Coverage for /var/srv/projects/api.amasfac.comuna18.com/tmp/venv/lib/python3.9/site-packages/django/db/models/expressions.py: 42%

883 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2023-07-17 14:22 -0600

1import copy 

2import datetime 

3import functools 

4import inspect 

5from decimal import Decimal 

6from uuid import UUID 

7 

8from django.core.exceptions import EmptyResultSet, FieldError 

9from django.db import DatabaseError, NotSupportedError, connection 

10from django.db.models import fields 

11from django.db.models.constants import LOOKUP_SEP 

12from django.db.models.query_utils import Q 

13from django.utils.deconstruct import deconstructible 

14from django.utils.functional import cached_property 

15from django.utils.hashable import make_hashable 

16 

17 

18class SQLiteNumericMixin: 

19 """ 

20 Some expressions with output_field=DecimalField() must be cast to 

21 numeric to be properly filtered. 

22 """ 

23 

24 def as_sqlite(self, compiler, connection, **extra_context): 

25 sql, params = self.as_sql(compiler, connection, **extra_context) 

26 try: 

27 if self.output_field.get_internal_type() == "DecimalField": 

28 sql = "CAST(%s AS NUMERIC)" % sql 

29 except FieldError: 

30 pass 

31 return sql, params 

32 

33 

34class Combinable: 

35 """ 

36 Provide the ability to combine one or two objects with 

37 some connector. For example F('foo') + F('bar'). 

38 """ 

39 

40 # Arithmetic connectors 

41 ADD = "+" 

42 SUB = "-" 

43 MUL = "*" 

44 DIV = "/" 

45 POW = "^" 

46 # The following is a quoted % operator - it is quoted because it can be 

47 # used in strings that also have parameter substitution. 

48 MOD = "%%" 

49 

50 # Bitwise operators - note that these are generated by .bitand() 

51 # and .bitor(), the '&' and '|' are reserved for boolean operator 

52 # usage. 

53 BITAND = "&" 

54 BITOR = "|" 

55 BITLEFTSHIFT = "<<" 

56 BITRIGHTSHIFT = ">>" 

57 BITXOR = "#" 

58 

59 def _combine(self, other, connector, reversed): 

60 if not hasattr(other, "resolve_expression"): 

61 # everything must be resolvable to an expression 

62 other = Value(other) 

63 

64 if reversed: 

65 return CombinedExpression(other, connector, self) 

66 return CombinedExpression(self, connector, other) 

67 

68 ############# 

69 # OPERATORS # 

70 ############# 

71 

72 def __neg__(self): 

73 return self._combine(-1, self.MUL, False) 

74 

75 def __add__(self, other): 

76 return self._combine(other, self.ADD, False) 

77 

78 def __sub__(self, other): 

79 return self._combine(other, self.SUB, False) 

80 

81 def __mul__(self, other): 

82 return self._combine(other, self.MUL, False) 

83 

84 def __truediv__(self, other): 

85 return self._combine(other, self.DIV, False) 

86 

87 def __mod__(self, other): 

88 return self._combine(other, self.MOD, False) 

89 

90 def __pow__(self, other): 

91 return self._combine(other, self.POW, False) 

92 

93 def __and__(self, other): 

94 if getattr(self, "conditional", False) and getattr(other, "conditional", False): 

95 return Q(self) & Q(other) 

96 raise NotImplementedError( 

97 "Use .bitand() and .bitor() for bitwise logical operations." 

98 ) 

99 

100 def bitand(self, other): 

101 return self._combine(other, self.BITAND, False) 

102 

103 def bitleftshift(self, other): 

104 return self._combine(other, self.BITLEFTSHIFT, False) 

105 

106 def bitrightshift(self, other): 

107 return self._combine(other, self.BITRIGHTSHIFT, False) 

108 

109 def bitxor(self, other): 

110 return self._combine(other, self.BITXOR, False) 

111 

112 def __or__(self, other): 

113 if getattr(self, "conditional", False) and getattr(other, "conditional", False): 

114 return Q(self) | Q(other) 

115 raise NotImplementedError( 

116 "Use .bitand() and .bitor() for bitwise logical operations." 

117 ) 

118 

119 def bitor(self, other): 

120 return self._combine(other, self.BITOR, False) 

121 

122 def __radd__(self, other): 

123 return self._combine(other, self.ADD, True) 

124 

125 def __rsub__(self, other): 

126 return self._combine(other, self.SUB, True) 

127 

128 def __rmul__(self, other): 

129 return self._combine(other, self.MUL, True) 

130 

131 def __rtruediv__(self, other): 

132 return self._combine(other, self.DIV, True) 

133 

134 def __rmod__(self, other): 

135 return self._combine(other, self.MOD, True) 

136 

137 def __rpow__(self, other): 

138 return self._combine(other, self.POW, True) 

139 

140 def __rand__(self, other): 

141 raise NotImplementedError( 

142 "Use .bitand() and .bitor() for bitwise logical operations." 

143 ) 

144 

145 def __ror__(self, other): 

146 raise NotImplementedError( 

147 "Use .bitand() and .bitor() for bitwise logical operations." 

148 ) 

149 

150 

151class BaseExpression: 

152 """Base class for all query expressions.""" 

153 

154 empty_result_set_value = NotImplemented 

155 # aggregate specific fields 

156 is_summary = False 

157 _output_field_resolved_to_none = False 

158 # Can the expression be used in a WHERE clause? 

159 filterable = True 

160 # Can the expression can be used as a source expression in Window? 

161 window_compatible = False 

162 

163 def __init__(self, output_field=None): 

164 if output_field is not None: 

165 self.output_field = output_field 

166 

167 def __getstate__(self): 

168 state = self.__dict__.copy() 

169 state.pop("convert_value", None) 

170 return state 

171 

172 def get_db_converters(self, connection): 

173 return ( 

174 [] 

175 if self.convert_value is self._convert_value_noop 

176 else [self.convert_value] 

177 ) + self.output_field.get_db_converters(connection) 

178 

179 def get_source_expressions(self): 

180 return [] 

181 

182 def set_source_expressions(self, exprs): 

183 assert not exprs 

184 

185 def _parse_expressions(self, *expressions): 

186 return [ 

187 arg 

188 if hasattr(arg, "resolve_expression") 

189 else (F(arg) if isinstance(arg, str) else Value(arg)) 

190 for arg in expressions 

191 ] 

192 

193 def as_sql(self, compiler, connection): 

194 """ 

195 Responsible for returning a (sql, [params]) tuple to be included 

196 in the current query. 

197 

198 Different backends can provide their own implementation, by 

199 providing an `as_{vendor}` method and patching the Expression: 

200 

201 ``` 

202 def override_as_sql(self, compiler, connection): 

203 # custom logic 

204 return super().as_sql(compiler, connection) 

205 setattr(Expression, 'as_' + connection.vendor, override_as_sql) 

206 ``` 

207 

208 Arguments: 

209 * compiler: the query compiler responsible for generating the query. 

210 Must have a compile method, returning a (sql, [params]) tuple. 

211 Calling compiler(value) will return a quoted `value`. 

212 

213 * connection: the database connection used for the current query. 

214 

215 Return: (sql, params) 

216 Where `sql` is a string containing ordered sql parameters to be 

217 replaced with the elements of the list `params`. 

218 """ 

