Coverage for /var/srv/projects/api.amasfac.comuna18.com/tmp/venv/lib/python3.9/site-packages/pandas/core/window/numba_.py: 8%

146 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2023-07-17 14:22 -0600

1from __future__ import annotations 

2 

3import functools 

4from typing import ( 

5 TYPE_CHECKING, 

6 Any, 

7 Callable, 

8) 

9 

10import numpy as np 

11 

12from pandas._typing import Scalar 

13from pandas.compat._optional import import_optional_dependency 

14 

15from pandas.core.util.numba_ import jit_user_function 

16 

17 

18@functools.lru_cache(maxsize=None) 

19def generate_numba_apply_func( 

20 func: Callable[..., Scalar], 

21 nopython: bool, 

22 nogil: bool, 

23 parallel: bool, 

24): 

25 """ 

26 Generate a numba jitted apply function specified by values from engine_kwargs. 

27 

28 1. jit the user's function 

29 2. Return a rolling apply function with the jitted function inline 

30 

31 Configurations specified in engine_kwargs apply to both the user's 

32 function _AND_ the rolling apply function. 

33 

34 Parameters 

35 ---------- 

36 func : function 

37 function to be applied to each window and will be JITed 

38 nopython : bool 

39 nopython to be passed into numba.jit 

40 nogil : bool 

41 nogil to be passed into numba.jit 

42 parallel : bool 

43 parallel to be passed into numba.jit 

44 

45 Returns 

46 ------- 

47 Numba function 

48 """ 

49 numba_func = jit_user_function(func, nopython, nogil, parallel) 

50 if TYPE_CHECKING: 

51 import numba 

52 else: 

53 numba = import_optional_dependency("numba") 

54 

55 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) 

56 def roll_apply( 

57 values: np.ndarray, 

58 begin: np.ndarray, 

59 end: np.ndarray, 

60 minimum_periods: int, 

61 *args: Any, 

62 ) -> np.ndarray: 

63 result = np.empty(len(begin)) 

64 for i in numba.prange(len(result)): 

65 start = begin[i] 

66 stop = end[i] 

67 window = values[start:stop] 

68 count_nan = np.sum(np.isnan(window)) 

69 if len(window) - count_nan >= minimum_periods: 

70 result[i] = numba_func(window, *args) 

71 else: 

72 result[i] = np.nan 

73 return result 

74 

75 return roll_apply 

76 

77 

78@functools.lru_cache(maxsize=None) 

79def generate_numba_ewm_func( 

80 nopython: bool, 

81 nogil: bool, 

82 parallel: bool, 

83 com: float, 

84 adjust: bool, 

85 ignore_na: bool, 

86 deltas: tuple, 

87 normalize: bool, 

88): 

89 """ 

90 Generate a numba jitted ewm mean or sum function specified by values 

91 from engine_kwargs. 

92 

93 Parameters 

94 ---------- 

95 nopython : bool 

96 nopython to be passed into numba.jit 

97 nogil : bool 

98 nogil to be passed into numba.jit 

99 parallel : bool 

100 parallel to be passed into numba.jit 

101 com : float 

102 adjust : bool 

103 ignore_na : bool 

104 deltas : tuple 

105 normalize : bool 

106 

107 Returns 

108 ------- 

109 Numba function 

110 """ 

111 if TYPE_CHECKING: 

112 import numba 

113 else: 

114 numba = import_optional_dependency("numba") 

115 

116 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) 

117 def ewm( 

118 values: np.ndarray, 

119 begin: np.ndarray, 

120 end: np.ndarray, 

121 minimum_periods: int, 

122 ) -> np.ndarray: 

123 result = np.empty(len(values)) 

124 alpha = 1.0 / (1.0 + com) 

125 old_wt_factor = 1.0 - alpha 

126 new_wt = 1.0 if adjust else alpha 

127 

128 for i in numba.prange(len(begin)): 

129 start = begin[i] 

130 stop = end[i] 

131 window = values[start:stop] 

132 sub_result = np.empty(len(window)) 

133 

134 weighted = window[0] 

135 nobs = int(not np.isnan(weighted)) 

136 sub_result[0] = weighted if nobs >= minimum_periods else np.nan 

137 old_wt = 1.0 

138 

139 for j in range(1, len(window)): 

140 cur = window[j] 

141 is_observation = not np.isnan(cur) 

142 nobs += is_observation 

143 if not np.isnan(weighted): 

144 

145 if is_observation or not ignore_na: 

146 if normalize: 

147 # note that len(deltas) = len(vals) - 1 and deltas[i] 

148 # is to be used in conjunction with vals[i+1] 

149 old_wt *= old_wt_factor ** deltas[start + j - 1] 

150 else: 

151 weighted = old_wt_factor * weighted 

152 if is_observation: 

153 if normalize: 

154 # avoid numerical errors on constant series 

155 if weighted != cur: 

156 weighted = old_wt * weighted + new_wt * cur 

157 if normalize: 

158 weighted = weighted / (old_wt + new_wt) 

159 if adjust: 

160 old_wt += new_wt 

161 else: 

162 old_wt = 1.0 

163 else: 

164 weighted += cur 

165 elif is_observation: 

166 weighted = cur 

167 

168 sub_result[j] = weighted if nobs >= minimum_periods else np.nan 

169 

170 result[start:stop] = sub_result 

171 

172 return result 

173 

174 return ewm 

175 

176 

177@functools.lru_cache(maxsize=None) 

178def generate_numba_table_func( 

179 func: Callable[..., np.ndarray], 

180 nopython: bool, 

181 nogil: bool, 

182 parallel: bool, 

183): 

