Coverage for /var/srv/projects/api.amasfac.comuna18.com/tmp/venv/lib/python3.9/site-packages/django/test/utils.py: 40%

516 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2023-07-17 14:22 -0600

1import asyncio 

2import collections 

3import logging 

4import os 

5import re 

6import sys 

7import time 

8import warnings 

9from contextlib import contextmanager 

10from functools import wraps 

11from io import StringIO 

12from itertools import chain 

13from types import SimpleNamespace 

14from unittest import TestCase, skipIf, skipUnless 

15from xml.dom.minidom import Node, parseString 

16 

17from django.apps import apps 

18from django.apps.registry import Apps 

19from django.conf import UserSettingsHolder, settings 

20from django.core import mail 

21from django.core.exceptions import ImproperlyConfigured 

22from django.core.signals import request_started 

23from django.db import DEFAULT_DB_ALIAS, connections, reset_queries 

24from django.db.models.options import Options 

25from django.template import Template 

26from django.test.signals import setting_changed, template_rendered 

27from django.urls import get_script_prefix, set_script_prefix 

28from django.utils.deprecation import RemovedInDjango50Warning 

29from django.utils.translation import deactivate 

30 

31try: 

32 import jinja2 

33except ImportError: 

34 jinja2 = None 

35 

36 

37__all__ = ( 

38 "Approximate", 

39 "ContextList", 

40 "isolate_lru_cache", 

41 "get_runner", 

42 "CaptureQueriesContext", 

43 "ignore_warnings", 

44 "isolate_apps", 

45 "modify_settings", 

46 "override_settings", 

47 "override_system_checks", 

48 "tag", 

49 "requires_tz_support", 

50 "setup_databases", 

51 "setup_test_environment", 

52 "teardown_test_environment", 

53) 

54 

55TZ_SUPPORT = hasattr(time, "tzset") 

56 

57 

58class Approximate: 

59 def __init__(self, val, places=7): 

60 self.val = val 

61 self.places = places 

62 

63 def __repr__(self): 

64 return repr(self.val) 

65 

66 def __eq__(self, other): 

67 return self.val == other or round(abs(self.val - other), self.places) == 0 

68 

69 

70class ContextList(list): 

71 """ 

72 A wrapper that provides direct key access to context items contained 

73 in a list of context objects. 

74 """ 

75 

76 def __getitem__(self, key): 

77 if isinstance(key, str): 

78 for subcontext in self: 

79 if key in subcontext: 

80 return subcontext[key] 

81 raise KeyError(key) 

82 else: 

83 return super().__getitem__(key) 

84 

85 def get(self, key, default=None): 

86 try: 

87 return self.__getitem__(key) 

88 except KeyError: 

89 return default 

90 

91 def __contains__(self, key): 

92 try: 

93 self[key] 

94 except KeyError: 

95 return False 

96 return True 

97 

98 def keys(self): 

99 """ 

100 Flattened keys of subcontexts. 

101 """ 

102 return set(chain.from_iterable(d for subcontext in self for d in subcontext)) 

103 

104 

105def instrumented_test_render(self, context): 

106 """ 

107 An instrumented Template render method, providing a signal that can be 

108 intercepted by the test Client. 

109 """ 

110 template_rendered.send(sender=self, template=self, context=context) 

111 return self.nodelist.render(context) 

112 

113 

114class _TestState: 

115 pass 

116 

117 

118def setup_test_environment(debug=None): 

119 """ 

120 Perform global pre-test setup, such as installing the instrumented template 

121 renderer and setting the email backend to the locmem email backend. 

122 """ 

123 if hasattr(_TestState, "saved_data"): 123 ↛ 125line 123 didn't jump to line 125, because the condition on line 123 was never true

124 # Executing this function twice would overwrite the saved values. 

125 raise RuntimeError( 

126 "setup_test_environment() was already called and can't be called " 

127 "again without first calling teardown_test_environment()." 

128 ) 

129 

130 if debug is None: 130 ↛ 131line 130 didn't jump to line 131, because the condition on line 130 was never true

131 debug = settings.DEBUG 

132 

133 saved_data = SimpleNamespace() 

134 _TestState.saved_data = saved_data 

135 

136 saved_data.allowed_hosts = settings.ALLOWED_HOSTS 

137 # Add the default host of the test client. 

138 settings.ALLOWED_HOSTS = [*settings.ALLOWED_HOSTS, "testserver"] 