219 raise NotImplementedError("Subclasses must implement as_sql()") 

220 

221 @cached_property 

222 def contains_aggregate(self): 

223 return any( 

224 expr and expr.contains_aggregate for expr in self.get_source_expressions() 

225 ) 

226 

227 @cached_property 

228 def contains_over_clause(self): 

229 return any( 

230 expr and expr.contains_over_clause for expr in self.get_source_expressions() 

231 ) 

232 

233 @cached_property 

234 def contains_column_references(self): 

235 return any( 

236 expr and expr.contains_column_references 

237 for expr in self.get_source_expressions() 

238 ) 

239 

240 def resolve_expression( 

241 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

242 ): 

243 """ 

244 Provide the chance to do any preprocessing or validation before being 

245 added to the query. 

246 

247 Arguments: 

248 * query: the backend query implementation 

249 * allow_joins: boolean allowing or denying use of joins 

250 in this query 

251 * reuse: a set of reusable joins for multijoins 

252 * summarize: a terminal aggregate clause 

253 * for_save: whether this expression about to be used in a save or update 

254 

255 Return: an Expression to be added to the query. 

256 """ 

257 c = self.copy() 

258 c.is_summary = summarize 

259 c.set_source_expressions( 

260 [ 

261 expr.resolve_expression(query, allow_joins, reuse, summarize) 

262 if expr 

263 else None 

264 for expr in c.get_source_expressions() 

265 ] 

266 ) 

267 return c 

268 

269 @property 

270 def conditional(self): 

271 return isinstance(self.output_field, fields.BooleanField) 

272 

273 @property 

274 def field(self): 

275 return self.output_field 

276 

277 @cached_property 

278 def output_field(self): 

279 """Return the output type of this expressions.""" 

280 output_field = self._resolve_output_field() 

281 if output_field is None: 281 ↛ 282line 281 didn't jump to line 282, because the condition on line 281 was never true

282 self._output_field_resolved_to_none = True 

283 raise FieldError("Cannot resolve expression type, unknown output_field") 

284 return output_field 

285 

286 @cached_property 

287 def _output_field_or_none(self): 

288 """ 

289 Return the output field of this expression, or None if 

290 _resolve_output_field() didn't return an output type. 

291 """ 

292 try: 

293 return self.output_field 

294 except FieldError: 

295 if not self._output_field_resolved_to_none: 

296 raise 

297 

298 def _resolve_output_field(self): 

299 """ 

300 Attempt to infer the output type of the expression. If the output 

301 fields of all source fields match then, simply infer the same type 

302 here. This isn't always correct, but it makes sense most of the time. 

303 

304 Consider the difference between `2 + 2` and `2 / 3`. Inferring 

305 the type here is a convenience for the common case. The user should 

306 supply their own output_field with more complex computations. 

307 

308 If a source's output field resolves to None, exclude it from this check. 

309 If all sources are None, then an error is raised higher up the stack in 

310 the output_field property. 

311 """ 

312 sources_iter = ( 

313 source for source in self.get_source_fields() if source is not None 

314 ) 

315 for output_field in sources_iter: 315 ↛ exitline 315 didn't return from function '_resolve_output_field', because the loop on line 315 didn't complete

316 for source in sources_iter: 316 ↛ 317line 316 didn't jump to line 317, because the loop on line 316 never started

317 if not isinstance(output_field, source.__class__): 

318 raise FieldError( 

319 "Expression contains mixed types: %s, %s. You must " 

320 "set output_field." 

321 % ( 

322 output_field.__class__.__name__, 

323 source.__class__.__name__, 

324 ) 

325 ) 

326 return output_field 

327 

328 @staticmethod 

329 def _convert_value_noop(value, expression, connection): 

330 return value 

331 

332 @cached_property 

333 def convert_value(self): 

334 """ 

335 Expressions provide their own converters because users have the option 

336 of manually specifying the output_field which may be a different type 

337 from the one the database returns. 

338 """ 

339 field = self.output_field 

340 internal_type = field.get_internal_type() 

341 if internal_type == "FloatField": 341 ↛ 342line 341 didn't jump to line 342, because the condition on line 341 was never true

342 return ( 

343 lambda value, expression, connection: None 

344 if value is None 

345 else float(value) 

346 ) 

347 elif internal_type.endswith("IntegerField"): 

348 return ( 

349 lambda value, expression, connection: None 

350 if value is None 

351 else int(value) 

352 ) 

353 elif internal_type == "DecimalField": 353 ↛ 359line 353 didn't jump to line 359, because the condition on line 353 was never false

354 return ( 

355 lambda value, expression, connection: None 

356 if value is None 

357 else Decimal(value) 

358 ) 

359 return self._convert_value_noop 

360 

361 def get_lookup(self, lookup): 

362 return self.output_field.get_lookup(lookup) 

363 

364 def get_transform(self, name): 

365 return self.output_field.get_transform(name) 

366 

367 def relabeled_clone(self, change_map): 

368 clone = self.copy() 

369 clone.set_source_expressions( 

370 [ 

371 e.relabeled_clone(change_map) if e is not None else None 

372 for e in self.get_source_expressions() 

373 ] 

374 ) 

375 return clone 

376 

377 def copy(self): 

378 return copy.copy(self) 

379 

380 def get_group_by_cols(self, alias=None): 

381 if not self.contains_aggregate: 

382 return [self] 

383 cols = [] 

384 for source in self.get_source_expressions(): 

385 cols.extend(source.get_group_by_cols()) 

386 return cols 

387 

388 def get_source_fields(self): 

389 """Return the underlying field types used by this aggregate.""" 

390 return [e._output_field_or_none for e in self.get_source_expressions()] 

391 

392 def asc(self, **kwargs): 

393 return OrderBy(self, **kwargs) 

394 

395 def desc(self, **kwargs): 

396 return OrderBy(self, descending=True, **kwargs) 

397 

398 def reverse_ordering(self): 

399 return self 

400 

401 def flatten(self): 

402 """ 

403 Recursively yield this expression and all subexpressions, in 

404 depth-first order. 

405 """ 

406 yield self 

407 for expr in self.get_source_expressions(): 

408 if expr: 

409 if hasattr(expr, "flatten"): 

410 yield from expr.flatten() 

411 else: 

412 yield expr 

413 

414 def select_format(self, compiler, sql, params): 

415 """ 

416 Custom format for select clauses. For example, EXISTS expressions need 

417 to be wrapped in CASE WHEN on Oracle. 

418 """ 

419 if hasattr(self.output_field, "select_format"): 419 ↛ 421line 419 didn't jump to line 421, because the condition on line 419 was never false

420 return self.output_field.select_format(compiler, sql, params) 

421 return sql, params 

422 

423 

424@deconstructible 

425class Expression(BaseExpression, Combinable): 

426 """An expression that can be combined with other expressions.""" 

427 

428 @cached_property 

429 def identity(self): 

