Coverage for /var/srv/projects/api.amasfac.comuna18.com/tmp/venv/lib/python3.9/site-packages/pandas/core/groupby/numba_.py: 19%
52 statements
« prev ^ index » next coverage.py v6.4.4, created at 2023-07-17 14:22 -0600
« prev ^ index » next coverage.py v6.4.4, created at 2023-07-17 14:22 -0600
1"""Common utilities for Numba operations with groupby ops"""
2from __future__ import annotations
4import functools
5import inspect
6from typing import (
7 TYPE_CHECKING,
8 Any,
9 Callable,
10)
12import numpy as np
14from pandas._typing import Scalar
15from pandas.compat._optional import import_optional_dependency
17from pandas.core.util.numba_ import (
18 NumbaUtilError,
19 jit_user_function,
20)
23def validate_udf(func: Callable) -> None:
24 """
25 Validate user defined function for ops when using Numba with groupby ops.
27 The first signature arguments should include:
29 def f(values, index, ...):
30 ...
32 Parameters
33 ----------
34 func : function, default False
35 user defined function
37 Returns
38 -------
39 None
41 Raises
42 ------
43 NumbaUtilError
44 """
45 if not callable(func):
46 raise NotImplementedError(
47 "Numba engine can only be used with a single function."
48 )
49 udf_signature = list(inspect.signature(func).parameters.keys())
50 expected_args = ["values", "index"]
51 min_number_args = len(expected_args)
52 if (
53 len(udf_signature) < min_number_args
54 or udf_signature[:min_number_args] != expected_args
55 ):
56 raise NumbaUtilError(
57 f"The first {min_number_args} arguments to {func.__name__} must be "
58 f"{expected_args}"
59 )
62@functools.lru_cache(maxsize=None)
63def generate_numba_agg_func(
64 func: Callable[..., Scalar],
65 nopython: bool,
66 nogil: bool,
67 parallel: bool,
68) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]:
69 """
70 Generate a numba jitted agg function specified by values from engine_kwargs.
72 1. jit the user's function
73 2. Return a groupby agg function with the jitted function inline
75 Configurations specified in engine_kwargs apply to both the user's
76 function _AND_ the groupby evaluation loop.
78 Parameters
79 ----------
80 func : function
81 function to be applied to each group and will be JITed
82 nopython : bool
83 nopython to be passed into numba.jit
84 nogil : bool
85 nogil to be passed into numba.jit
86 parallel : bool
87 parallel to be passed into numba.jit
89 Returns
90 -------
91 Numba function
92 """
93 numba_func = jit_user_function(func, nopython, nogil, parallel)
94 if TYPE_CHECKING:
95 import numba
96 else:
97 numba = import_optional_dependency("numba")
99 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
100 def group_agg(
101 values: np.ndarray,
102 index: np.ndarray,
103 begin: np.ndarray,
104 end: np.ndarray,
105 num_columns: int,
106 *args: Any,
107 ) -> np.ndarray:
109 assert len(begin) == len(end)
110 num_groups = len(begin)
112 result = np.empty((num_groups, num_columns))
113 for i in numba.prange(num_groups):
114 group_index = index[begin[i] : end[i]]
115 for j in numba.prange(num_columns):
116 group = values[begin[i] : end[i], j]
117 result[i, j] = numba_func(group, group_index, *args)
118 return result
120 return group_agg
123@functools.lru_cache(maxsize=None)
124def generate_numba_transform_func(
125 func: Callable[..., np.ndarray],
126 nopython: bool,
127 nogil: bool,
128 parallel: bool,
129) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]:
130 """
131 Generate a numba jitted transform function specified by values from engine_kwargs.
133 1. jit the user's function
134 2. Return a groupby transform function with the jitted function inline
136 Configurations specified in engine_kwargs apply to both the user's
137 function _AND_ the groupby evaluation loop.
139 Parameters
140 ----------
141 func : function
142 function to be applied to each window and will be JITed
143 nopython : bool
144 nopython to be passed into numba.jit
145 nogil : bool
146 nogil to be passed into numba.jit
147 parallel : bool
148 parallel to be passed into numba.jit
150 Returns
151 -------
152 Numba function
153 """
154 numba_func = jit_user_function(func, nopython, nogil, parallel)
155 if TYPE_CHECKING:
156 import numba
157 else:
158 numba = import_optional_dependency("numba")
160 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
161 def group_transform(
162 values: np.ndarray,
163 index: np.ndarray,
164 begin: np.ndarray,
165 end: np.ndarray,
166 num_columns: int,
167 *args: Any,
168 ) -> np.ndarray:
170 assert len(begin) == len(end)
171 num_groups = len(begin)
173 result = np.empty((len(values), num_columns))
174 for i in numba.prange(num_groups):
175 group_index = index[begin[i] : end[i]]
176 for j in numba.prange(num_columns):
177 group = values[begin[i] : end[i], j]
178 result[begin[i] : end[i], j] = numba_func(group, group_index, *args)
179 return result
181 return group_transform