Coverage for /var/srv/projects/api.amasfac.comuna18.com/tmp/venv/lib/python3.9/site-packages/pandas/core/groupby/numba_.py: 19%

52 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2023-07-17 14:22 -0600

1"""Common utilities for Numba operations with groupby ops""" 

2from __future__ import annotations 

3 

4import functools 

5import inspect 

6from typing import ( 

7 TYPE_CHECKING, 

8 Any, 

9 Callable, 

10) 

11 

12import numpy as np 

13 

14from pandas._typing import Scalar 

15from pandas.compat._optional import import_optional_dependency 

16 

17from pandas.core.util.numba_ import ( 

18 NumbaUtilError, 

19 jit_user_function, 

20) 

21 

22 

23def validate_udf(func: Callable) -> None: 

24 """ 

25 Validate user defined function for ops when using Numba with groupby ops. 

26 

27 The first signature arguments should include: 

28 

29 def f(values, index, ...): 

30 ... 

31 

32 Parameters 

33 ---------- 

34 func : function, default False 

35 user defined function 

36 

37 Returns 

38 ------- 

39 None 

40 

41 Raises 

42 ------ 

43 NumbaUtilError 

44 """ 

45 if not callable(func): 

46 raise NotImplementedError( 

47 "Numba engine can only be used with a single function." 

48 ) 

49 udf_signature = list(inspect.signature(func).parameters.keys()) 

50 expected_args = ["values", "index"] 

51 min_number_args = len(expected_args) 

52 if ( 

53 len(udf_signature) < min_number_args 

54 or udf_signature[:min_number_args] != expected_args 

55 ): 

56 raise NumbaUtilError( 

57 f"The first {min_number_args} arguments to {func.__name__} must be " 

58 f"{expected_args}" 

59 ) 

60 

61 

62@functools.lru_cache(maxsize=None) 

63def generate_numba_agg_func( 

64 func: Callable[..., Scalar], 

65 nopython: bool, 

66 nogil: bool, 

67 parallel: bool, 

68) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]: 

69 """ 

70 Generate a numba jitted agg function specified by values from engine_kwargs. 

71 

72 1. jit the user's function 

73 2. Return a groupby agg function with the jitted function inline 

74 

75 Configurations specified in engine_kwargs apply to both the user's 

76 function _AND_ the groupby evaluation loop. 

77 

78 Parameters 

79 ---------- 

80 func : function 

81 function to be applied to each group and will be JITed 

82 nopython : bool 

83 nopython to be passed into numba.jit 

84 nogil : bool 

85 nogil to be passed into numba.jit 

86 parallel : bool 

87 parallel to be passed into numba.jit 

88 

89 Returns 

90 ------- 

91 Numba function 

92 """ 

93 numba_func = jit_user_function(func, nopython, nogil, parallel) 

94 if TYPE_CHECKING: 

95 import numba 

96 else: 

97 numba = import_optional_dependency("numba") 

98 

99 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) 

100 def group_agg( 

101 values: np.ndarray, 

102 index: np.ndarray, 

103 begin: np.ndarray, 

104 end: np.ndarray, 

105 num_columns: int, 

106 *args: Any, 

107 ) -> np.ndarray: 

108 

109 assert len(begin) == len(end) 

110 num_groups = len(begin) 

111 

112 result = np.empty((num_groups, num_columns)) 

113 for i in numba.prange(num_groups): 

114 group_index = index[begin[i] : end[i]] 

115 for j in numba.prange(num_columns): 

116 group = values[begin[i] : end[i], j] 

117 result[i, j] = numba_func(group, group_index, *args) 

118 return result 

119 

120 return group_agg 

121 

122 

123@functools.lru_cache(maxsize=None) 

124def generate_numba_transform_func( 

125 func: Callable[..., np.ndarray], 

126 nopython: bool, 

127 nogil: bool, 

128 parallel: bool, 

129) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]: 

130 """ 

131 Generate a numba jitted transform function specified by values from engine_kwargs. 

132 

133 1. jit the user's function 

134 2. Return a groupby transform function with the jitted function inline 

135 

136 Configurations specified in engine_kwargs apply to both the user's 

137 function _AND_ the groupby evaluation loop. 

138 

139 Parameters 

140 ---------- 

141 func : function 

142 function to be applied to each window and will be JITed 

143 nopython : bool 

144 nopython to be passed into numba.jit 

145 nogil : bool 

146 nogil to be passed into numba.jit 

147 parallel : bool 

148 parallel to be passed into numba.jit 

149 

150 Returns 

151 ------- 

152 Numba function 

153 """ 

154 numba_func = jit_user_function(func, nopython, nogil, parallel) 

155 if TYPE_CHECKING: 

156 import numba 

157 else: 

158 numba = import_optional_dependency("numba") 

159 

160 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) 

161 def group_transform( 

162 values: np.ndarray, 

163 index: np.ndarray, 

164 begin: np.ndarray, 

165 end: np.ndarray, 

166 num_columns: int, 

167 *args: Any, 

168 ) -> np.ndarray: 

169 

170 assert len(begin) == len(end) 

171 num_groups = len(begin) 

172 

173 result = np.empty((len(values), num_columns)) 

174 for i in numba.prange(num_groups): 

175 group_index = index[begin[i] : end[i]] 

176 for j in numba.prange(num_columns): 

177 group = values[begin[i] : end[i], j] 

178 result[begin[i] : end[i], j] = numba_func(group, group_index, *args) 

179 return result 

180 

181 return group_transform