430 constructor_signature = inspect.signature(self.__init__) 

431 args, kwargs = self._constructor_args 

432 signature = constructor_signature.bind_partial(*args, **kwargs) 

433 signature.apply_defaults() 

434 arguments = signature.arguments.items() 

435 identity = [self.__class__] 

436 for arg, value in arguments: 

437 if isinstance(value, fields.Field): 

438 if value.name and value.model: 

439 value = (value.model._meta.label, value.name) 

440 else: 

441 value = type(value) 

442 else: 

443 value = make_hashable(value) 

444 identity.append((arg, value)) 

445 return tuple(identity) 

446 

447 def __eq__(self, other): 

448 if not isinstance(other, Expression): 

449 return NotImplemented 

450 return other.identity == self.identity 

451 

452 def __hash__(self): 

453 return hash(self.identity) 

454 

455 

456_connector_combinators = { 

457 connector: [ 

458 (fields.IntegerField, fields.IntegerField, fields.IntegerField), 

459 (fields.IntegerField, fields.DecimalField, fields.DecimalField), 

460 (fields.DecimalField, fields.IntegerField, fields.DecimalField), 

461 (fields.IntegerField, fields.FloatField, fields.FloatField), 

462 (fields.FloatField, fields.IntegerField, fields.FloatField), 

463 ] 

464 for connector in (Combinable.ADD, Combinable.SUB, Combinable.MUL, Combinable.DIV) 

465} 

466 

467 

468@functools.lru_cache(maxsize=128) 

469def _resolve_combined_type(connector, lhs_type, rhs_type): 

470 combinators = _connector_combinators.get(connector, ()) 

471 for combinator_lhs_type, combinator_rhs_type, combined_type in combinators: 

472 if issubclass(lhs_type, combinator_lhs_type) and issubclass( 

473 rhs_type, combinator_rhs_type 

474 ): 

475 return combined_type 

476 

477 

478class CombinedExpression(SQLiteNumericMixin, Expression): 

479 def __init__(self, lhs, connector, rhs, output_field=None): 

480 super().__init__(output_field=output_field) 

481 self.connector = connector 

482 self.lhs = lhs 

483 self.rhs = rhs 

484 

485 def __repr__(self): 

486 return "<{}: {}>".format(self.__class__.__name__, self) 

487 

488 def __str__(self): 

489 return "{} {} {}".format(self.lhs, self.connector, self.rhs) 

490 

491 def get_source_expressions(self): 

492 return [self.lhs, self.rhs] 

493 

494 def set_source_expressions(self, exprs): 

495 self.lhs, self.rhs = exprs 

496 

497 def _resolve_output_field(self): 

498 try: 

499 return super()._resolve_output_field() 

500 except FieldError: 

501 combined_type = _resolve_combined_type( 

502 self.connector, 

503 type(self.lhs.output_field), 

504 type(self.rhs.output_field), 

505 ) 

506 if combined_type is None: 

507 raise 

508 return combined_type() 

509 

510 def as_sql(self, compiler, connection): 

511 expressions = [] 

512 expression_params = [] 

513 sql, params = compiler.compile(self.lhs) 

514 expressions.append(sql) 

515 expression_params.extend(params) 

516 sql, params = compiler.compile(self.rhs) 

517 expressions.append(sql) 

518 expression_params.extend(params) 

519 # order of precedence 

520 expression_wrapper = "(%s)" 

521 sql = connection.ops.combine_expression(self.connector, expressions) 

522 return expression_wrapper % sql, expression_params 

523 

524 def resolve_expression( 

525 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

526 ): 

527 lhs = self.lhs.resolve_expression( 

528 query, allow_joins, reuse, summarize, for_save 

529 ) 

530 rhs = self.rhs.resolve_expression( 

531 query, allow_joins, reuse, summarize, for_save 

532 ) 

533 if not isinstance(self, (DurationExpression, TemporalSubtraction)): 

534 try: 

535 lhs_type = lhs.output_field.get_internal_type() 

536 except (AttributeError, FieldError): 

537 lhs_type = None 

538 try: 

539 rhs_type = rhs.output_field.get_internal_type() 

540 except (AttributeError, FieldError): 

541 rhs_type = None 

542 if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type: 

543 return DurationExpression( 

544 self.lhs, self.connector, self.rhs 

545 ).resolve_expression( 

546 query, 

547 allow_joins, 

548 reuse, 

549 summarize, 

550 for_save, 

551 ) 

552 datetime_fields = {"DateField", "DateTimeField", "TimeField"} 

553 if ( 

554 self.connector == self.SUB 

555 and lhs_type in datetime_fields 

556 and lhs_type == rhs_type 

557 ): 

558 return TemporalSubtraction(self.lhs, self.rhs).resolve_expression( 

559 query, 

560 allow_joins, 

561 reuse, 

562 summarize, 

563 for_save, 

564 ) 

565 c = self.copy() 

566 c.is_summary = summarize 

567 c.lhs = lhs 

568 c.rhs = rhs 

569 return c 

570 

571 

572class DurationExpression(CombinedExpression): 

573 def compile(self, side, compiler, connection): 

574 try: 

575 output = side.output_field 

576 except FieldError: 

577 pass 

578 else: 

579 if output.get_internal_type() == "DurationField": 

580 sql, params = compiler.compile(side) 

581 return connection.ops.format_for_duration_arithmetic(sql), params 

582 return compiler.compile(side) 

583 

584 def as_sql(self, compiler, connection): 

585 if connection.features.has_native_duration_field: 

586 return super().as_sql(compiler, connection) 

587 connection.ops.check_expression_support(self) 

588 expressions = [] 

589 expression_params = [] 

590 sql, params = self.compile(self.lhs, compiler, connection) 

591 expressions.append(sql) 

592 expression_params.extend(params) 

593 sql, params = self.compile(self.rhs, compiler, connection) 

594 expressions.append(sql) 

595 expression_params.extend(params) 

596 # order of precedence 

597 expression_wrapper = "(%s)" 

598 sql = connection.ops.combine_duration_expression(self.connector, expressions) 

599 return expression_wrapper % sql, expression_params 

600 

601 def as_sqlite(self, compiler, connection, **extra_context): 

602 sql, params = self.as_sql(compiler, connection, **extra_context) 

603 if self.connector in {Combinable.MUL, Combinable.DIV}: 

604 try: 

605 lhs_type = self.lhs.output_field.get_internal_type() 

606 rhs_type = self.rhs.output_field.get_internal_type() 

607 except (AttributeError, FieldError): 

608 pass 

609 else: 

610 allowed_fields = { 

611 "DecimalField", 

612 "DurationField", 

613 "FloatField", 

614 "IntegerField", 

615 } 

616 if lhs_type not in allowed_fields or rhs_type not in allowed_fields: 

617 raise DatabaseError( 

618 f"Invalid arguments for operator {self.connector}." 

619 ) 

620 return sql, params 

621 

622 

623class TemporalSubtraction(CombinedExpression): 

624 output_field = fields.DurationField() 

625 

626 def __init__(self, lhs, rhs): 

627 super().__init__(lhs, self.SUB, rhs) 

