Coverage for /var/srv/projects/api.amasfac.comuna18.com/tmp/venv/lib/python3.9/site-packages/django/contrib/postgres/fields/array.py: 38%

203 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2023-07-17 14:22 -0600

1import json 

2 

3from django.contrib.postgres import lookups 

4from django.contrib.postgres.forms import SimpleArrayField 

5from django.contrib.postgres.validators import ArrayMaxLengthValidator 

6from django.core import checks, exceptions 

7from django.db.models import Field, Func, IntegerField, Transform, Value 

8from django.db.models.fields.mixins import CheckFieldDefaultMixin 

9from django.db.models.lookups import Exact, In 

10from django.utils.translation import gettext_lazy as _ 

11 

12from ..utils import prefix_validation_error 

13from .utils import AttributeSetter 

14 

15__all__ = ["ArrayField"] 

16 

17 

18class ArrayField(CheckFieldDefaultMixin, Field): 

19 empty_strings_allowed = False 

20 default_error_messages = { 

21 "item_invalid": _("Item %(nth)s in the array did not validate:"), 

22 "nested_array_mismatch": _("Nested arrays must have the same length."), 

23 } 

24 _default_hint = ("list", "[]") 

25 

26 def __init__(self, base_field, size=None, **kwargs): 

27 self.base_field = base_field 

28 self.size = size 

29 if self.size: 29 ↛ 30line 29 didn't jump to line 30

30 self.default_validators = [ 

31 *self.default_validators, 

32 ArrayMaxLengthValidator(self.size), 

33 ] 

34 # For performance, only add a from_db_value() method if the base field 

35 # implements it. 

36 if hasattr(self.base_field, "from_db_value"): 36 ↛ 37line 36 didn't jump to line 37, because the condition on line 36 was never true

37 self.from_db_value = self._from_db_value 

38 super().__init__(**kwargs) 

39 

40 @property 

41 def model(self): 

42 try: 

43 return self.__dict__["model"] 

44 except KeyError: 

45 raise AttributeError( 

46 "'%s' object has no attribute 'model'" % self.__class__.__name__ 

47 ) 

48 

49 @model.setter 

50 def model(self, model): 

51 self.__dict__["model"] = model 

52 self.base_field.model = model 

53 

54 @classmethod 

55 def _choices_is_value(cls, value): 

56 return isinstance(value, (list, tuple)) or super()._choices_is_value(value) 

57 

58 def check(self, **kwargs): 

59 errors = super().check(**kwargs) 

60 if self.base_field.remote_field: 

61 errors.append( 

62 checks.Error( 

63 "Base field for array cannot be a related field.", 

64 obj=self, 

65 id="postgres.E002", 

66 ) 

67 ) 

68 else: 

69 # Remove the field name checks as they are not needed here. 

70 base_errors = self.base_field.check() 

71 if base_errors: 

72 messages = "\n ".join( 

73 "%s (%s)" % (error.msg, error.id) for error in base_errors 

74 ) 

75 errors.append( 

76 checks.Error( 

77 "Base field for array has errors:\n %s" % messages, 

78 obj=self, 

79 id="postgres.E001", 

80 ) 

81 ) 

82 return errors 

83 

84 def set_attributes_from_name(self, name): 

85 super().set_attributes_from_name(name) 

86 self.base_field.set_attributes_from_name(name) 

87 

88 @property 

89 def description(self): 

90 return "Array of %s" % self.base_field.description 

91 

92 def db_type(self, connection): 

93 size = self.size or "" 

94 return "%s[%s]" % (self.base_field.db_type(connection), size) 

95 

96 def cast_db_type(self, connection): 

97 size = self.size or "" 

98 return "%s[%s]" % (self.base_field.cast_db_type(connection), size) 

99 

100 def get_placeholder(self, value, compiler, connection): 

101 return "%s::{}".format(self.db_type(connection)) 

102 

103 def get_db_prep_value(self, value, connection, prepared=False): 

104 if isinstance(value, (list, tuple)): 

105 return [ 

106 self.base_field.get_db_prep_value(i, connection, prepared=False) 

107 for i in value 

108 ] 

109 return value 

110 

111 def deconstruct(self): 

112 name, path, args, kwargs = super().deconstruct() 

113 if path == "django.contrib.postgres.fields.array.ArrayField": 

114 path = "django.contrib.postgres.fields.ArrayField" 

115 kwargs.update( 

116 { 

117 "base_field": self.base_field.clone(), 

118 "size": self.size, 

119 } 

120 ) 

121 return name, path, args, kwargs 

122 

123 def to_python(self, value): 

124 if isinstance(value, str): 

125 # Assume we're deserializing 

126 vals = json.loads(value) 

127 value = [self.base_field.to_python(val) for val in vals] 

128 return value 

129 

130 def _from_db_value(self, value, expression, connection): 

131 if value is None: 

132 return value 

133 return [ 

134 self.base_field.from_db_value(item, expression, connection) 

135 for item in value 

136 ] 

137 

138 def value_to_string(self, obj): 

139 values = [] 

140 vals = self.value_from_object(obj) 

141 base_field = self.base_field 

142 

143 for val in vals: 

144 if val is None: 

145 values.append(None) 

146 else: 

147 obj = AttributeSetter(base_field.attname, val) 

148 values.append(base_field.value_to_string(obj)) 

149 return json.dumps(values) 

150 

151 def get_transform(self, name): 

152 transform = super().get_transform(name) 

153 if transform: 

154 return transform 

155 if "_" not in name: 

156 try: 

157 index = int(name) 

158 except ValueError: 

159 pass 

160 else: 

161 index += 1 # postgres uses 1-indexing 