139 

140 saved_data.debug = settings.DEBUG 

141 settings.DEBUG = debug 

142 

143 saved_data.email_backend = settings.EMAIL_BACKEND 

144 settings.EMAIL_BACKEND = "django.core.mail.backends.locmem.EmailBackend" 

145 

146 saved_data.template_render = Template._render 

147 Template._render = instrumented_test_render 

148 

149 mail.outbox = [] 

150 

151 deactivate() 

152 

153 

154def teardown_test_environment(): 

155 """ 

156 Perform any global post-test teardown, such as restoring the original 

157 template renderer and restoring the email sending functions. 

158 """ 

159 saved_data = _TestState.saved_data 

160 

161 settings.ALLOWED_HOSTS = saved_data.allowed_hosts 

162 settings.DEBUG = saved_data.debug 

163 settings.EMAIL_BACKEND = saved_data.email_backend 

164 Template._render = saved_data.template_render 

165 

166 del _TestState.saved_data 

167 del mail.outbox 

168 

169 

170def setup_databases( 

171 verbosity, 

172 interactive, 

173 *, 

174 time_keeper=None, 

175 keepdb=False, 

176 debug_sql=False, 

177 parallel=0, 

178 aliases=None, 

179 serialized_aliases=None, 

180 **kwargs, 

181): 

182 """Create the test databases.""" 

183 if time_keeper is None: 183 ↛ 184line 183 didn't jump to line 184, because the condition on line 183 was never true

184 time_keeper = NullTimeKeeper() 

185 

186 test_databases, mirrored_aliases = get_unique_databases_and_mirrors(aliases) 

187 

188 old_names = [] 

189 

190 for db_name, aliases in test_databases.values(): 

191 first_alias = None 

192 for alias in aliases: 

193 connection = connections[alias] 

194 old_names.append((connection, db_name, first_alias is None)) 

195 

196 # Actually create the database for the first connection 

197 if first_alias is None: 197 ↛ 236line 197 didn't jump to line 236, because the condition on line 197 was never false

198 first_alias = alias 

199 with time_keeper.timed(" Creating '%s'" % alias): 

200 # RemovedInDjango50Warning: when the deprecation ends, 

201 # replace with: 

202 # serialize_alias = ( 

203 # serialized_aliases is None 

204 # or alias in serialized_aliases 

205 # ) 

206 try: 

207 serialize_alias = connection.settings_dict["TEST"]["SERIALIZE"] 

208 except KeyError: 

209 serialize_alias = ( 

210 serialized_aliases is None or alias in serialized_aliases 

211 ) 

212 else: 

213 warnings.warn( 

214 "The SERIALIZE test database setting is " 

215 "deprecated as it can be inferred from the " 

216 "TestCase/TransactionTestCase.databases that " 

217 "enable the serialized_rollback feature.", 

218 category=RemovedInDjango50Warning, 

219 ) 

220 connection.creation.create_test_db( 

221 verbosity=verbosity, 

222 autoclobber=not interactive, 

223 keepdb=keepdb, 

224 serialize=serialize_alias, 

225 ) 

226 if parallel > 1: 226 ↛ 227line 226 didn't jump to line 227, because the condition on line 226 was never true

227 for index in range(parallel): 

228 with time_keeper.timed(" Cloning '%s'" % alias): 

229 connection.creation.clone_test_db( 

230 suffix=str(index + 1), 

231 verbosity=verbosity, 

232 keepdb=keepdb, 

233 ) 

234 # Configure all other connections as mirrors of the first one 

235 else: 

236 connections[alias].creation.set_as_test_mirror( 

237 connections[first_alias].settings_dict 

238 ) 

239 

240 # Configure the test mirrors. 

241 for alias, mirror_alias in mirrored_aliases.items(): 241 ↛ 242line 241 didn't jump to line 242, because the loop on line 241 never started

242 connections[alias].creation.set_as_test_mirror( 

243 connections[mirror_alias].settings_dict 

244 ) 

245 

246 if debug_sql: 246 ↛ 247line 246 didn't jump to line 247, because the condition on line 246 was never true

247 for alias in connections: 

248 connections[alias].force_debug_cursor = True 

249 

250 return old_names 

251 

252 

253def iter_test_cases(tests): 

254 """ 

255 Return an iterator over a test suite's unittest.TestCase objects. 

256 

257 The tests argument can also be an iterable of TestCase objects. 

258 """ 