628 

629 def as_sql(self, compiler, connection): 

630 connection.ops.check_expression_support(self) 

631 lhs = compiler.compile(self.lhs) 

632 rhs = compiler.compile(self.rhs) 

633 return connection.ops.subtract_temporals( 

634 self.lhs.output_field.get_internal_type(), lhs, rhs 

635 ) 

636 

637 

638@deconstructible 

639class F(Combinable): 

640 """An object capable of resolving references to existing query objects.""" 

641 

642 def __init__(self, name): 

643 """ 

644 Arguments: 

645 * name: the name of the field this expression references 

646 """ 

647 self.name = name 

648 

649 def __repr__(self): 

650 return "{}({})".format(self.__class__.__name__, self.name) 

651 

652 def resolve_expression( 

653 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

654 ): 

655 return query.resolve_ref(self.name, allow_joins, reuse, summarize) 

656 

657 def asc(self, **kwargs): 

658 return OrderBy(self, **kwargs) 

659 

660 def desc(self, **kwargs): 

661 return OrderBy(self, descending=True, **kwargs) 

662 

663 def __eq__(self, other): 

664 return self.__class__ == other.__class__ and self.name == other.name 

665 

666 def __hash__(self): 

667 return hash(self.name) 

668 

669 

670class ResolvedOuterRef(F): 

671 """ 

672 An object that contains a reference to an outer query. 

673 

674 In this case, the reference to the outer query has been resolved because 

675 the inner query has been used as a subquery. 

676 """ 

677 

678 contains_aggregate = False 

679 

680 def as_sql(self, *args, **kwargs): 

681 raise ValueError( 

682 "This queryset contains a reference to an outer query and may " 

683 "only be used in a subquery." 

684 ) 

685 

686 def resolve_expression(self, *args, **kwargs): 

687 col = super().resolve_expression(*args, **kwargs) 

688 # FIXME: Rename possibly_multivalued to multivalued and fix detection 

689 # for non-multivalued JOINs (e.g. foreign key fields). This should take 

690 # into account only many-to-many and one-to-many relationships. 

691 col.possibly_multivalued = LOOKUP_SEP in self.name 

692 return col 

693 

694 def relabeled_clone(self, relabels): 

695 return self 

696 

697 def get_group_by_cols(self, alias=None): 

698 return [] 

699 

700 

701class OuterRef(F): 

702 contains_aggregate = False 

703 

704 def resolve_expression(self, *args, **kwargs): 

705 if isinstance(self.name, self.__class__): 

706 return self.name 

707 return ResolvedOuterRef(self.name) 

708 

709 def relabeled_clone(self, relabels): 

710 return self 

711 

712 

713class Func(SQLiteNumericMixin, Expression): 

714 """An SQL function call.""" 

715 

716 function = None 

717 template = "%(function)s(%(expressions)s)" 

718 arg_joiner = ", " 

719 arity = None # The number of arguments the function accepts. 

720 

721 def __init__(self, *expressions, output_field=None, **extra): 

722 if self.arity is not None and len(expressions) != self.arity: 722 ↛ 723line 722 didn't jump to line 723, because the condition on line 722 was never true

723 raise TypeError( 

724 "'%s' takes exactly %s %s (%s given)" 

725 % ( 

726 self.__class__.__name__, 

727 self.arity, 

728 "argument" if self.arity == 1 else "arguments", 

729 len(expressions), 

730 ) 

731 ) 

732 super().__init__(output_field=output_field) 

733 self.source_expressions = self._parse_expressions(*expressions) 

734 self.extra = extra 

735 

736 def __repr__(self): 

737 args = self.arg_joiner.join(str(arg) for arg in self.source_expressions) 

738 extra = {**self.extra, **self._get_repr_options()} 

739 if extra: 

740 extra = ", ".join( 

741 str(key) + "=" + str(val) for key, val in sorted(extra.items()) 

742 ) 

743 return "{}({}, {})".format(self.__class__.__name__, args, extra) 

744 return "{}({})".format(self.__class__.__name__, args) 

745 

746 def _get_repr_options(self): 

747 """Return a dict of extra __init__() options to include in the repr.""" 

748 return {} 

749 

750 def get_source_expressions(self): 

751 return self.source_expressions 

752 

753 def set_source_expressions(self, exprs): 

754 self.source_expressions = exprs 

755 

756 def resolve_expression( 

757 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

758 ): 

759 c = self.copy() 

760 c.is_summary = summarize 

761 for pos, arg in enumerate(c.source_expressions): 

762 c.source_expressions[pos] = arg.resolve_expression( 

763 query, allow_joins, reuse, summarize, for_save 

764 ) 

765 return c 

766 

767 def as_sql( 

768 self, 

769 compiler, 

770 connection, 

771 function=None, 

772 template=None, 

773 arg_joiner=None, 

774 **extra_context, 

775 ): 

776 connection.ops.check_expression_support(self) 

777 sql_parts = [] 

778 params = [] 

779 for arg in self.source_expressions: 

780 try: 

781 arg_sql, arg_params = compiler.compile(arg) 

782 except EmptyResultSet: 

783 empty_result_set_value = getattr( 

784 arg, "empty_result_set_value", NotImplemented 

785 ) 

786 if empty_result_set_value is NotImplemented: 

787 raise 

788 arg_sql, arg_params = compiler.compile(Value(empty_result_set_value)) 

789 sql_parts.append(arg_sql) 

790 params.extend(arg_params) 

791 data = {**self.extra, **extra_context} 

792 # Use the first supplied value in this order: the parameter to this 

793 # method, a value supplied in __init__()'s **extra (the value in 

794 # `data`), or the value defined on the class. 

795 if function is not None: 795 ↛ 796line 795 didn't jump to line 796, because the condition on line 795 was never true

796 data["function"] = function 

797 else: 

798 data.setdefault("function", self.function) 

799 template = template or data.get("template", self.template) 

800 arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner) 

801 data["expressions"] = data["field"] = arg_joiner.join(sql_parts) 

802 return template % data, params 

803 

804 def copy(self): 

805 copy = super().copy() 

806 copy.source_expressions = self.source_expressions[:] 

807 copy.extra = self.extra.copy() 

808 return copy 

809 

810 

811class Value(SQLiteNumericMixin, Expression): 

812 """Represent a wrapped value as a node within an expression.""" 

813 

814 # Provide a default value for `for_save` in order to allow unresolved 

815 # instances to be compiled until a decision is taken in #25425. 

816 for_save = False 

817 

818 def __init__(self, value, output_field=None): 

819 """ 

820 Arguments: 

821 * value: the value this expression represents. The value will be 

822 added into the sql parameter list and properly quoted. 

823 

824 * output_field: an instance of the model field type that this 

825 expression will return, such as IntegerField() or CharField(). 

826 """ 

827 super().__init__(output_field=output_field) 

828 self.value = value 

829 

830 def __repr__(self): 

831 return f"{self.__class__.__name__}({self.value!r})" 

832 

833 def as_sql(self, compiler, connection): 

834 connection.ops.check_expression_support(self) 

835 val = self.value 