162 return IndexTransformFactory(index, self.base_field) 

163 try: 

164 start, end = name.split("_") 

165 start = int(start) + 1 

166 end = int(end) # don't add one here because postgres slices are weird 

167 except ValueError: 

168 pass 

169 else: 

170 return SliceTransformFactory(start, end) 

171 

172 def validate(self, value, model_instance): 

173 super().validate(value, model_instance) 

174 for index, part in enumerate(value): 

175 try: 

176 self.base_field.validate(part, model_instance) 

177 except exceptions.ValidationError as error: 

178 raise prefix_validation_error( 

179 error, 

180 prefix=self.error_messages["item_invalid"], 

181 code="item_invalid", 

182 params={"nth": index + 1}, 

183 ) 

184 if isinstance(self.base_field, ArrayField): 

185 if len({len(i) for i in value}) > 1: 

186 raise exceptions.ValidationError( 

187 self.error_messages["nested_array_mismatch"], 

188 code="nested_array_mismatch", 

189 ) 

190 

191 def run_validators(self, value): 

192 super().run_validators(value) 

193 for index, part in enumerate(value): 

194 try: 

195 self.base_field.run_validators(part) 

196 except exceptions.ValidationError as error: 

197 raise prefix_validation_error( 

198 error, 

199 prefix=self.error_messages["item_invalid"], 

200 code="item_invalid", 

201 params={"nth": index + 1}, 

202 ) 

203 

204 def formfield(self, **kwargs): 

205 return super().formfield( 

206 **{ 

207 "form_class": SimpleArrayField, 

208 "base_field": self.base_field.formfield(), 

209 "max_length": self.size, 

210 **kwargs, 

211 } 

212 ) 

213 

214 

215class ArrayRHSMixin: 

216 def __init__(self, lhs, rhs): 

217 if isinstance(rhs, (tuple, list)): 

218 expressions = [] 

219 for value in rhs: 

220 if not hasattr(value, "resolve_expression"): 

221 field = lhs.output_field 

222 value = Value(field.base_field.get_prep_value(value)) 

223 expressions.append(value) 

224 rhs = Func( 

225 *expressions, 

226 function="ARRAY", 

227 template="%(function)s[%(expressions)s]", 

228 ) 

229 super().__init__(lhs, rhs) 

230 

231 def process_rhs(self, compiler, connection): 

232 rhs, rhs_params = super().process_rhs(compiler, connection) 

233 cast_type = self.lhs.output_field.cast_db_type(connection) 

234 return "%s::%s" % (rhs, cast_type), rhs_params 

235 

236 

237@ArrayField.register_lookup 

238class ArrayContains(ArrayRHSMixin, lookups.DataContains): 

239 pass 

240 

241 

242@ArrayField.register_lookup 

243class ArrayContainedBy(ArrayRHSMixin, lookups.ContainedBy): 

244 pass 

245 

246 

247@ArrayField.register_lookup 

248class ArrayExact(ArrayRHSMixin, Exact): 

249 pass 

250 

251 

252@ArrayField.register_lookup 

253class ArrayOverlap(ArrayRHSMixin, lookups.Overlap): 

254 pass 

255 

256 

257@ArrayField.register_lookup 

258class ArrayLenTransform(Transform): 

259 lookup_name = "len" 

260 output_field = IntegerField() 

261 

262 def as_sql(self, compiler, connection): 

263 lhs, params = compiler.compile(self.lhs) 

264 # Distinguish NULL and empty arrays 

265 return ( 

266 "CASE WHEN %(lhs)s IS NULL THEN NULL ELSE " 

267 "coalesce(array_length(%(lhs)s, 1), 0) END" 

268 ) % {"lhs": lhs}, params 

269 

270 

271@ArrayField.register_lookup 

272class ArrayInLookup(In): 

273 def get_prep_lookup(self): 

274 values = super().get_prep_lookup() 

275 if hasattr(values, "resolve_expression"): 

276 return values 

277 # In.process_rhs() expects values to be hashable, so convert lists 

278 # to tuples. 

279 prepared_values = [] 

280 for value in values: 

281 if hasattr(value, "resolve_expression"): 

282 prepared_values.append(value) 

283 else: 

284 prepared_values.append(tuple(value)) 

285 return prepared_values 

286 

287 

288class IndexTransform(Transform): 

289 def __init__(self, index, base_field, *args, **kwargs): 

290 super().__init__(*args, **kwargs) 

291 self.index = index 

292 self.base_field = base_field 

293 

294 def as_sql(self, compiler, connection): 

295 lhs, params = compiler.compile(self.lhs) 

296 return "%s[%%s]" % lhs, params + [self.index] 

297 

298 @property 

299 def output_field(self): 

300 return self.base_field 

301 

302 

303class IndexTransformFactory: 

304 def __init__(self, index, base_field): 

305 self.index = index 

306 self.base_field = base_field 

307 

308 def __call__(self, *args, **kwargs): 

309 return IndexTransform(self.index, self.base_field, *args, **kwargs) 

310 

311 

312class SliceTransform(Transform): 

313 def __init__(self, start, end, *args, **kwargs): 

314 super().__init__(*args, **kwargs) 

315 self.start = start 

316 self.end = end 

317 

318 def as_sql(self, compiler, connection): 

319 lhs, params = compiler.compile(self.lhs) 

320 return "%s[%%s:%%s]" % lhs, params + [self.start, self.end] 

321 

322 

323class SliceTransformFactory: 

324 def __init__(self, start, end): 

325 self.start = start 

326 self.end = end 

327 

328 def __call__(self, *args, **kwargs): 

329 return SliceTransform(self.start, self.end, *args, **kwargs)