259 for test in tests: 

260 if isinstance(test, str): 260 ↛ 263line 260 didn't jump to line 263, because the condition on line 260 was never true

261 # Prevent an unfriendly RecursionError that can happen with 

262 # strings. 

263 raise TypeError( 

264 f"Test {test!r} must be a test case or test suite not string " 

265 f"(was found in {tests!r})." 

266 ) 

267 if isinstance(test, TestCase): 

268 yield test 

269 else: 

270 # Otherwise, assume it is a test suite. 

271 yield from iter_test_cases(test) 

272 

273 

274def dependency_ordered(test_databases, dependencies): 

275 """ 

276 Reorder test_databases into an order that honors the dependencies 

277 described in TEST[DEPENDENCIES]. 

278 """ 

279 ordered_test_databases = [] 

280 resolved_databases = set() 

281 

282 # Maps db signature to dependencies of all its aliases 

283 dependencies_map = {} 

284 

285 # Check that no database depends on its own alias 

286 for sig, (_, aliases) in test_databases: 

287 all_deps = set() 

288 for alias in aliases: 

289 all_deps.update(dependencies.get(alias, [])) 

290 if not all_deps.isdisjoint(aliases): 290 ↛ 291line 290 didn't jump to line 291, because the condition on line 290 was never true

291 raise ImproperlyConfigured( 

292 "Circular dependency: databases %r depend on each other, " 

293 "but are aliases." % aliases 

294 ) 

295 dependencies_map[sig] = all_deps 

296 

297 while test_databases: 

298 changed = False 

299 deferred = [] 

300 

301 # Try to find a DB that has all its dependencies met 

302 for signature, (db_name, aliases) in test_databases: 

303 if dependencies_map[signature].issubset(resolved_databases): 303 ↛ 308line 303 didn't jump to line 308, because the condition on line 303 was never false

304 resolved_databases.update(aliases) 

305 ordered_test_databases.append((signature, (db_name, aliases))) 

306 changed = True 

307 else: 

308 deferred.append((signature, (db_name, aliases))) 

309 

310 if not changed: 310 ↛ 311line 310 didn't jump to line 311, because the condition on line 310 was never true

311 raise ImproperlyConfigured("Circular dependency in TEST[DEPENDENCIES]") 

312 test_databases = deferred 

313 return ordered_test_databases 

314 

315 

316def get_unique_databases_and_mirrors(aliases=None): 

317 """ 

318 Figure out which databases actually need to be created. 

319 

320 Deduplicate entries in DATABASES that correspond the same database or are 

321 configured as test mirrors. 

322 

323 Return two values: 

324 - test_databases: ordered mapping of signatures to (name, list of aliases) 

325 where all aliases share the same underlying database. 

326 - mirrored_aliases: mapping of mirror aliases to original aliases. 

327 """ 

328 if aliases is None: 328 ↛ 329line 328 didn't jump to line 329, because the condition on line 328 was never true

329 aliases = connections 

330 mirrored_aliases = {} 

331 test_databases = {} 

332 dependencies = {} 

333 default_sig = connections[DEFAULT_DB_ALIAS].creation.test_db_signature() 

334 

335 for alias in connections: 

336 connection = connections[alias] 

337 test_settings = connection.settings_dict["TEST"] 

338 

339 if test_settings["MIRROR"]: 339 ↛ 341line 339 didn't jump to line 341, because the condition on line 339 was never true

340 # If the database is marked as a test mirror, save the alias. 

341 mirrored_aliases[alias] = test_settings["MIRROR"] 

342 elif alias in aliases: 342 ↛ 335line 342 didn't jump to line 335, because the condition on line 342 was never false

343 # Store a tuple with DB parameters that uniquely identify it. 

344 # If we have two aliases with the same values for that tuple, 

345 # we only need to create the test database once. 

346 item = test_databases.setdefault( 

347 connection.creation.test_db_signature(), 

348 (connection.settings_dict["NAME"], []), 

349 ) 

350 # The default database must be the first because data migrations 

351 # use the default alias by default. 

352 if alias == DEFAULT_DB_ALIAS: 352 ↛ 355line 352 didn't jump to line 355, because the condition on line 352 was never false

353 item[1].insert(0, alias) 

354 else: 

355 item[1].append(alias) 

356 

357 if "DEPENDENCIES" in test_settings: 357 ↛ 358line 357 didn't jump to line 358, because the condition on line 357 was never true