836 output_field = self._output_field_or_none 

837 if output_field is not None: 

838 if self.for_save: 

839 val = output_field.get_db_prep_save(val, connection=connection) 

840 else: 

841 val = output_field.get_db_prep_value(val, connection=connection) 

842 if hasattr(output_field, "get_placeholder"): 

843 return output_field.get_placeholder(val, compiler, connection), [val] 

844 if val is None: 

845 # cx_Oracle does not always convert None to the appropriate 

846 # NULL type (like in case expressions using numbers), so we 

847 # use a literal SQL NULL 

848 return "NULL", [] 

849 return "%s", [val] 

850 

851 def resolve_expression( 

852 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

853 ): 

854 c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) 

855 c.for_save = for_save 

856 return c 

857 

858 def get_group_by_cols(self, alias=None): 

859 return [] 

860 

861 def _resolve_output_field(self): 

862 if isinstance(self.value, str): 

863 return fields.CharField() 

864 if isinstance(self.value, bool): 

865 return fields.BooleanField() 

866 if isinstance(self.value, int): 

867 return fields.IntegerField() 

868 if isinstance(self.value, float): 

869 return fields.FloatField() 

870 if isinstance(self.value, datetime.datetime): 

871 return fields.DateTimeField() 

872 if isinstance(self.value, datetime.date): 

873 return fields.DateField() 

874 if isinstance(self.value, datetime.time): 

875 return fields.TimeField() 

876 if isinstance(self.value, datetime.timedelta): 

877 return fields.DurationField() 

878 if isinstance(self.value, Decimal): 

879 return fields.DecimalField() 

880 if isinstance(self.value, bytes): 

881 return fields.BinaryField() 

882 if isinstance(self.value, UUID): 

883 return fields.UUIDField() 

884 

885 @property 

886 def empty_result_set_value(self): 

887 return self.value 

888 

889 

890class RawSQL(Expression): 

891 def __init__(self, sql, params, output_field=None): 

892 if output_field is None: 892 ↛ 894line 892 didn't jump to line 894, because the condition on line 892 was never false

893 output_field = fields.Field() 

894 self.sql, self.params = sql, params 

895 super().__init__(output_field=output_field) 

896 

897 def __repr__(self): 

898 return "{}({}, {})".format(self.__class__.__name__, self.sql, self.params) 

899 

900 def as_sql(self, compiler, connection): 

901 return "(%s)" % self.sql, self.params 

902 

903 def get_group_by_cols(self, alias=None): 

904 return [self] 

905 

906 def resolve_expression( 

907 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

908 ): 

909 # Resolve parents fields used in raw SQL. 

910 for parent in query.model._meta.get_parent_list(): 

911 for parent_field in parent._meta.local_fields: 

912 _, column_name = parent_field.get_attname_column() 

913 if column_name.lower() in self.sql.lower(): 

914 query.resolve_ref(parent_field.name, allow_joins, reuse, summarize) 

915 break 

916 return super().resolve_expression( 

917 query, allow_joins, reuse, summarize, for_save 

918 ) 

919 

920 

921class Star(Expression): 

922 def __repr__(self): 

923 return "'*'" 

924 

925 def as_sql(self, compiler, connection): 

926 return "*", [] 

927 

928 

929class Col(Expression): 

930 

931 contains_column_references = True 

932 possibly_multivalued = False 

933 

934 def __init__(self, alias, target, output_field=None): 

935 if output_field is None: 

936 output_field = target 

937 super().__init__(output_field=output_field) 

938 self.alias, self.target = alias, target 

939 

940 def __repr__(self): 

941 alias, target = self.alias, self.target 

942 identifiers = (alias, str(target)) if alias else (str(target),) 

943 return "{}({})".format(self.__class__.__name__, ", ".join(identifiers)) 

944 

945 def as_sql(self, compiler, connection): 

946 alias, column = self.alias, self.target.column 

947 identifiers = (alias, column) if alias else (column,) 

948 sql = ".".join(map(compiler.quote_name_unless_alias, identifiers)) 

949 return sql, [] 

950 

951 def relabeled_clone(self, relabels): 

952 if self.alias is None: 952 ↛ 953line 952 didn't jump to line 953, because the condition on line 952 was never true

953 return self 

954 return self.__class__( 

955 relabels.get(self.alias, self.alias), self.target, self.output_field 

956 ) 

957 

958 def get_group_by_cols(self, alias=None): 

959 return [self] 

960 

961 def get_db_converters(self, connection): 

962 if self.target == self.output_field: 

963 return self.output_field.get_db_converters(connection) 

964 return self.output_field.get_db_converters( 

965 connection 

966 ) + self.target.get_db_converters(connection) 

967 

968 

969class Ref(Expression): 

970 """ 

971 Reference to column alias of the query. For example, Ref('sum_cost') in 

972 qs.annotate(sum_cost=Sum('cost')) query. 

973 """ 

974 

975 def __init__(self, refs, source): 

976 super().__init__() 

977 self.refs, self.source = refs, source 

978 

979 def __repr__(self): 

980 return "{}({}, {})".format(self.__class__.__name__, self.refs, self.source) 

981 

982 def get_source_expressions(self): 

983 return [self.source] 

984 

985 def set_source_expressions(self, exprs): 

986 (self.source,) = exprs 

987 

988 def resolve_expression( 

989 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

990 ): 

991 # The sub-expression `source` has already been resolved, as this is 

992 # just a reference to the name of `source`. 

993 return self 

994 

995 def relabeled_clone(self, relabels): 

996 return self 

997 

998 def as_sql(self, compiler, connection): 

999 return connection.ops.quote_name(self.refs), [] 

1000 

1001 def get_group_by_cols(self, alias=None): 

1002 return [self] 

1003 

1004 

1005class ExpressionList(Func): 

1006 """ 

1007 An expression containing multiple expressions. Can be used to provide a 

1008 list of expressions as an argument to another expression, like an 

1009 ordering clause. 

1010 """ 

1011 

1012 template = "%(expressions)s" 

1013 

1014 def __init__(self, *expressions, **extra): 

1015 if not expressions: 

1016 raise ValueError( 

1017 "%s requires at least one expression." % self.__class__.__name__ 

1018 ) 

1019 super().__init__(*expressions, **extra) 

1020 

1021 def __str__(self): 

1022 return self.arg_joiner.join(str(arg) for arg in self.source_expressions) 

1023 

1024 def as_sqlite(self, compiler, connection, **extra_context): 

1025 # Casting to numeric is unnecessary. 

1026 return self.as_sql(compiler, connection, **extra_context) 

1027 

1028 

1029class ExpressionWrapper(Expression): 

1030 """ 

1031 An expression that can wrap another expression so that it can provide 

1032 extra context to the inner expression, such as the output_field. 

1033 """ 

1034 

1035 def __init__(self, expression, output_field): 

1036 super().__init__(output_field=output_field) 

1037 self.expression = expression 

1038 

1039 def set_source_expressions(self, exprs): 

1040 self.expression = exprs[0] 

1041 

1042 def get_source_expressions(self): 

1043 return [self.expression] 

1044 

