Coverage for /var/srv/projects/api.amasfac.comuna18.com/tmp/venv/lib/python3.9/site-packages/django/contrib/postgres/fields/ranges.py: 63%

215 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2023-07-17 14:22 -0600

1import datetime 

2import json 

3 

4from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange, Range 

5 

6from django.contrib.postgres import forms, lookups 

7from django.db import models 

8from django.db.models.lookups import PostgresOperatorLookup 

9 

10from .utils import AttributeSetter 

11 

12__all__ = [ 

13 "RangeField", 

14 "IntegerRangeField", 

15 "BigIntegerRangeField", 

16 "DecimalRangeField", 

17 "DateTimeRangeField", 

18 "DateRangeField", 

19 "RangeBoundary", 

20 "RangeOperators", 

21] 

22 

23 

24class RangeBoundary(models.Expression): 

25 """A class that represents range boundaries.""" 

26 

27 def __init__(self, inclusive_lower=True, inclusive_upper=False): 

28 self.lower = "[" if inclusive_lower else "(" 

29 self.upper = "]" if inclusive_upper else ")" 

30 

31 def as_sql(self, compiler, connection): 

32 return "'%s%s'" % (self.lower, self.upper), [] 

33 

34 

35class RangeOperators: 

36 # https://www.postgresql.org/docs/current/functions-range.html#RANGE-OPERATORS-TABLE 

37 EQUAL = "=" 

38 NOT_EQUAL = "<>" 

39 CONTAINS = "@>" 

40 CONTAINED_BY = "<@" 

41 OVERLAPS = "&&" 

42 FULLY_LT = "<<" 

43 FULLY_GT = ">>" 

44 NOT_LT = "&>" 

45 NOT_GT = "&<" 

46 ADJACENT_TO = "-|-" 

47 

48 

49class RangeField(models.Field): 

50 empty_strings_allowed = False 

51 

52 def __init__(self, *args, **kwargs): 

53 # Initializing base_field here ensures that its model matches the model 

54 # for self. 

55 if hasattr(self, "base_field"): 

56 self.base_field = self.base_field() 

57 super().__init__(*args, **kwargs) 

58 

59 @property 

60 def model(self): 

61 try: 

62 return self.__dict__["model"] 

63 except KeyError: 

64 raise AttributeError( 

65 "'%s' object has no attribute 'model'" % self.__class__.__name__ 

66 ) 

67 

68 @model.setter 

69 def model(self, model): 

70 self.__dict__["model"] = model 

71 self.base_field.model = model 

72 

73 @classmethod 

74 def _choices_is_value(cls, value): 

75 return isinstance(value, (list, tuple)) or super()._choices_is_value(value) 

76 

77 def get_prep_value(self, value): 

78 if value is None: 

79 return None 

80 elif isinstance(value, Range): 

81 return value 

82 elif isinstance(value, (list, tuple)): 

83 return self.range_type(value[0], value[1]) 

84 return value 

85 

86 def to_python(self, value): 

87 if isinstance(value, str): 

88 # Assume we're deserializing 

89 vals = json.loads(value) 

90 for end in ("lower", "upper"): 

91 if end in vals: 

92 vals[end] = self.base_field.to_python(vals[end]) 

93 value = self.range_type(**vals) 

94 elif isinstance(value, (list, tuple)): 

95 value = self.range_type(value[0], value[1]) 

96 return value 

97 

98 def set_attributes_from_name(self, name): 

99 super().set_attributes_from_name(name) 

100 self.base_field.set_attributes_from_name(name) 

101 

102 def value_to_string(self, obj): 

103 value = self.value_from_object(obj) 

104 if value is None: 

105 return None 

106 if value.isempty: 

107 return json.dumps({"empty": True}) 

108 base_field = self.base_field 

109 result = {"bounds": value._bounds} 

110 for end in ("lower", "upper"): 

111 val = getattr(value, end) 