358 dependencies[alias] = test_settings["DEPENDENCIES"] 

359 else: 

360 if ( 360 ↛ 364line 360 didn't jump to line 364

361 alias != DEFAULT_DB_ALIAS 

362 and connection.creation.test_db_signature() != default_sig 

363 ): 

364 dependencies[alias] = test_settings.get( 

365 "DEPENDENCIES", [DEFAULT_DB_ALIAS] 

366 ) 

367 

368 test_databases = dict(dependency_ordered(test_databases.items(), dependencies)) 

369 return test_databases, mirrored_aliases 

370 

371 

372def teardown_databases(old_config, verbosity, parallel=0, keepdb=False): 

373 """Destroy all the non-mirror databases.""" 

374 for connection, old_name, destroy in old_config: 

375 if destroy: 375 ↛ 374line 375 didn't jump to line 374, because the condition on line 375 was never false

376 if parallel > 1: 376 ↛ 377line 376 didn't jump to line 377, because the condition on line 376 was never true

377 for index in range(parallel): 

378 connection.creation.destroy_test_db( 

379 suffix=str(index + 1), 

380 verbosity=verbosity, 

381 keepdb=keepdb, 

382 ) 

383 connection.creation.destroy_test_db(old_name, verbosity, keepdb) 

384 

385 

386def get_runner(settings, test_runner_class=None): 

387 test_runner_class = test_runner_class or settings.TEST_RUNNER 

388 test_path = test_runner_class.split(".") 

389 # Allow for relative paths 

390 if len(test_path) > 1: 390 ↛ 393line 390 didn't jump to line 393, because the condition on line 390 was never false

391 test_module_name = ".".join(test_path[:-1]) 

392 else: 

393 test_module_name = "." 

394 test_module = __import__(test_module_name, {}, {}, test_path[-1]) 

395 return getattr(test_module, test_path[-1]) 

396 

397 

398class TestContextDecorator: 

399 """ 

400 A base class that can either be used as a context manager during tests 

401 or as a test function or unittest.TestCase subclass decorator to perform 

402 temporary alterations. 

403 

404 `attr_name`: attribute assigned the return value of enable() if used as 

405 a class decorator. 

406 

407 `kwarg_name`: keyword argument passing the return value of enable() if 

408 used as a function decorator. 

409 """ 

410 

411 def __init__(self, attr_name=None, kwarg_name=None): 

412 self.attr_name = attr_name 

413 self.kwarg_name = kwarg_name 

414 

415 def enable(self): 

416 raise NotImplementedError 

417 

418 def disable(self): 

419 raise NotImplementedError 

420 

421 def __enter__(self): 

422 return self.enable() 

423 

424 def __exit__(self, exc_type, exc_value, traceback): 

425 self.disable() 

426 

427 def decorate_class(self, cls): 

428 if issubclass(cls, TestCase): 

429 decorated_setUp = cls.setUp 

430 

431 def setUp(inner_self): 

432 context = self.enable() 

433 inner_self.addCleanup(self.disable) 

434 if self.attr_name: 

435 setattr(inner_self, self.attr_name, context) 

436 decorated_setUp(inner_self) 

437 

438 cls.setUp = setUp 

439 return cls 

440 raise TypeError("Can only decorate subclasses of unittest.TestCase") 

441 

442 def decorate_callable(self, func): 

443 if asyncio.iscoroutinefunction(func): 

444 # If the inner function is an async function, we must execute async 

445 # as well so that the `with` statement executes at the right time. 

446 @wraps(func) 

447 async def inner(*args, **kwargs): 

448 with self as context: 

449 if self.kwarg_name: 

450 kwargs[self.kwarg_name] = context 

451 return await func(*args, **kwargs) 

452 

453 else: 

454 

455 @wraps(func) 

456 def inner(*args, **kwargs): 

457 with self as context: 

458 if self.kwarg_name: 

459 kwargs[self.kwarg_name] = context 

460 return func(*args, **kwargs) 

461 

462 return inner 

463 

464 def __call__(self, decorated): 

465 if isinstance(decorated, type): 

466 return self.decorate_class(decorated) 

467 elif callable(decorated): 

468 return self.decorate_callable(decorated) 

469 raise TypeError("Cannot decorate object of type %s" % type(decorated)) 

470 

471 

472class override_settings(TestContextDecorator): 

