Coverage for /var/srv/projects/api.amasfac.comuna18.com/tmp/venv/lib/python3.9/site-packages/pandas/core/window/numba_.py: 8%
146 statements
« prev ^ index » next coverage.py v6.4.4, created at 2023-07-17 14:22 -0600
« prev ^ index » next coverage.py v6.4.4, created at 2023-07-17 14:22 -0600
1from __future__ import annotations
3import functools
4from typing import (
5 TYPE_CHECKING,
6 Any,
7 Callable,
8)
10import numpy as np
12from pandas._typing import Scalar
13from pandas.compat._optional import import_optional_dependency
15from pandas.core.util.numba_ import jit_user_function
18@functools.lru_cache(maxsize=None)
19def generate_numba_apply_func(
20 func: Callable[..., Scalar],
21 nopython: bool,
22 nogil: bool,
23 parallel: bool,
24):
25 """
26 Generate a numba jitted apply function specified by values from engine_kwargs.
28 1. jit the user's function
29 2. Return a rolling apply function with the jitted function inline
31 Configurations specified in engine_kwargs apply to both the user's
32 function _AND_ the rolling apply function.
34 Parameters
35 ----------
36 func : function
37 function to be applied to each window and will be JITed
38 nopython : bool
39 nopython to be passed into numba.jit
40 nogil : bool
41 nogil to be passed into numba.jit
42 parallel : bool
43 parallel to be passed into numba.jit
45 Returns
46 -------
47 Numba function
48 """
49 numba_func = jit_user_function(func, nopython, nogil, parallel)
50 if TYPE_CHECKING:
51 import numba
52 else:
53 numba = import_optional_dependency("numba")
55 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
56 def roll_apply(
57 values: np.ndarray,
58 begin: np.ndarray,
59 end: np.ndarray,
60 minimum_periods: int,
61 *args: Any,
62 ) -> np.ndarray:
63 result = np.empty(len(begin))
64 for i in numba.prange(len(result)):
65 start = begin[i]
66 stop = end[i]
67 window = values[start:stop]
68 count_nan = np.sum(np.isnan(window))
69 if len(window) - count_nan >= minimum_periods:
70 result[i] = numba_func(window, *args)
71 else:
72 result[i] = np.nan
73 return result
75 return roll_apply
78@functools.lru_cache(maxsize=None)
79def generate_numba_ewm_func(
80 nopython: bool,
81 nogil: bool,
82 parallel: bool,
83 com: float,
84 adjust: bool,
85 ignore_na: bool,
86 deltas: tuple,
87 normalize: bool,
88):
89 """
90 Generate a numba jitted ewm mean or sum function specified by values
91 from engine_kwargs.
93 Parameters
94 ----------
95 nopython : bool
96 nopython to be passed into numba.jit
97 nogil : bool
98 nogil to be passed into numba.jit
99 parallel : bool
100 parallel to be passed into numba.jit
101 com : float
102 adjust : bool
103 ignore_na : bool
104 deltas : tuple
105 normalize : bool
107 Returns
108 -------
109 Numba function
110 """
111 if TYPE_CHECKING:
112 import numba
113 else:
114 numba = import_optional_dependency("numba")
116 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
117 def ewm(
118 values: np.ndarray,
119 begin: np.ndarray,
120 end: np.ndarray,
121 minimum_periods: int,
122 ) -> np.ndarray:
123 result = np.empty(len(values))
124 alpha = 1.0 / (1.0 + com)
125 old_wt_factor = 1.0 - alpha
126 new_wt = 1.0 if adjust else alpha
128 for i in numba.prange(len(begin)):
129 start = begin[i]
130 stop = end[i]
131 window = values[start:stop]
132 sub_result = np.empty(len(window))
134 weighted = window[0]
135 nobs = int(not np.isnan(weighted))
136 sub_result[0] = weighted if nobs >= minimum_periods else np.nan
137 old_wt = 1.0
139 for j in range(1, len(window)):
140 cur = window[j]
141 is_observation = not np.isnan(cur)
142 nobs += is_observation
143 if not np.isnan(weighted):
145 if is_observation or not ignore_na:
146 if normalize:
147 # note that len(deltas) = len(vals) - 1 and deltas[i]
148 # is to be used in conjunction with vals[i+1]
149 old_wt *= old_wt_factor ** deltas[start + j - 1]
150 else:
151 weighted = old_wt_factor * weighted
152 if is_observation:
153 if normalize:
154 # avoid numerical errors on constant series
155 if weighted != cur:
156 weighted = old_wt * weighted + new_wt * cur
157 if normalize:
158 weighted = weighted / (old_wt + new_wt)
159 if adjust:
160 old_wt += new_wt
161 else:
162 old_wt = 1.0
163 else:
164 weighted += cur
165 elif is_observation:
166 weighted = cur
168 sub_result[j] = weighted if nobs >= minimum_periods else np.nan
170 result[start:stop] = sub_result
172 return result
174 return ewm
177@functools.lru_cache(maxsize=None)
178def generate_numba_table_func(
179 func: Callable[..., np.ndarray],
180 nopython: bool,
181 nogil: bool,
182 parallel: bool,
183):
184 """
185 Generate a numba jitted function to apply window calculations table-wise.
187 Func will be passed a M window size x N number of columns array, and
188 must return a 1 x N number of columns array. Func is intended to operate
189 row-wise, but the result will be transposed for axis=1.
191 1. jit the user's function
192 2. Return a rolling apply function with the jitted function inline
194 Parameters
195 ----------
196 func : function
197 function to be applied to each window and will be JITed
198 nopython : bool
199 nopython to be passed into numba.jit
200 nogil : bool
201 nogil to be passed into numba.jit
202 parallel : bool
203 parallel to be passed into numba.jit
205 Returns
206 -------
207 Numba function
208 """
209 numba_func = jit_user_function(func, nopython, nogil, parallel)
210 if TYPE_CHECKING:
211 import numba
212 else:
213 numba = import_optional_dependency("numba")
215 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
216 def roll_table(
217 values: np.ndarray,
218 begin: np.ndarray,
219 end: np.ndarray,
220 minimum_periods: int,
221 *args: Any,
222 ):
223 result = np.empty((len(begin), values.shape[1]))
224 min_periods_mask = np.empty(result.shape)
225 for i in numba.prange(len(result)):
226 start = begin[i]
227 stop = end[i]
228 window = values[start:stop]
229 count_nan = np.sum(np.isnan(window), axis=0)
230 sub_result = numba_func(window, *args)
231 nan_mask = len(window) - count_nan >= minimum_periods
232 min_periods_mask[i, :] = nan_mask
233 result[i, :] = sub_result
234 result = np.where(min_periods_mask, result, np.nan)
235 return result
237 return roll_table
240# This function will no longer be needed once numba supports
241# axis for all np.nan* agg functions
242# https://github.com/numba/numba/issues/1269
243@functools.lru_cache(maxsize=None)
244def generate_manual_numpy_nan_agg_with_axis(nan_func):
245 if TYPE_CHECKING:
246 import numba
247 else:
248 numba = import_optional_dependency("numba")
250 @numba.jit(nopython=True, nogil=True, parallel=True)
251 def nan_agg_with_axis(table):
252 result = np.empty(table.shape[1])
253 for i in numba.prange(table.shape[1]):
254 partition = table[:, i]
255 result[i] = nan_func(partition)
256 return result
258 return nan_agg_with_axis
261@functools.lru_cache(maxsize=None)
262def generate_numba_ewm_table_func(
263 nopython: bool,
264 nogil: bool,
265 parallel: bool,
266 com: float,
267 adjust: bool,
268 ignore_na: bool,
269 deltas: tuple,
270 normalize: bool,
271):
272 """
273 Generate a numba jitted ewm mean or sum function applied table wise specified
274 by values from engine_kwargs.
276 Parameters
277 ----------
278 nopython : bool
279 nopython to be passed into numba.jit
280 nogil : bool
281 nogil to be passed into numba.jit
282 parallel : bool
283 parallel to be passed into numba.jit
284 com : float
285 adjust : bool
286 ignore_na : bool
287 deltas : tuple
288 normalize: bool
290 Returns
291 -------
292 Numba function
293 """
294 if TYPE_CHECKING:
295 import numba
296 else:
297 numba = import_optional_dependency("numba")
299 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
300 def ewm_table(
301 values: np.ndarray,
302 begin: np.ndarray,
303 end: np.ndarray,
304 minimum_periods: int,
305 ) -> np.ndarray:
306 alpha = 1.0 / (1.0 + com)
307 old_wt_factor = 1.0 - alpha
308 new_wt = 1.0 if adjust else alpha
309 old_wt = np.ones(values.shape[1])
311 result = np.empty(values.shape)
312 weighted = values[0].copy()
313 nobs = (~np.isnan(weighted)).astype(np.int64)
314 result[0] = np.where(nobs >= minimum_periods, weighted, np.nan)
315 for i in range(1, len(values)):
316 cur = values[i]
317 is_observations = ~np.isnan(cur)
318 nobs += is_observations.astype(np.int64)
319 for j in numba.prange(len(cur)):
320 if not np.isnan(weighted[j]):
321 if is_observations[j] or not ignore_na:
322 if normalize:
323 # note that len(deltas) = len(vals) - 1 and deltas[i]
324 # is to be used in conjunction with vals[i+1]
325 old_wt[j] *= old_wt_factor ** deltas[i - 1]
326 else:
327 weighted[j] = old_wt_factor * weighted[j]
328 if is_observations[j]:
329 if normalize:
330 # avoid numerical errors on constant series
331 if weighted[j] != cur[j]:
332 weighted[j] = (
333 old_wt[j] * weighted[j] + new_wt * cur[j]
334 )
335 if normalize:
336 weighted[j] = weighted[j] / (old_wt[j] + new_wt)
337 if adjust:
338 old_wt[j] += new_wt
339 else:
340 old_wt[j] = 1.0
341 else:
342 weighted[j] += cur[j]
343 elif is_observations[j]:
344 weighted[j] = cur[j]
346 result[i] = np.where(nobs >= minimum_periods, weighted, np.nan)
348 return result
350 return ewm_table