Coverage for /var/srv/projects/api.amasfac.comuna18.com/tmp/venv/lib/python3.9/site-packages/django/contrib/postgres/fields/ranges.py: 63%
215 statements
« prev ^ index » next coverage.py v6.4.4, created at 2023-07-17 14:22 -0600
« prev ^ index » next coverage.py v6.4.4, created at 2023-07-17 14:22 -0600
1import datetime
2import json
4from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange, Range
6from django.contrib.postgres import forms, lookups
7from django.db import models
8from django.db.models.lookups import PostgresOperatorLookup
10from .utils import AttributeSetter
12__all__ = [
13 "RangeField",
14 "IntegerRangeField",
15 "BigIntegerRangeField",
16 "DecimalRangeField",
17 "DateTimeRangeField",
18 "DateRangeField",
19 "RangeBoundary",
20 "RangeOperators",
21]
24class RangeBoundary(models.Expression):
25 """A class that represents range boundaries."""
27 def __init__(self, inclusive_lower=True, inclusive_upper=False):
28 self.lower = "[" if inclusive_lower else "("
29 self.upper = "]" if inclusive_upper else ")"
31 def as_sql(self, compiler, connection):
32 return "'%s%s'" % (self.lower, self.upper), []
35class RangeOperators:
36 # https://www.postgresql.org/docs/current/functions-range.html#RANGE-OPERATORS-TABLE
37 EQUAL = "="
38 NOT_EQUAL = "<>"
39 CONTAINS = "@>"
40 CONTAINED_BY = "<@"
41 OVERLAPS = "&&"
42 FULLY_LT = "<<"
43 FULLY_GT = ">>"
44 NOT_LT = "&>"
45 NOT_GT = "&<"
46 ADJACENT_TO = "-|-"
49class RangeField(models.Field):
50 empty_strings_allowed = False
52 def __init__(self, *args, **kwargs):
53 # Initializing base_field here ensures that its model matches the model
54 # for self.
55 if hasattr(self, "base_field"):
56 self.base_field = self.base_field()
57 super().__init__(*args, **kwargs)
59 @property
60 def model(self):
61 try:
62 return self.__dict__["model"]
63 except KeyError:
64 raise AttributeError(
65 "'%s' object has no attribute 'model'" % self.__class__.__name__
66 )
68 @model.setter
69 def model(self, model):
70 self.__dict__["model"] = model
71 self.base_field.model = model
73 @classmethod
74 def _choices_is_value(cls, value):
75 return isinstance(value, (list, tuple)) or super()._choices_is_value(value)
77 def get_prep_value(self, value):
78 if value is None:
79 return None
80 elif isinstance(value, Range):
81 return value
82 elif isinstance(value, (list, tuple)):
83 return self.range_type(value[0], value[1])
84 return value
86 def to_python(self, value):
87 if isinstance(value, str):
88 # Assume we're deserializing
89 vals = json.loads(value)
90 for end in ("lower", "upper"):
91 if end in vals:
92 vals[end] = self.base_field.to_python(vals[end])
93 value = self.range_type(**vals)
94 elif isinstance(value, (list, tuple)):
95 value = self.range_type(value[0], value[1])
96 return value
98 def set_attributes_from_name(self, name):
99 super().set_attributes_from_name(name)
100 self.base_field.set_attributes_from_name(name)
102 def value_to_string(self, obj):
103 value = self.value_from_object(obj)
104 if value is None:
105 return None
106 if value.isempty:
107 return json.dumps({"empty": True})
108 base_field = self.base_field
109 result = {"bounds": value._bounds}
110 for end in ("lower", "upper"):
111 val = getattr(value, end)
112 if val is None:
113 result[end] = None
114 else:
115 obj = AttributeSetter(base_field.attname, val)
116 result[end] = base_field.value_to_string(obj)
117 return json.dumps(result)
119 def formfield(self, **kwargs):
120 kwargs.setdefault("form_class", self.form_field)
121 return super().formfield(**kwargs)
124class IntegerRangeField(RangeField):
125 base_field = models.IntegerField
126 range_type = NumericRange
127 form_field = forms.IntegerRangeField
129 def db_type(self, connection):
130 return "int4range"
133class BigIntegerRangeField(RangeField):
134 base_field = models.BigIntegerField
135 range_type = NumericRange
136 form_field = forms.IntegerRangeField
138 def db_type(self, connection):
139 return "int8range"
142class DecimalRangeField(RangeField):
143 base_field = models.DecimalField
144 range_type = NumericRange
145 form_field = forms.DecimalRangeField
147 def db_type(self, connection):
148 return "numrange"
151class DateTimeRangeField(RangeField):
152 base_field = models.DateTimeField
153 range_type = DateTimeTZRange
154 form_field = forms.DateTimeRangeField
156 def db_type(self, connection):
157 return "tstzrange"
160class DateRangeField(RangeField):
161 base_field = models.DateField
162 range_type = DateRange
163 form_field = forms.DateRangeField
165 def db_type(self, connection):
166 return "daterange"
169RangeField.register_lookup(lookups.DataContains)
170RangeField.register_lookup(lookups.ContainedBy)
171RangeField.register_lookup(lookups.Overlap)
174class DateTimeRangeContains(PostgresOperatorLookup):
175 """
176 Lookup for Date/DateTimeRange containment to cast the rhs to the correct
177 type.
178 """
180 lookup_name = "contains"
181 postgres_operator = RangeOperators.CONTAINS
183 def process_rhs(self, compiler, connection):
184 # Transform rhs value for db lookup.
185 if isinstance(self.rhs, datetime.date):
186 value = models.Value(self.rhs)
187 self.rhs = value.resolve_expression(compiler.query)
188 return super().process_rhs(compiler, connection)
190 def as_postgresql(self, compiler, connection):
191 sql, params = super().as_postgresql(compiler, connection)
192 # Cast the rhs if needed.
193 cast_sql = ""
194 if (
195 isinstance(self.rhs, models.Expression)
196 and self.rhs._output_field_or_none
197 and
198 # Skip cast if rhs has a matching range type.
199 not isinstance(
200 self.rhs._output_field_or_none, self.lhs.output_field.__class__
201 )
202 ):
203 cast_internal_type = self.lhs.output_field.base_field.get_internal_type()
204 cast_sql = "::{}".format(connection.data_types.get(cast_internal_type))
205 return "%s%s" % (sql, cast_sql), params
208DateRangeField.register_lookup(DateTimeRangeContains)
209DateTimeRangeField.register_lookup(DateTimeRangeContains)
212class RangeContainedBy(PostgresOperatorLookup):
213 lookup_name = "contained_by"
214 type_mapping = {
215 "smallint": "int4range",
216 "integer": "int4range",
217 "bigint": "int8range",
218 "double precision": "numrange",
219 "numeric": "numrange",
220 "date": "daterange",
221 "timestamp with time zone": "tstzrange",
222 }
223 postgres_operator = RangeOperators.CONTAINED_BY
225 def process_rhs(self, compiler, connection):
226 rhs, rhs_params = super().process_rhs(compiler, connection)
227 # Ignore precision for DecimalFields.
228 db_type = self.lhs.output_field.cast_db_type(connection).split("(")[0]
229 cast_type = self.type_mapping[db_type]
230 return "%s::%s" % (rhs, cast_type), rhs_params
232 def process_lhs(self, compiler, connection):
233 lhs, lhs_params = super().process_lhs(compiler, connection)
234 if isinstance(self.lhs.output_field, models.FloatField):
235 lhs = "%s::numeric" % lhs
236 elif isinstance(self.lhs.output_field, models.SmallIntegerField):
237 lhs = "%s::integer" % lhs
238 return lhs, lhs_params
240 def get_prep_lookup(self):
241 return RangeField().get_prep_value(self.rhs)
244models.DateField.register_lookup(RangeContainedBy)
245models.DateTimeField.register_lookup(RangeContainedBy)
246models.IntegerField.register_lookup(RangeContainedBy)
247models.FloatField.register_lookup(RangeContainedBy)
248models.DecimalField.register_lookup(RangeContainedBy)
251@RangeField.register_lookup
252class FullyLessThan(PostgresOperatorLookup):
253 lookup_name = "fully_lt"
254 postgres_operator = RangeOperators.FULLY_LT
257@RangeField.register_lookup
258class FullGreaterThan(PostgresOperatorLookup):
259 lookup_name = "fully_gt"
260 postgres_operator = RangeOperators.FULLY_GT
263@RangeField.register_lookup
264class NotLessThan(PostgresOperatorLookup):
265 lookup_name = "not_lt"
266 postgres_operator = RangeOperators.NOT_LT
269@RangeField.register_lookup
270class NotGreaterThan(PostgresOperatorLookup):
271 lookup_name = "not_gt"
272 postgres_operator = RangeOperators.NOT_GT
275@RangeField.register_lookup
276class AdjacentToLookup(PostgresOperatorLookup):
277 lookup_name = "adjacent_to"
278 postgres_operator = RangeOperators.ADJACENT_TO
281@RangeField.register_lookup
282class RangeStartsWith(models.Transform):
283 lookup_name = "startswith"
284 function = "lower"
286 @property
287 def output_field(self):
288 return self.lhs.output_field.base_field
291@RangeField.register_lookup
292class RangeEndsWith(models.Transform):
293 lookup_name = "endswith"
294 function = "upper"
296 @property
297 def output_field(self):
298 return self.lhs.output_field.base_field
301@RangeField.register_lookup
302class IsEmpty(models.Transform):
303 lookup_name = "isempty"
304 function = "isempty"
305 output_field = models.BooleanField()
308@RangeField.register_lookup
309class LowerInclusive(models.Transform):
310 lookup_name = "lower_inc"
311 function = "LOWER_INC"
312 output_field = models.BooleanField()
315@RangeField.register_lookup
316class LowerInfinite(models.Transform):
317 lookup_name = "lower_inf"
318 function = "LOWER_INF"
319 output_field = models.BooleanField()
322@RangeField.register_lookup
323class UpperInclusive(models.Transform):
324 lookup_name = "upper_inc"
325 function = "UPPER_INC"
326 output_field = models.BooleanField()
329@RangeField.register_lookup
330class UpperInfinite(models.Transform):
331 lookup_name = "upper_inf"
332 function = "UPPER_INF"
333 output_field = models.BooleanField()