473 """ 

474 Act as either a decorator or a context manager. If it's a decorator, take a 

475 function and return a wrapped function. If it's a contextmanager, use it 

476 with the ``with`` statement. In either event, entering/exiting are called 

477 before and after, respectively, the function/block is executed. 

478 """ 

479 

480 enable_exception = None 

481 

482 def __init__(self, **kwargs): 

483 self.options = kwargs 

484 super().__init__() 

485 

486 def enable(self): 

487 # Keep this code at the beginning to leave the settings unchanged 

488 # in case it raises an exception because INSTALLED_APPS is invalid. 

489 if "INSTALLED_APPS" in self.options: 

490 try: 

491 apps.set_installed_apps(self.options["INSTALLED_APPS"]) 

492 except Exception: 

493 apps.unset_installed_apps() 

494 raise 

495 override = UserSettingsHolder(settings._wrapped) 

496 for key, new_value in self.options.items(): 

497 setattr(override, key, new_value) 

498 self.wrapped = settings._wrapped 

499 settings._wrapped = override 

500 for key, new_value in self.options.items(): 

501 try: 

502 setting_changed.send( 

503 sender=settings._wrapped.__class__, 

504 setting=key, 

505 value=new_value, 

506 enter=True, 

507 ) 

508 except Exception as exc: 

509 self.enable_exception = exc 

510 self.disable() 

511 

512 def disable(self): 

513 if "INSTALLED_APPS" in self.options: 

514 apps.unset_installed_apps() 

515 settings._wrapped = self.wrapped 

516 del self.wrapped 

517 responses = [] 

518 for key in self.options: 

519 new_value = getattr(settings, key, None) 

520 responses_for_setting = setting_changed.send_robust( 

521 sender=settings._wrapped.__class__, 

522 setting=key, 

523 value=new_value, 

524 enter=False, 

525 ) 

526 responses.extend(responses_for_setting) 

527 if self.enable_exception is not None: 

528 exc = self.enable_exception 

529 self.enable_exception = None 

530 raise exc 

531 for _, response in responses: 

532 if isinstance(response, Exception): 

533 raise response 

534 

535 def save_options(self, test_func): 

536 if test_func._overridden_settings is None: 

537 test_func._overridden_settings = self.options 

538 else: 

539 # Duplicate dict to prevent subclasses from altering their parent. 

540 test_func._overridden_settings = { 

541 **test_func._overridden_settings, 

542 **self.options, 

543 } 

544 

545 def decorate_class(self, cls): 

546 from django.test import SimpleTestCase 

547 

548 if not issubclass(cls, SimpleTestCase): 

549 raise ValueError( 

550 "Only subclasses of Django SimpleTestCase can be decorated " 

551 "with override_settings" 

552 ) 

553 self.save_options(cls) 

554 return cls 

555 

556 

557class modify_settings(override_settings): 

558 """ 

559 Like override_settings, but makes it possible to append, prepend, or remove 

560 items instead of redefining the entire list. 

561 """ 

562 

563 def __init__(self, *args, **kwargs): 

564 if args: 

565 # Hack used when instantiating from SimpleTestCase.setUpClass. 

566 assert not kwargs 

567 self.operations = args[0] 

568 else: 

569 assert not args 

570 self.operations = list(kwargs.items()) 

571 super(override_settings, self).__init__() 

572 

573 def save_options(self, test_func): 

574 if test_func._modified_settings is None: 

575 test_func._modified_settings = self.operations 

576 else: 

577 # Duplicate list to prevent subclasses from altering their parent. 

578 test_func._modified_settings = ( 

579 list(test_func._modified_settings) + self.operations 

580 ) 

581 

582 def enable(self): 

583 self.options = {} 

584 for name, operations in self.operations: 

585 try: 

586 # When called from SimpleTestCase.setUpClass, values may be 

587 # overridden several times; cumulate changes. 

588 value = self.options[name] 

589 except KeyError: 

590 value = list(getattr(settings, name, [])) 

591 for action, items in operations.items(): 

592 # items my be a single value or an iterable. 

593 if isinstance(items, str): 

594 items = [items] 

595 if action == "append": 

596 value = value + [item for item in items if item not in value] 

597 elif action == "prepend": 

598 value = [item for item in items if item not in value] + value 

599 elif action == "remove": 

600 value = [item for item in value if item not in items] 

601 else: 

602 raise ValueError("Unsupported action: %s" % action) 

603 self.options[name] = value 

604 super().enable() 

605 

606 

607class override_system_checks(TestContextDecorator): 