112 if val is None: 

113 result[end] = None 

114 else: 

115 obj = AttributeSetter(base_field.attname, val) 

116 result[end] = base_field.value_to_string(obj) 

117 return json.dumps(result) 

118 

119 def formfield(self, **kwargs): 

120 kwargs.setdefault("form_class", self.form_field) 

121 return super().formfield(**kwargs) 

122 

123 

124class IntegerRangeField(RangeField): 

125 base_field = models.IntegerField 

126 range_type = NumericRange 

127 form_field = forms.IntegerRangeField 

128 

129 def db_type(self, connection): 

130 return "int4range" 

131 

132 

133class BigIntegerRangeField(RangeField): 

134 base_field = models.BigIntegerField 

135 range_type = NumericRange 

136 form_field = forms.IntegerRangeField 

137 

138 def db_type(self, connection): 

139 return "int8range" 

140 

141 

142class DecimalRangeField(RangeField): 

143 base_field = models.DecimalField 

144 range_type = NumericRange 

145 form_field = forms.DecimalRangeField 

146 

147 def db_type(self, connection): 

148 return "numrange" 

149 

150 

151class DateTimeRangeField(RangeField): 

152 base_field = models.DateTimeField 

153 range_type = DateTimeTZRange 

154 form_field = forms.DateTimeRangeField 

155 

156 def db_type(self, connection): 

157 return "tstzrange" 

158 

159 

160class DateRangeField(RangeField): 

161 base_field = models.DateField 

162 range_type = DateRange 

163 form_field = forms.DateRangeField 

164 

165 def db_type(self, connection): 

166 return "daterange" 

167 

168 

169RangeField.register_lookup(lookups.DataContains) 

170RangeField.register_lookup(lookups.ContainedBy) 

171RangeField.register_lookup(lookups.Overlap) 

172 

173 

174class DateTimeRangeContains(PostgresOperatorLookup): 

175 """ 

176 Lookup for Date/DateTimeRange containment to cast the rhs to the correct 

177 type. 

178 """ 

179 

180 lookup_name = "contains" 

181 postgres_operator = RangeOperators.CONTAINS 

182 

183 def process_rhs(self, compiler, connection): 

184 # Transform rhs value for db lookup. 

185 if isinstance(self.rhs, datetime.date): 

186 value = models.Value(self.rhs) 

187 self.rhs = value.resolve_expression(compiler.query) 

188 return super().process_rhs(compiler, connection) 

189 

190 def as_postgresql(self, compiler, connection): 

191 sql, params = super().as_postgresql(compiler, connection) 

192 # Cast the rhs if needed. 

193 cast_sql = "" 

194 if ( 

195 isinstance(self.rhs, models.Expression) 

196 and self.rhs._output_field_or_none 

197 and 

198 # Skip cast if rhs has a matching range type. 

199 not isinstance( 

200 self.rhs._output_field_or_none, self.lhs.output_field.__class__ 

201 ) 

202 ): 

203 cast_internal_type = self.lhs.output_field.base_field.get_internal_type() 

204 cast_sql = "::{}".format(connection.data_types.get(cast_internal_type)) 

205 return "%s%s" % (sql, cast_sql), params 

206 

207 

208DateRangeField.register_lookup(DateTimeRangeContains) 

209DateTimeRangeField.register_lookup(DateTimeRangeContains) 

210 

211 

212class RangeContainedBy(PostgresOperatorLookup): 

213 lookup_name = "contained_by" 

214 type_mapping = { 

215 "smallint": "int4range", 

216 "integer": "int4range", 

217 "bigint": "int8range", 

218 "double precision": "numrange", 

219 "numeric": "numrange", 

220 "date": "daterange", 

221 "timestamp with time zone": "tstzrange", 

222 } 

223 postgres_operator = RangeOperators.CONTAINED_BY 

224 

225 def process_rhs(self, compiler, connection): 