1045 def get_group_by_cols(self, alias=None): 

1046 if isinstance(self.expression, Expression): 

1047 expression = self.expression.copy() 

1048 expression.output_field = self.output_field 

1049 return expression.get_group_by_cols(alias=alias) 

1050 # For non-expressions e.g. an SQL WHERE clause, the entire 

1051 # `expression` must be included in the GROUP BY clause. 

1052 return super().get_group_by_cols() 

1053 

1054 def as_sql(self, compiler, connection): 

1055 return compiler.compile(self.expression) 

1056 

1057 def __repr__(self): 

1058 return "{}({})".format(self.__class__.__name__, self.expression) 

1059 

1060 

1061class When(Expression): 

1062 template = "WHEN %(condition)s THEN %(result)s" 

1063 # This isn't a complete conditional expression, must be used in Case(). 

1064 conditional = False 

1065 

1066 def __init__(self, condition=None, then=None, **lookups): 

1067 if lookups: 

1068 if condition is None: 

1069 condition, lookups = Q(**lookups), None 

1070 elif getattr(condition, "conditional", False): 

1071 condition, lookups = Q(condition, **lookups), None 

1072 if condition is None or not getattr(condition, "conditional", False) or lookups: 

1073 raise TypeError( 

1074 "When() supports a Q object, a boolean expression, or lookups " 

1075 "as a condition." 

1076 ) 

1077 if isinstance(condition, Q) and not condition: 

1078 raise ValueError("An empty Q() can't be used as a When() condition.") 

1079 super().__init__(output_field=None) 

1080 self.condition = condition 

1081 self.result = self._parse_expressions(then)[0] 

1082 

1083 def __str__(self): 

1084 return "WHEN %r THEN %r" % (self.condition, self.result) 

1085 

1086 def __repr__(self): 

1087 return "<%s: %s>" % (self.__class__.__name__, self) 

1088 

1089 def get_source_expressions(self): 

1090 return [self.condition, self.result] 

1091 

1092 def set_source_expressions(self, exprs): 

1093 self.condition, self.result = exprs 

1094 

1095 def get_source_fields(self): 

1096 # We're only interested in the fields of the result expressions. 

1097 return [self.result._output_field_or_none] 

1098 

1099 def resolve_expression( 

1100 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

1101 ): 

1102 c = self.copy() 

1103 c.is_summary = summarize 

1104 if hasattr(c.condition, "resolve_expression"): 

1105 c.condition = c.condition.resolve_expression( 

1106 query, allow_joins, reuse, summarize, False 

1107 ) 

1108 c.result = c.result.resolve_expression( 

1109 query, allow_joins, reuse, summarize, for_save 

1110 ) 

1111 return c 

1112 

1113 def as_sql(self, compiler, connection, template=None, **extra_context): 

1114 connection.ops.check_expression_support(self) 

1115 template_params = extra_context 

1116 sql_params = [] 

1117 condition_sql, condition_params = compiler.compile(self.condition) 

1118 template_params["condition"] = condition_sql 

1119 sql_params.extend(condition_params) 

1120 result_sql, result_params = compiler.compile(self.result) 

1121 template_params["result"] = result_sql 

1122 sql_params.extend(result_params) 

1123 template = template or self.template 

1124 return template % template_params, sql_params 

1125 

1126 def get_group_by_cols(self, alias=None): 

1127 # This is not a complete expression and cannot be used in GROUP BY. 

1128 cols = [] 

1129 for source in self.get_source_expressions(): 

1130 cols.extend(source.get_group_by_cols()) 

1131 return cols 

1132 

1133 

1134class Case(Expression): 

1135 """ 

1136 An SQL searched CASE expression: 

1137 

1138 CASE 

1139 WHEN n > 0 

1140 THEN 'positive' 

1141 WHEN n < 0 

1142 THEN 'negative' 

1143 ELSE 'zero' 

1144 END 

1145 """ 

1146 

1147 template = "CASE %(cases)s ELSE %(default)s END" 

1148 case_joiner = " " 

1149 

1150 def __init__(self, *cases, default=None, output_field=None, **extra): 

1151 if not all(isinstance(case, When) for case in cases): 

1152 raise TypeError("Positional arguments must all be When objects.") 

1153 super().__init__(output_field) 

1154 self.cases = list(cases) 

1155 self.default = self._parse_expressions(default)[0] 

1156 self.extra = extra 

1157 

1158 def __str__(self): 

1159 return "CASE %s, ELSE %r" % ( 

1160 ", ".join(str(c) for c in self.cases), 

1161 self.default, 

1162 ) 

1163 

1164 def __repr__(self): 

1165 return "<%s: %s>" % (self.__class__.__name__, self) 

1166 

1167 def get_source_expressions(self): 

1168 return self.cases + [self.default] 

1169 

1170 def set_source_expressions(self, exprs): 

1171 *self.cases, self.default = exprs 

1172 

1173 def resolve_expression( 

1174 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

1175 ): 

1176 c = self.copy() 

1177 c.is_summary = summarize 

1178 for pos, case in enumerate(c.cases): 

1179 c.cases[pos] = case.resolve_expression( 

1180 query, allow_joins, reuse, summarize, for_save 

1181 ) 

1182 c.default = c.default.resolve_expression( 

1183 query, allow_joins, reuse, summarize, for_save 

1184 ) 

1185 return c 

1186 

1187 def copy(self): 

1188 c = super().copy() 

1189 c.cases = c.cases[:] 

1190 return c 

1191 

1192 def as_sql( 

1193 self, compiler, connection, template=None, case_joiner=None, **extra_context 

1194 ): 

1195 connection.ops.check_expression_support(self) 

1196 if not self.cases: 

1197 return compiler.compile(self.default) 

1198 template_params = {**self.extra, **extra_context} 

1199 case_parts = [] 

1200 sql_params = [] 

1201 for case in self.cases: 

1202 try: 

1203 case_sql, case_params = compiler.compile(case) 

1204 except EmptyResultSet: 

1205 continue 

1206 case_parts.append(case_sql) 

1207 sql_params.extend(case_params) 

1208 default_sql, default_params = compiler.compile(self.default) 

1209 if not case_parts: 

1210 return default_sql, default_params 

1211 case_joiner = case_joiner or self.case_joiner 

1212 template_params["cases"] = case_joiner.join(case_parts) 

1213 template_params["default"] = default_sql 

1214 sql_params.extend(default_params) 

1215 template = template or template_params.get("template", self.template) 

1216 sql = template % template_params 

1217 if self._output_field_or_none is not None: 

1218 sql = connection.ops.unification_cast_sql(self.output_field) % sql 

1219 return sql, sql_params 

1220 

1221 def get_group_by_cols(self, alias=None): 

1222 if not self.cases: 

1223 return self.default.get_group_by_cols(alias) 

1224 return super().get_group_by_cols(alias) 

1225 

1226 

1227class Subquery(BaseExpression, Combinable): 

1228 """ 

1229 An explicit subquery. It may contain OuterRef() references to the outer 

1230 query which will be resolved when it is applied to that query. 

1231 """ 

1232 