608 """ 

609 Act as a decorator. Override list of registered system checks. 

610 Useful when you override `INSTALLED_APPS`, e.g. if you exclude `auth` app, 

611 you also need to exclude its system checks. 

612 """ 

613 

614 def __init__(self, new_checks, deployment_checks=None): 

615 from django.core.checks.registry import registry 

616 

617 self.registry = registry 

618 self.new_checks = new_checks 

619 self.deployment_checks = deployment_checks 

620 super().__init__() 

621 

622 def enable(self): 

623 self.old_checks = self.registry.registered_checks 

624 self.registry.registered_checks = set() 

625 for check in self.new_checks: 

626 self.registry.register(check, *getattr(check, "tags", ())) 

627 self.old_deployment_checks = self.registry.deployment_checks 

628 if self.deployment_checks is not None: 

629 self.registry.deployment_checks = set() 

630 for check in self.deployment_checks: 

631 self.registry.register(check, *getattr(check, "tags", ()), deploy=True) 

632 

633 def disable(self): 

634 self.registry.registered_checks = self.old_checks 

635 self.registry.deployment_checks = self.old_deployment_checks 

636 

637 

638def compare_xml(want, got): 

639 """ 

640 Try to do a 'xml-comparison' of want and got. Plain string comparison 

641 doesn't always work because, for example, attribute ordering should not be 

642 important. Ignore comment nodes, processing instructions, document type 

643 node, and leading and trailing whitespaces. 

644 

645 Based on https://github.com/lxml/lxml/blob/master/src/lxml/doctestcompare.py 

646 """ 

647 _norm_whitespace_re = re.compile(r"[ \t\n][ \t\n]+") 

648 

649 def norm_whitespace(v): 

650 return _norm_whitespace_re.sub(" ", v) 

651 

652 def child_text(element): 

653 return "".join( 

654 c.data for c in element.childNodes if c.nodeType == Node.TEXT_NODE 

655 ) 

656 

657 def children(element): 

658 return [c for c in element.childNodes if c.nodeType == Node.ELEMENT_NODE] 

659 

660 def norm_child_text(element): 

661 return norm_whitespace(child_text(element)) 

662 

663 def attrs_dict(element): 

664 return dict(element.attributes.items()) 

665 

666 def check_element(want_element, got_element): 

667 if want_element.tagName != got_element.tagName: 

668 return False 

669 if norm_child_text(want_element) != norm_child_text(got_element): 

670 return False 

671 if attrs_dict(want_element) != attrs_dict(got_element): 

672 return False 

673 want_children = children(want_element) 

674 got_children = children(got_element) 

675 if len(want_children) != len(got_children): 

676 return False 

677 return all( 

678 check_element(want, got) for want, got in zip(want_children, got_children) 

679 ) 

680 

681 def first_node(document): 

682 for node in document.childNodes: 

683 if node.nodeType not in ( 

684 Node.COMMENT_NODE, 

685 Node.DOCUMENT_TYPE_NODE, 

686 Node.PROCESSING_INSTRUCTION_NODE, 

687 ): 

688 return node 

689 

690 want = want.strip().replace("\\n", "\n") 

691 got = got.strip().replace("\\n", "\n") 

692 

693 # If the string is not a complete xml document, we may need to add a 

694 # root element. This allow us to compare fragments, like "<foo/><bar/>" 

695 if not want.startswith("<?xml"): 

696 wrapper = "<root>%s</root>" 

697 want = wrapper % want 

698 got = wrapper % got 

699 

700 # Parse the want and got strings, and compare the parsings. 

701 want_root = first_node(parseString(want)) 

702 got_root = first_node(parseString(got)) 

703 

704 return check_element(want_root, got_root) 

705 

706 

707class CaptureQueriesContext: 

708 """ 

709 Context manager that captures queries executed by the specified connection. 

710 """ 

711 

712 def __init__(self, connection): 

713 self.connection = connection 

714 

715 def __iter__(self): 

716 return iter(self.captured_queries) 

717 

718 def __getitem__(self, index): 

719 return self.captured_queries[index] 

720 

721 def __len__(self): 

722 return len(self.captured_queries) 

723 

724 @property 

725 def captured_queries(self): 

726 return self.connection.queries[self.initial_queries : self.final_queries] 

727 

728 def __enter__(self): 

729 self.force_debug_cursor = self.connection.force_debug_cursor 