226 rhs, rhs_params = super().process_rhs(compiler, connection) 

227 # Ignore precision for DecimalFields. 

228 db_type = self.lhs.output_field.cast_db_type(connection).split("(")[0] 

229 cast_type = self.type_mapping[db_type] 

230 return "%s::%s" % (rhs, cast_type), rhs_params 

231 

232 def process_lhs(self, compiler, connection): 

233 lhs, lhs_params = super().process_lhs(compiler, connection) 

234 if isinstance(self.lhs.output_field, models.FloatField): 

235 lhs = "%s::numeric" % lhs 

236 elif isinstance(self.lhs.output_field, models.SmallIntegerField): 

237 lhs = "%s::integer" % lhs 

238 return lhs, lhs_params 

239 

240 def get_prep_lookup(self): 

241 return RangeField().get_prep_value(self.rhs) 

242 

243 

244models.DateField.register_lookup(RangeContainedBy) 

245models.DateTimeField.register_lookup(RangeContainedBy) 

246models.IntegerField.register_lookup(RangeContainedBy) 

247models.FloatField.register_lookup(RangeContainedBy) 

248models.DecimalField.register_lookup(RangeContainedBy) 

249 

250 

251@RangeField.register_lookup 

252class FullyLessThan(PostgresOperatorLookup): 

253 lookup_name = "fully_lt" 

254 postgres_operator = RangeOperators.FULLY_LT 

255 

256 

257@RangeField.register_lookup 

258class FullGreaterThan(PostgresOperatorLookup): 

259 lookup_name = "fully_gt" 

260 postgres_operator = RangeOperators.FULLY_GT 

261 

262 

263@RangeField.register_lookup 

264class NotLessThan(PostgresOperatorLookup): 

265 lookup_name = "not_lt" 

266 postgres_operator = RangeOperators.NOT_LT 

267 

268 

269@RangeField.register_lookup 

270class NotGreaterThan(PostgresOperatorLookup): 

271 lookup_name = "not_gt" 

272 postgres_operator = RangeOperators.NOT_GT 

273 

274 

275@RangeField.register_lookup 

276class AdjacentToLookup(PostgresOperatorLookup): 

277 lookup_name = "adjacent_to" 

278 postgres_operator = RangeOperators.ADJACENT_TO 

279 

280 

281@RangeField.register_lookup 

282class RangeStartsWith(models.Transform): 

283 lookup_name = "startswith" 

284 function = "lower" 

285 

286 @property 

287 def output_field(self): 

288 return self.lhs.output_field.base_field 

289 

290 

291@RangeField.register_lookup 

292class RangeEndsWith(models.Transform): 

293 lookup_name = "endswith" 

294 function = "upper" 

295 

296 @property 

297 def output_field(self): 

298 return self.lhs.output_field.base_field 

299 

300 

301@RangeField.register_lookup 

302class IsEmpty(models.Transform): 

303 lookup_name = "isempty" 

304 function = "isempty" 

305 output_field = models.BooleanField() 

306 

307 

308@RangeField.register_lookup 

309class LowerInclusive(models.Transform): 

310 lookup_name = "lower_inc" 

311 function = "LOWER_INC" 

312 output_field = models.BooleanField() 

313 

314 

315@RangeField.register_lookup 

316class LowerInfinite(models.Transform): 

317 lookup_name = "lower_inf" 

318 function = "LOWER_INF" 

319 output_field = models.BooleanField() 

320 

321 

322@RangeField.register_lookup 

323class UpperInclusive(models.Transform): 

324 lookup_name = "upper_inc" 

325 function = "UPPER_INC" 

326 output_field = models.BooleanField() 

327 

328 

329@RangeField.register_lookup 

330class UpperInfinite(models.Transform): 

331 lookup_name = "upper_inf" 

332 function = "UPPER_INF" 

333 output_field = models.BooleanField()