1233 template = "(%(subquery)s)" 

1234 contains_aggregate = False 

1235 empty_result_set_value = None 

1236 

1237 def __init__(self, queryset, output_field=None, **extra): 

1238 # Allow the usage of both QuerySet and sql.Query objects. 

1239 self.query = getattr(queryset, "query", queryset) 

1240 self.extra = extra 

1241 super().__init__(output_field) 

1242 

1243 def get_source_expressions(self): 

1244 return [self.query] 

1245 

1246 def set_source_expressions(self, exprs): 

1247 self.query = exprs[0] 

1248 

1249 def _resolve_output_field(self): 

1250 return self.query.output_field 

1251 

1252 def copy(self): 

1253 clone = super().copy() 

1254 clone.query = clone.query.clone() 

1255 return clone 

1256 

1257 @property 

1258 def external_aliases(self): 

1259 return self.query.external_aliases 

1260 

1261 def get_external_cols(self): 

1262 return self.query.get_external_cols() 

1263 

1264 def as_sql(self, compiler, connection, template=None, query=None, **extra_context): 

1265 connection.ops.check_expression_support(self) 

1266 template_params = {**self.extra, **extra_context} 

1267 query = query or self.query 

1268 subquery_sql, sql_params = query.as_sql(compiler, connection) 

1269 template_params["subquery"] = subquery_sql[1:-1] 

1270 

1271 template = template or template_params.get("template", self.template) 

1272 sql = template % template_params 

1273 return sql, sql_params 

1274 

1275 def get_group_by_cols(self, alias=None): 

1276 if alias: 

1277 return [Ref(alias, self)] 

1278 external_cols = self.get_external_cols() 

1279 if any(col.possibly_multivalued for col in external_cols): 

1280 return [self] 

1281 return external_cols 

1282 

1283 

1284class Exists(Subquery): 

1285 template = "EXISTS(%(subquery)s)" 

1286 output_field = fields.BooleanField() 

1287 

1288 def __init__(self, queryset, negated=False, **kwargs): 

1289 self.negated = negated 

1290 super().__init__(queryset, **kwargs) 

1291 

1292 def __invert__(self): 

1293 clone = self.copy() 

1294 clone.negated = not self.negated 

1295 return clone 

1296 

1297 def as_sql(self, compiler, connection, template=None, **extra_context): 

1298 query = self.query.exists(using=connection.alias) 

1299 sql, params = super().as_sql( 

1300 compiler, 

1301 connection, 

1302 template=template, 

1303 query=query, 

1304 **extra_context, 

1305 ) 

1306 if self.negated: 

1307 sql = "NOT {}".format(sql) 

1308 return sql, params 

1309 

1310 def select_format(self, compiler, sql, params): 

1311 # Wrap EXISTS() with a CASE WHEN expression if a database backend 

1312 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP 

1313 # BY list. 

1314 if not compiler.connection.features.supports_boolean_expr_in_select_clause: 

1315 sql = "CASE WHEN {} THEN 1 ELSE 0 END".format(sql) 

1316 return sql, params 

1317 

1318 

1319class OrderBy(Expression): 

1320 template = "%(expression)s %(ordering)s" 

1321 conditional = False 

1322 

1323 def __init__( 

1324 self, expression, descending=False, nulls_first=False, nulls_last=False 

1325 ): 

1326 if nulls_first and nulls_last: 1326 ↛ 1327line 1326 didn't jump to line 1327, because the condition on line 1326 was never true

1327 raise ValueError("nulls_first and nulls_last are mutually exclusive") 

1328 self.nulls_first = nulls_first 

1329 self.nulls_last = nulls_last 

1330 self.descending = descending 

1331 if not hasattr(expression, "resolve_expression"): 1331 ↛ 1332line 1331 didn't jump to line 1332, because the condition on line 1331 was never true

1332 raise ValueError("expression must be an expression type") 

1333 self.expression = expression 

1334 

1335 def __repr__(self): 

1336 return "{}({}, descending={})".format( 

1337 self.__class__.__name__, self.expression, self.descending 

1338 ) 

1339 

1340 def set_source_expressions(self, exprs): 

1341 self.expression = exprs[0] 

1342 

1343 def get_source_expressions(self): 

1344 return [self.expression] 

1345 

1346 def as_sql(self, compiler, connection, template=None, **extra_context): 

1347 template = template or self.template 

1348 if connection.features.supports_order_by_nulls_modifier: 1348 ↛ 1354line 1348 didn't jump to line 1354, because the condition on line 1348 was never false

1349 if self.nulls_last: 1349 ↛ 1350line 1349 didn't jump to line 1350, because the condition on line 1349 was never true

1350 template = "%s NULLS LAST" % template 

1351 elif self.nulls_first: 1351 ↛ 1352line 1351 didn't jump to line 1352, because the condition on line 1351 was never true

1352 template = "%s NULLS FIRST" % template 

1353 else: 

1354 if self.nulls_last and not ( 

1355 self.descending and connection.features.order_by_nulls_first 

1356 ): 

1357 template = "%%(expression)s IS NULL, %s" % template 

1358 elif self.nulls_first and not ( 

1359 not self.descending and connection.features.order_by_nulls_first 

1360 ): 

1361 template = "%%(expression)s IS NOT NULL, %s" % template 

1362 connection.ops.check_expression_support(self) 

1363 expression_sql, params = compiler.compile(self.expression) 

1364 placeholders = { 

1365 "expression": expression_sql, 

1366 "ordering": "DESC" if self.descending else "ASC", 

1367 **extra_context, 

1368 } 

1369 params *= template.count("%(expression)s") 

1370 return (template % placeholders).rstrip(), params 

1371 

1372 def as_oracle(self, compiler, connection): 

1373 # Oracle doesn't allow ORDER BY EXISTS() or filters unless it's wrapped 

1374 # in a CASE WHEN. 

1375 if connection.ops.conditional_expression_supported_in_where_clause( 

1376 self.expression 

1377 ): 

1378 copy = self.copy() 

1379 copy.expression = Case( 

1380 When(self.expression, then=True), 

1381 default=False, 

1382 ) 

1383 return copy.as_sql(compiler, connection) 

1384 return self.as_sql(compiler, connection) 

1385 

1386 def get_group_by_cols(self, alias=None): 

1387 cols = [] 

1388 for source in self.get_source_expressions(): 

1389 cols.extend(source.get_group_by_cols()) 

1390 return cols 

1391 

1392 def reverse_ordering(self): 

1393 self.descending = not self.descending 

1394 if self.nulls_first or self.nulls_last: 

1395 self.nulls_first = not self.nulls_first 

1396 self.nulls_last = not self.nulls_last 

1397 return self 

1398 

1399 def asc(self): 

1400 self.descending = False 

1401 

1402 def desc(self): 

1403 self.descending = True 

1404 

1405 

1406class Window(SQLiteNumericMixin, Expression): 

1407 template = "%(expression)s OVER (%(window)s)" 

1408 # Although the main expression may either be an aggregate or an 

1409 # expression with an aggregate function, the GROUP BY that will 

1410 # be introduced in the query as a result is not desired. 