730 self.connection.force_debug_cursor = True 

731 # Run any initialization queries if needed so that they won't be 

732 # included as part of the count. 

733 self.connection.ensure_connection() 

734 self.initial_queries = len(self.connection.queries_log) 

735 self.final_queries = None 

736 request_started.disconnect(reset_queries) 

737 return self 

738 

739 def __exit__(self, exc_type, exc_value, traceback): 

740 self.connection.force_debug_cursor = self.force_debug_cursor 

741 request_started.connect(reset_queries) 

742 if exc_type is not None: 

743 return 

744 self.final_queries = len(self.connection.queries_log) 

745 

746 

747class ignore_warnings(TestContextDecorator): 

748 def __init__(self, **kwargs): 

749 self.ignore_kwargs = kwargs 

750 if "message" in self.ignore_kwargs or "module" in self.ignore_kwargs: 

751 self.filter_func = warnings.filterwarnings 

752 else: 

753 self.filter_func = warnings.simplefilter 

754 super().__init__() 

755 

756 def enable(self): 

757 self.catch_warnings = warnings.catch_warnings() 

758 self.catch_warnings.__enter__() 

759 self.filter_func("ignore", **self.ignore_kwargs) 

760 

761 def disable(self): 

762 self.catch_warnings.__exit__(*sys.exc_info()) 

763 

764 

765# On OSes that don't provide tzset (Windows), we can't set the timezone 

766# in which the program runs. As a consequence, we must skip tests that 

767# don't enforce a specific timezone (with timezone.override or equivalent), 

768# or attempt to interpret naive datetimes in the default timezone. 

769 

770requires_tz_support = skipUnless( 

771 TZ_SUPPORT, 

772 "This test relies on the ability to run a program in an arbitrary " 

773 "time zone, but your operating system isn't able to do that.", 

774) 

775 

776 

777@contextmanager 

778def extend_sys_path(*paths): 

779 """Context manager to temporarily add paths to sys.path.""" 

780 _orig_sys_path = sys.path[:] 

781 sys.path.extend(paths) 

782 try: 

783 yield 

784 finally: 

785 sys.path = _orig_sys_path 

786 

787 

788@contextmanager 

789def isolate_lru_cache(lru_cache_object): 

790 """Clear the cache of an LRU cache object on entering and exiting.""" 

791 lru_cache_object.cache_clear() 

792 try: 

793 yield 

794 finally: 

795 lru_cache_object.cache_clear() 

796 

797 

798@contextmanager 

799def captured_output(stream_name): 

800 """Return a context manager used by captured_stdout/stdin/stderr 

801 that temporarily replaces the sys stream *stream_name* with a StringIO. 

802 

803 Note: This function and the following ``captured_std*`` are copied 

804 from CPython's ``test.support`` module.""" 

805 orig_stdout = getattr(sys, stream_name) 

806 setattr(sys, stream_name, StringIO()) 

807 try: 

808 yield getattr(sys, stream_name) 

809 finally: 

810 setattr(sys, stream_name, orig_stdout) 

811 

812 

813def captured_stdout(): 

814 """Capture the output of sys.stdout: 

815 

816 with captured_stdout() as stdout: 

817 print("hello") 

818 self.assertEqual(stdout.getvalue(), "hello\n") 

819 """ 

820 return captured_output("stdout") 

821 

822 

823def captured_stderr(): 

824 """Capture the output of sys.stderr: 

825 

826 with captured_stderr() as stderr: 

827 print("hello", file=sys.stderr) 

828 self.assertEqual(stderr.getvalue(), "hello\n") 

829 """ 

830 return captured_output("stderr") 

831 

832 

833def captured_stdin(): 

834 """Capture the input to sys.stdin: 

835 

836 with captured_stdin() as stdin: 

837 stdin.write('hello\n') 

838 stdin.seek(0) 

839 # call test code that consumes from sys.stdin 

840 captured = input() 

841 self.assertEqual(captured, "hello") 

842 """ 

843 return captured_output("stdin") 

844 

845 

846@contextmanager 

847def freeze_time(t): 

848 """ 

849 Context manager to temporarily freeze time.time(). This temporarily 

850 modifies the time function of the time module. Modules which import the 

851 time function directly (e.g. `from time import time`) won't be affected 

852 This isn't meant as a public API, but helps reduce some repetitive code in 

853 Django's test suite. 

854 """ 

855 _real_time = time.time 

856 time.time = lambda: t 