184 """ 

185 Generate a numba jitted function to apply window calculations table-wise. 

186 

187 Func will be passed a M window size x N number of columns array, and 

188 must return a 1 x N number of columns array. Func is intended to operate 

189 row-wise, but the result will be transposed for axis=1. 

190 

191 1. jit the user's function 

192 2. Return a rolling apply function with the jitted function inline 

193 

194 Parameters 

195 ---------- 

196 func : function 

197 function to be applied to each window and will be JITed 

198 nopython : bool 

199 nopython to be passed into numba.jit 

200 nogil : bool 

201 nogil to be passed into numba.jit 

202 parallel : bool 

203 parallel to be passed into numba.jit 

204 

205 Returns 

206 ------- 

207 Numba function 

208 """ 

209 numba_func = jit_user_function(func, nopython, nogil, parallel) 

210 if TYPE_CHECKING: 

211 import numba 

212 else: 

213 numba = import_optional_dependency("numba") 

214 

215 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) 

216 def roll_table( 

217 values: np.ndarray, 

218 begin: np.ndarray, 

219 end: np.ndarray, 

220 minimum_periods: int, 

221 *args: Any, 

222 ): 

223 result = np.empty((len(begin), values.shape[1])) 

224 min_periods_mask = np.empty(result.shape) 

225 for i in numba.prange(len(result)): 

226 start = begin[i] 

227 stop = end[i] 

228 window = values[start:stop] 

229 count_nan = np.sum(np.isnan(window), axis=0) 

230 sub_result = numba_func(window, *args) 

231 nan_mask = len(window) - count_nan >= minimum_periods 

232 min_periods_mask[i, :] = nan_mask 

233 result[i, :] = sub_result 

234 result = np.where(min_periods_mask, result, np.nan) 

235 return result 

236 

237 return roll_table 

238 

239 

240# This function will no longer be needed once numba supports 

241# axis for all np.nan* agg functions 

242# https://github.com/numba/numba/issues/1269 

243@functools.lru_cache(maxsize=None) 

244def generate_manual_numpy_nan_agg_with_axis(nan_func): 

245 if TYPE_CHECKING: 

246 import numba 

247 else: 

248 numba = import_optional_dependency("numba") 

249 

250 @numba.jit(nopython=True, nogil=True, parallel=True) 

251 def nan_agg_with_axis(table): 

252 result = np.empty(table.shape[1]) 

253 for i in numba.prange(table.shape[1]): 

254 partition = table[:, i] 

255 result[i] = nan_func(partition) 

256 return result 

257 

258 return nan_agg_with_axis 

259 

260 

261@functools.lru_cache(maxsize=None) 

262def generate_numba_ewm_table_func( 

263 nopython: bool, 

264 nogil: bool, 

265 parallel: bool, 

266 com: float, 

267 adjust: bool, 

268 ignore_na: bool, 

269 deltas: tuple, 

270 normalize: bool, 

271): 

272 """ 

273 Generate a numba jitted ewm mean or sum function applied table wise specified 

274 by values from engine_kwargs. 

275 

276 Parameters 

277 ---------- 

278 nopython : bool 

279 nopython to be passed into numba.jit 

280 nogil : bool 

281 nogil to be passed into numba.jit 

282 parallel : bool 

283 parallel to be passed into numba.jit 

284 com : float 

285 adjust : bool 

286 ignore_na : bool 

287 deltas : tuple 

288 normalize: bool 

289 

290 Returns 

291 ------- 

292 Numba function 

293 """ 

294 if TYPE_CHECKING: 

295 import numba 

296 else: 

297 numba = import_optional_dependency("numba") 

298 

299 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) 

300 def ewm_table( 

301 values: np.ndarray, 

302 begin: np.ndarray, 

303 end: np.ndarray, 

304 minimum_periods: int, 

305 ) -> np.ndarray: 

306 alpha = 1.0 / (1.0 + com) 

307 old_wt_factor = 1.0 - alpha 

308 new_wt = 1.0 if adjust else alpha 

309 old_wt = np.ones(values.shape[1]) 

310 

311 result = np.empty(values.shape) 

312 weighted = values[0].copy() 

313 nobs = (~np.isnan(weighted)).astype(np.int64) 

314 result[0] = np.where(nobs >= minimum_periods, weighted, np.nan) 

315 for i in range(1, len(values)): 

316 cur = values[i] 

317 is_observations = ~np.isnan(cur) 

318 nobs += is_observations.astype(np.int64) 

319 for j in numba.prange(len(cur)): 

320 if not np.isnan(weighted[j]): 

321 if is_observations[j] or not ignore_na: 

322 if normalize: 

323 # note that len(deltas) = len(vals) - 1 and deltas[i] 

324 # is to be used in conjunction with vals[i+1] 

325 old_wt[j] *= old_wt_factor ** deltas[i - 1] 

326 else: 

327 weighted[j] = old_wt_factor * weighted[j] 

328 if is_observations[j]: 

329 if normalize: 

330 # avoid numerical errors on constant series 

331 if weighted[j] != cur[j]: 

332 weighted[j] = ( 

333 old_wt[j] * weighted[j] + new_wt * cur[j] 

334 ) 

335 if normalize: 

336 weighted[j] = weighted[j] / (old_wt[j] + new_wt) 

337 if adjust: 

338 old_wt[j] += new_wt 

339 else: 

340 old_wt[j] = 1.0 

341 else: 

342 weighted[j] += cur[j] 

343 elif is_observations[j]: 

344 weighted[j] = cur[j] 

345 

346 result[i] = np.where(nobs >= minimum_periods, weighted, np.nan) 

347 

348 return result 

349 

350 return ewm_table