1411 contains_aggregate = False 

1412 contains_over_clause = True 

1413 filterable = False 

1414 

1415 def __init__( 

1416 self, 

1417 expression, 

1418 partition_by=None, 

1419 order_by=None, 

1420 frame=None, 

1421 output_field=None, 

1422 ): 

1423 self.partition_by = partition_by 

1424 self.order_by = order_by 

1425 self.frame = frame 

1426 

1427 if not getattr(expression, "window_compatible", False): 

1428 raise ValueError( 

1429 "Expression '%s' isn't compatible with OVER clauses." 

1430 % expression.__class__.__name__ 

1431 ) 

1432 

1433 if self.partition_by is not None: 

1434 if not isinstance(self.partition_by, (tuple, list)): 

1435 self.partition_by = (self.partition_by,) 

1436 self.partition_by = ExpressionList(*self.partition_by) 

1437 

1438 if self.order_by is not None: 

1439 if isinstance(self.order_by, (list, tuple)): 

1440 self.order_by = ExpressionList(*self.order_by) 

1441 elif not isinstance(self.order_by, BaseExpression): 

1442 raise ValueError( 

1443 "order_by must be either an Expression or a sequence of " 

1444 "expressions." 

1445 ) 

1446 super().__init__(output_field=output_field) 

1447 self.source_expression = self._parse_expressions(expression)[0] 

1448 

1449 def _resolve_output_field(self): 

1450 return self.source_expression.output_field 

1451 

1452 def get_source_expressions(self): 

1453 return [self.source_expression, self.partition_by, self.order_by, self.frame] 

1454 

1455 def set_source_expressions(self, exprs): 

1456 self.source_expression, self.partition_by, self.order_by, self.frame = exprs 

1457 

1458 def as_sql(self, compiler, connection, template=None): 

1459 connection.ops.check_expression_support(self) 

1460 if not connection.features.supports_over_clause: 

1461 raise NotSupportedError("This backend does not support window expressions.") 

1462 expr_sql, params = compiler.compile(self.source_expression) 

1463 window_sql, window_params = [], [] 

1464 

1465 if self.partition_by is not None: 

1466 sql_expr, sql_params = self.partition_by.as_sql( 

1467 compiler=compiler, 

1468 connection=connection, 

1469 template="PARTITION BY %(expressions)s", 

1470 ) 

1471 window_sql.extend(sql_expr) 

1472 window_params.extend(sql_params) 

1473 

1474 if self.order_by is not None: 

1475 window_sql.append(" ORDER BY ") 

1476 order_sql, order_params = compiler.compile(self.order_by) 

1477 window_sql.extend(order_sql) 

1478 window_params.extend(order_params) 

1479 

1480 if self.frame: 

1481 frame_sql, frame_params = compiler.compile(self.frame) 

1482 window_sql.append(" " + frame_sql) 

1483 window_params.extend(frame_params) 

1484 

1485 params.extend(window_params) 

1486 template = template or self.template 

1487 

1488 return ( 

1489 template % {"expression": expr_sql, "window": "".join(window_sql).strip()}, 

1490 params, 

1491 ) 

1492 

1493 def as_sqlite(self, compiler, connection): 

1494 if isinstance(self.output_field, fields.DecimalField): 

1495 # Casting to numeric must be outside of the window expression. 

1496 copy = self.copy() 

1497 source_expressions = copy.get_source_expressions() 

1498 source_expressions[0].output_field = fields.FloatField() 

1499 copy.set_source_expressions(source_expressions) 

1500 return super(Window, copy).as_sqlite(compiler, connection) 

1501 return self.as_sql(compiler, connection) 

1502 

1503 def __str__(self): 

1504 return "{} OVER ({}{}{})".format( 

1505 str(self.source_expression), 

1506 "PARTITION BY " + str(self.partition_by) if self.partition_by else "", 

1507 "ORDER BY " + str(self.order_by) if self.order_by else "", 

1508 str(self.frame or ""), 

1509 ) 

1510 

1511 def __repr__(self): 

1512 return "<%s: %s>" % (self.__class__.__name__, self) 

1513 

1514 def get_group_by_cols(self, alias=None): 

1515 return [] 

1516 

1517 

1518class WindowFrame(Expression): 

1519 """ 

1520 Model the frame clause in window expressions. There are two types of frame 

1521 clauses which are subclasses, however, all processing and validation (by no 

1522 means intended to be complete) is done here. Thus, providing an end for a 

1523 frame is optional (the default is UNBOUNDED FOLLOWING, which is the last 

1524 row in the frame). 

1525 """ 

1526 

1527 template = "%(frame_type)s BETWEEN %(start)s AND %(end)s" 

1528 

1529 def __init__(self, start=None, end=None): 

1530 self.start = Value(start) 

1531 self.end = Value(end) 

1532 

1533 def set_source_expressions(self, exprs): 

1534 self.start, self.end = exprs 

1535 

1536 def get_source_expressions(self): 

1537 return [self.start, self.end] 

1538 

1539 def as_sql(self, compiler, connection): 

1540 connection.ops.check_expression_support(self) 

1541 start, end = self.window_frame_start_end( 

1542 connection, self.start.value, self.end.value 

1543 ) 

1544 return ( 

1545 self.template 

1546 % { 

1547 "frame_type": self.frame_type, 

1548 "start": start, 

1549 "end": end, 

1550 }, 

1551 [], 

1552 ) 

1553 

1554 def __repr__(self): 

1555 return "<%s: %s>" % (self.__class__.__name__, self) 

1556 

1557 def get_group_by_cols(self, alias=None): 

1558 return [] 

1559 

1560 def __str__(self): 

1561 if self.start.value is not None and self.start.value < 0: 

1562 start = "%d %s" % (abs(self.start.value), connection.ops.PRECEDING) 

1563 elif self.start.value is not None and self.start.value == 0: 

1564 start = connection.ops.CURRENT_ROW 

1565 else: 

1566 start = connection.ops.UNBOUNDED_PRECEDING 

1567 

1568 if self.end.value is not None and self.end.value > 0: 

1569 end = "%d %s" % (self.end.value, connection.ops.FOLLOWING) 

1570 elif self.end.value is not None and self.end.value == 0: 

1571 end = connection.ops.CURRENT_ROW 

1572 else: 

1573 end = connection.ops.UNBOUNDED_FOLLOWING 

1574 return self.template % { 

1575 "frame_type": self.frame_type, 

1576 "start": start, 

1577 "end": end, 

1578 } 

1579 

1580 def window_frame_start_end(self, connection, start, end): 

1581 raise NotImplementedError("Subclasses must implement window_frame_start_end().") 

1582 

1583 

1584class RowRange(WindowFrame): 

1585 frame_type = "ROWS" 

1586 

1587 def window_frame_start_end(self, connection, start, end): 

1588 return connection.ops.window_frame_rows_start_end(start, end) 

1589 

1590 

1591class ValueRange(WindowFrame): 

1592 frame_type = "RANGE" 

1593 

1594 def window_frame_start_end(self, connection, start, end): 

1595 return connection.ops.window_frame_range_start_end(start, end)