857 try: 

858 yield 

859 finally: 

860 time.time = _real_time 

861 

862 

863def require_jinja2(test_func): 

864 """ 

865 Decorator to enable a Jinja2 template engine in addition to the regular 

866 Django template engine for a test or skip it if Jinja2 isn't available. 

867 """ 

868 test_func = skipIf(jinja2 is None, "this test requires jinja2")(test_func) 

869 return override_settings( 

870 TEMPLATES=[ 

871 { 

872 "BACKEND": "django.template.backends.django.DjangoTemplates", 

873 "APP_DIRS": True, 

874 }, 

875 { 

876 "BACKEND": "django.template.backends.jinja2.Jinja2", 

877 "APP_DIRS": True, 

878 "OPTIONS": {"keep_trailing_newline": True}, 

879 }, 

880 ] 

881 )(test_func) 

882 

883 

884class override_script_prefix(TestContextDecorator): 

885 """Decorator or context manager to temporary override the script prefix.""" 

886 

887 def __init__(self, prefix): 

888 self.prefix = prefix 

889 super().__init__() 

890 

891 def enable(self): 

892 self.old_prefix = get_script_prefix() 

893 set_script_prefix(self.prefix) 

894 

895 def disable(self): 

896 set_script_prefix(self.old_prefix) 

897 

898 

899class LoggingCaptureMixin: 

900 """ 

901 Capture the output from the 'django' logger and store it on the class's 

902 logger_output attribute. 

903 """ 

904 

905 def setUp(self): 

906 self.logger = logging.getLogger("django") 

907 self.old_stream = self.logger.handlers[0].stream 

908 self.logger_output = StringIO() 

909 self.logger.handlers[0].stream = self.logger_output 

910 

911 def tearDown(self): 

912 self.logger.handlers[0].stream = self.old_stream 

913 

914 

915class isolate_apps(TestContextDecorator): 

916 """ 

917 Act as either a decorator or a context manager to register models defined 

918 in its wrapped context to an isolated registry. 

919 

920 The list of installed apps the isolated registry should contain must be 

921 passed as arguments. 

922 

923 Two optional keyword arguments can be specified: 

924 

925 `attr_name`: attribute assigned the isolated registry if used as a class 

926 decorator. 

927 

928 `kwarg_name`: keyword argument passing the isolated registry if used as a 

929 function decorator. 

930 """ 

931 

932 def __init__(self, *installed_apps, **kwargs): 

933 self.installed_apps = installed_apps 

934 super().__init__(**kwargs) 

935 

936 def enable(self): 

937 self.old_apps = Options.default_apps 

938 apps = Apps(self.installed_apps) 

939 setattr(Options, "default_apps", apps) 

940 return apps 

941 

942 def disable(self): 

943 setattr(Options, "default_apps", self.old_apps) 

944 

945 

946class TimeKeeper: 

947 def __init__(self): 

948 self.records = collections.defaultdict(list) 

949 

950 @contextmanager 

951 def timed(self, name): 

952 self.records[name] 

953 start_time = time.perf_counter() 

954 try: 

955 yield 

956 finally: 

957 end_time = time.perf_counter() - start_time 

958 self.records[name].append(end_time) 

959 

960 def print_results(self): 

961 for name, end_times in self.records.items(): 

962 for record_time in end_times: 

963 record = "%s took %.3fs" % (name, record_time) 

964 sys.stderr.write(record + os.linesep) 

965 

966 

967class NullTimeKeeper: 

968 @contextmanager 

969 def timed(self, name): 

970 yield 

971 

972 def print_results(self): 

973 pass 

974 

975 

976def tag(*tags): 

977 """Decorator to add tags to a test class or method.""" 

978 

979 def decorator(obj): 

980 if hasattr(obj, "tags"): 

981 obj.tags = obj.tags.union(tags) 

982 else: 

983 setattr(obj, "tags", set(tags)) 

984 return obj 

985 

986 return decorator 

987 

988 

989@contextmanager 

990def register_lookup(field, *lookups, lookup_name=None): 

991 """ 

992 Context manager to temporarily register lookups on a model field using 

993 lookup_name (or the lookup's lookup_name if not provided). 

994 """ 

995 try: 

996 for lookup in lookups: 

997 field.register_lookup(lookup, lookup_name) 

998 yield 

999 finally: 

1000 for lookup in lookups: 

1001 field._unregister_lookup(lookup, lookup_name)