Coverage for /var/srv/projects/api.amasfac.comuna18.com/tmp/venv/lib/python3.9/site-packages/factory/declarations.py: 56%
287 statements
« prev ^ index » next coverage.py v6.4.4, created at 2023-07-17 14:22 -0600
« prev ^ index » next coverage.py v6.4.4, created at 2023-07-17 14:22 -0600
1# Copyright: See the LICENSE file.
4import itertools
5import logging
6import typing as T
8from . import enums, errors, utils
10logger = logging.getLogger('factory.generate')
13class BaseDeclaration(utils.OrderedBase):
14 """A factory declaration.
16 Declarations mark an attribute as needing lazy evaluation.
17 This allows them to refer to attributes defined by other BaseDeclarations
18 in the same factory.
19 """
21 FACTORY_BUILDER_PHASE = enums.BuilderPhase.ATTRIBUTE_RESOLUTION
23 #: Whether to unroll the context before evaluating the declaration.
24 #: Set to False on declarations that perform their own unrolling.
25 UNROLL_CONTEXT_BEFORE_EVALUATION = True
27 def __init__(self, **defaults):
28 super().__init__()
29 self._defaults = defaults or {}
31 def unroll_context(self, instance, step, context):
32 full_context = dict()
33 full_context.update(self._defaults)
34 full_context.update(context)
36 if not self.UNROLL_CONTEXT_BEFORE_EVALUATION:
37 return full_context
38 if not any(enums.get_builder_phase(v) for v in full_context.values()): 38 ↛ 42line 38 didn't jump to line 42, because the condition on line 38 was never false
39 # Optimization for simple contexts - don't do anything.
40 return full_context
42 import factory.base
43 subfactory = factory.base.DictFactory
44 return step.recurse(subfactory, full_context, force_sequence=step.sequence)
46 def evaluate_pre(self, instance, step, overrides):
47 context = self.unroll_context(instance, step, overrides)
48 return self.evaluate(instance, step, context)
50 def evaluate(self, instance, step, extra):
51 """Evaluate this declaration.
53 Args:
54 instance (builder.Resolver): The object holding currently computed
55 attributes
56 step: a factory.builder.BuildStep
57 extra (dict): additional, call-time added kwargs
58 for the step.
59 """
60 raise NotImplementedError('This is an abstract method')
63class OrderedDeclaration(BaseDeclaration):
64 """Compatibility"""
66 # FIXME(rbarrois)
69class LazyFunction(BaseDeclaration):
70 """Simplest BaseDeclaration computed by calling the given function.
72 Attributes:
73 function (function): a function without arguments and
74 returning the computed value.
75 """
77 def __init__(self, function):
78 super().__init__()
79 self.function = function
81 def evaluate(self, instance, step, extra):
82 logger.debug("LazyFunction: Evaluating %r on %r", self.function, step)
83 return self.function()
86class LazyAttribute(BaseDeclaration):
87 """Specific BaseDeclaration computed using a lambda.
89 Attributes:
90 function (function): a function, expecting the current LazyStub and
91 returning the computed value.
92 """
94 def __init__(self, function):
95 super().__init__()
96 self.function = function
98 def evaluate(self, instance, step, extra):
99 logger.debug("LazyAttribute: Evaluating %r on %r", self.function, instance)
100 return self.function(instance)
103class _UNSPECIFIED:
104 pass
107def deepgetattr(obj, name, default=_UNSPECIFIED):
108 """Try to retrieve the given attribute of an object, digging on '.'.
110 This is an extended getattr, digging deeper if '.' is found.
112 Args:
113 obj (object): the object of which an attribute should be read
114 name (str): the name of an attribute to look up.
115 default (object): the default value to use if the attribute wasn't found
117 Returns:
118 the attribute pointed to by 'name', splitting on '.'.
120 Raises:
121 AttributeError: if obj has no 'name' attribute.
122 """
123 try:
124 if '.' in name:
125 attr, subname = name.split('.', 1)
126 return deepgetattr(getattr(obj, attr), subname, default)
127 else:
128 return getattr(obj, name)
129 except AttributeError:
130 if default is _UNSPECIFIED:
131 raise
132 else:
133 return default
136class SelfAttribute(BaseDeclaration):
137 """Specific BaseDeclaration copying values from other fields.
139 If the field name starts with two dots or more, the lookup will be anchored
140 in the related 'parent'.
142 Attributes:
143 depth (int): the number of steps to go up in the containers chain
144 attribute_name (str): the name of the attribute to copy.
145 default (object): the default value to use if the attribute doesn't
146 exist.
147 """
149 def __init__(self, attribute_name, default=_UNSPECIFIED):
150 super().__init__()
151 depth = len(attribute_name) - len(attribute_name.lstrip('.'))
152 attribute_name = attribute_name[depth:]
154 self.depth = depth
155 self.attribute_name = attribute_name
156 self.default = default
158 def evaluate(self, instance, step, extra):
159 if self.depth > 1:
160 # Fetching from a parent
161 target = step.chain[self.depth - 1]
162 else:
163 target = instance
165 logger.debug("SelfAttribute: Picking attribute %r on %r", self.attribute_name, target)
166 return deepgetattr(target, self.attribute_name, self.default)
168 def __repr__(self):
169 return '<%s(%r, default=%r)>' % (
170 self.__class__.__name__,
171 self.attribute_name,
172 self.default,
173 )
176class Iterator(BaseDeclaration):
177 """Fill this value using the values returned by an iterator.
179 Warning: the iterator should not end !
181 Attributes:
182 iterator (iterable): the iterator whose value should be used.
183 getter (callable or None): a function to parse returned values
184 """
186 def __init__(self, iterator, cycle=True, getter=None):
187 super().__init__()
188 self.getter = getter
189 self.iterator = None
191 if cycle:
192 self.iterator_builder = lambda: utils.ResetableIterator(itertools.cycle(iterator))
193 else:
194 self.iterator_builder = lambda: utils.ResetableIterator(iterator)
196 def evaluate(self, instance, step, extra):
197 # Begin unrolling as late as possible.
198 # This helps with ResetableIterator(MyModel.objects.all())
199 if self.iterator is None:
200 self.iterator = self.iterator_builder()
202 logger.debug("Iterator: Fetching next value from %r", self.iterator)
203 value = next(iter(self.iterator))
204 if self.getter is None:
205 return value
206 return self.getter(value)
208 def reset(self):
209 """Reset the internal iterator."""
210 if self.iterator is not None:
211 self.iterator.reset()
214class Sequence(BaseDeclaration):
215 """Specific BaseDeclaration to use for 'sequenced' fields.
217 These fields are typically used to generate increasing unique values.
219 Attributes:
220 function (function): A function, expecting the current sequence counter
221 and returning the computed value.
222 """
223 def __init__(self, function):
224 super().__init__()
225 self.function = function
227 def evaluate(self, instance, step, extra):
228 logger.debug("Sequence: Computing next value of %r for seq=%s", self.function, step.sequence)
229 return self.function(int(step.sequence))
232class LazyAttributeSequence(Sequence):
233 """Composite of a LazyAttribute and a Sequence.
235 Attributes:
236 function (function): A function, expecting the current LazyStub and the
237 current sequence counter.
238 type (function): A function converting an integer into the expected kind
239 of counter for the 'function' attribute.
240 """
241 def evaluate(self, instance, step, extra):
242 logger.debug(
243 "LazyAttributeSequence: Computing next value of %r for seq=%s, obj=%r",
244 self.function, step.sequence, instance)
245 return self.function(instance, int(step.sequence))
248class ContainerAttribute(BaseDeclaration):
249 """Variant of LazyAttribute, also receives the containers of the object.
251 Attributes:
252 function (function): A function, expecting the current LazyStub and the
253 (optional) object having a subfactory containing this attribute.
254 strict (bool): Whether evaluating should fail when the containers are
255 not passed in (i.e used outside a SubFactory).
256 """
257 def __init__(self, function, strict=True):
258 super().__init__()
259 self.function = function
260 self.strict = strict
262 def evaluate(self, instance, step, extra):
263 """Evaluate the current ContainerAttribute.
265 Args:
266 obj (LazyStub): a lazy stub of the object being constructed, if
267 needed.
268 containers (list of LazyStub): a list of lazy stubs of factories
269 being evaluated in a chain, each item being a future field of
270 next one.
271 """
272 # Strip the current instance from the chain
273 chain = step.chain[1:]
274 if self.strict and not chain:
275 raise TypeError(
276 "A ContainerAttribute in 'strict' mode can only be used "
277 "within a SubFactory.")
279 return self.function(instance, chain)
282class ParameteredAttribute(BaseDeclaration):
283 """Base class for attributes expecting parameters.
285 Attributes:
286 defaults (dict): Default values for the parameters.
287 May be overridden by call-time parameters.
288 """
290 def evaluate(self, instance, step, extra):
291 """Evaluate the current definition and fill its attributes.
293 Uses attributes definition in the following order:
294 - values defined when defining the ParameteredAttribute
295 - additional values defined when instantiating the containing factory
297 Args:
298 instance (builder.Resolver): The object holding currently computed
299 attributes
300 step: a factory.builder.BuildStep
301 extra (dict): additional, call-time added kwargs
302 for the step.
303 """
304 return self.generate(step, extra)
306 def generate(self, step, params):
307 """Actually generate the related attribute.
309 Args:
310 sequence (int): the current sequence number
311 obj (LazyStub): the object being constructed
312 create (bool): whether the calling factory was in 'create' or
313 'build' mode
314 params (dict): parameters inherited from init and evaluation-time
315 overrides.
317 Returns:
318 Computed value for the current declaration.
319 """
320 raise NotImplementedError()
323class _FactoryWrapper:
324 """Handle a 'factory' arg.
326 Such args can be either a Factory subclass, or a fully qualified import
327 path for that subclass (e.g 'myapp.factories.MyFactory').
328 """
329 def __init__(self, factory_or_path):
330 self.factory = None
331 self.module = self.name = ''
332 if isinstance(factory_or_path, type): 332 ↛ 335line 332 didn't jump to line 335, because the condition on line 332 was never false
333 self.factory = factory_or_path
334 else:
335 if not (isinstance(factory_or_path, str) and '.' in factory_or_path):
336 raise ValueError(
337 "A factory= argument must receive either a class "
338 "or the fully qualified path to a Factory subclass; got "
339 "%r instead." % factory_or_path)
340 self.module, self.name = factory_or_path.rsplit('.', 1)
342 def get(self):
343 if self.factory is None: 343 ↛ 344line 343 didn't jump to line 344, because the condition on line 343 was never true
344 self.factory = utils.import_object(
345 self.module,
346 self.name,
347 )
348 return self.factory
350 def __repr__(self):
351 if self.factory is None:
352 return f'<_FactoryImport: {self.module}.{self.name}>'
353 else:
354 return f'<_FactoryImport: {self.factory.__class__}>'
357class SubFactory(BaseDeclaration):
358 """Base class for attributes based upon a sub-factory.
360 Attributes:
361 defaults (dict): Overrides to the defaults defined in the wrapped
362 factory
363 factory (base.Factory): the wrapped factory
364 """
366 # Whether to align the attribute's sequence counter to the holding
367 # factory's sequence counter
368 FORCE_SEQUENCE = False
369 UNROLL_CONTEXT_BEFORE_EVALUATION = False
371 def __init__(self, factory, **kwargs):
372 super().__init__(**kwargs)
373 self.factory_wrapper = _FactoryWrapper(factory)
375 def get_factory(self):
376 """Retrieve the wrapped factory.Factory subclass."""
377 return self.factory_wrapper.get()
379 def evaluate(self, instance, step, extra):
380 """Evaluate the current definition and fill its attributes.
382 Args:
383 step: a factory.builder.BuildStep
384 params (dict): additional, call-time added kwargs
385 for the step.
386 """
387 subfactory = self.get_factory()
388 logger.debug(
389 "SubFactory: Instantiating %s.%s(%s), create=%r",
390 subfactory.__module__, subfactory.__name__,
391 utils.log_pprint(kwargs=extra),
392 step,
393 )
394 force_sequence = step.sequence if self.FORCE_SEQUENCE else None
395 return step.recurse(subfactory, extra, force_sequence=force_sequence)
398class Dict(SubFactory):
399 """Fill a dict with usual declarations."""
401 FORCE_SEQUENCE = True
403 def __init__(self, params, dict_factory='factory.DictFactory'):
404 super().__init__(dict_factory, **dict(params))
407class List(SubFactory):
408 """Fill a list with standard declarations."""
410 FORCE_SEQUENCE = True
412 def __init__(self, params, list_factory='factory.ListFactory'):
413 params = {str(i): v for i, v in enumerate(params)}
414 super().__init__(list_factory, **params)
417# Parameters
418# ==========
421class Skip:
422 def __bool__(self):
423 return False
426SKIP = Skip()
429class Maybe(BaseDeclaration):
430 def __init__(self, decider, yes_declaration=SKIP, no_declaration=SKIP):
431 super().__init__()
433 if enums.get_builder_phase(decider) is None:
434 # No builder phase => flat value
435 decider = SelfAttribute(decider, default=None)
437 self.decider = decider
438 self.yes = yes_declaration
439 self.no = no_declaration
441 phases = {
442 'yes_declaration': enums.get_builder_phase(yes_declaration),
443 'no_declaration': enums.get_builder_phase(no_declaration),
444 }
445 used_phases = {phase for phase in phases.values() if phase is not None}
447 if len(used_phases) > 1:
448 raise TypeError(f"Inconsistent phases for {self!r}: {phases!r}")
450 self.FACTORY_BUILDER_PHASE = used_phases.pop() if used_phases else enums.BuilderPhase.ATTRIBUTE_RESOLUTION
452 def evaluate_post(self, instance, step, overrides):
453 """Handle post-generation declarations"""
454 decider_phase = enums.get_builder_phase(self.decider)
455 if decider_phase == enums.BuilderPhase.ATTRIBUTE_RESOLUTION:
456 # Note: we work on the *builder stub*, not on the actual instance.
457 # This gives us access to all Params-level definitions.
458 choice = self.decider.evaluate_pre(
459 instance=step.stub, step=step, overrides=overrides)
460 else:
461 assert decider_phase == enums.BuilderPhase.POST_INSTANTIATION
462 choice = self.decider.evaluate_post(
463 instance=instance, step=step, overrides={})
465 target = self.yes if choice else self.no
466 if enums.get_builder_phase(target) == enums.BuilderPhase.POST_INSTANTIATION:
467 return target.evaluate_post(
468 instance=instance,
469 step=step,
470 overrides=overrides,
471 )
472 else:
473 # Flat value (can't be ATTRIBUTE_RESOLUTION, checked in __init__)
474 return target
476 def evaluate_pre(self, instance, step, overrides):
477 choice = self.decider.evaluate(instance=instance, step=step, extra={})
478 target = self.yes if choice else self.no
480 if isinstance(target, BaseDeclaration):
481 return target.evaluate_pre(
482 instance=instance,
483 step=step,
484 overrides=overrides,
485 )
486 else:
487 # Flat value (can't be POST_INSTANTIATION, checked in __init__)
488 return target
490 def __repr__(self):
491 return f'Maybe({self.decider!r}, yes={self.yes!r}, no={self.no!r})'
494class Parameter(utils.OrderedBase):
495 """A complex parameter, to be used in a Factory.Params section.
497 Must implement:
498 - A "compute" function, performing the actual declaration override
499 - Optionally, a get_revdeps() function (to compute other parameters it may alter)
500 """
502 def as_declarations(self, field_name, declarations):
503 """Compute the overrides for this parameter.
505 Args:
506 - field_name (str): the field this parameter is installed at
507 - declarations (dict): the global factory declarations
509 Returns:
510 dict: the declarations to override
511 """
512 raise NotImplementedError()
514 def get_revdeps(self, parameters):
515 """Retrieve the list of other parameters modified by this one."""
516 return []
519class SimpleParameter(Parameter):
520 def __init__(self, value):
521 super().__init__()
522 self.value = value
524 def as_declarations(self, field_name, declarations):
525 return {
526 field_name: self.value,
527 }
529 @classmethod
530 def wrap(cls, value):
531 if not isinstance(value, Parameter): 531 ↛ 533line 531 didn't jump to line 533, because the condition on line 531 was never false
532 return cls(value)
533 value.touch_creation_counter()
534 return value
537class Trait(Parameter):
538 """The simplest complex parameter, it enables a bunch of new declarations based on a boolean flag."""
539 def __init__(self, **overrides):
540 super().__init__()
541 self.overrides = overrides
543 def as_declarations(self, field_name, declarations):
544 overrides = {}
545 for maybe_field, new_value in self.overrides.items():
546 overrides[maybe_field] = Maybe(
547 decider=SelfAttribute(
548 '%s.%s' % (
549 '.' * maybe_field.count(enums.SPLITTER),
550 field_name,
551 ),
552 default=False,
553 ),
554 yes_declaration=new_value,
555 no_declaration=declarations.get(maybe_field, SKIP),
556 )
557 return overrides
559 def get_revdeps(self, parameters):
560 """This might alter fields it's injecting."""
561 return [param for param in parameters if param in self.overrides]
563 def __repr__(self):
564 return '%s(%s)' % (
565 self.__class__.__name__,
566 ', '.join('%s=%r' % t for t in self.overrides.items())
567 )
570# Post-generation
571# ===============
574class PostGenerationContext(T.NamedTuple):
575 value_provided: bool
576 value: T.Any
577 extra: T.Dict[str, T.Any]
580class PostGenerationDeclaration(BaseDeclaration):
581 """Declarations to be called once the model object has been generated."""
583 FACTORY_BUILDER_PHASE = enums.BuilderPhase.POST_INSTANTIATION
585 def evaluate_post(self, instance, step, overrides):
586 context = self.unroll_context(instance, step, overrides)
587 postgen_context = PostGenerationContext(
588 value_provided=bool('' in context),
589 value=context.get(''),
590 extra={k: v for k, v in context.items() if k != ''},
591 )
592 return self.call(instance, step, postgen_context)
594 def call(self, instance, step, context): # pragma: no cover
595 """Call this hook; no return value is expected.
597 Args:
598 obj (object): the newly generated object
599 create (bool): whether the object was 'built' or 'created'
600 context: a builder.PostGenerationContext containing values
601 extracted from the containing factory's declaration
602 """
603 raise NotImplementedError()
606class PostGeneration(PostGenerationDeclaration):
607 """Calls a given function once the object has been generated."""
608 def __init__(self, function):
609 super().__init__()
610 self.function = function
612 def call(self, instance, step, context):
613 logger.debug(
614 "PostGeneration: Calling %s.%s(%s)",
615 self.function.__module__,
616 self.function.__name__,
617 utils.log_pprint(
618 (instance, step),
619 context._asdict(),
620 ),
621 )
622 create = step.builder.strategy == enums.CREATE_STRATEGY
623 return self.function(
624 instance, create, context.value, **context.extra)
627class RelatedFactory(PostGenerationDeclaration):
628 """Calls a factory once the object has been generated.
630 Attributes:
631 factory (Factory): the factory to call
632 defaults (dict): extra declarations for calling the related factory
633 name (str): the name to use to refer to the generated object when
634 calling the related factory
635 """
637 UNROLL_CONTEXT_BEFORE_EVALUATION = False
639 def __init__(self, factory, factory_related_name='', **defaults):
640 super().__init__()
642 self.name = factory_related_name
643 self.defaults = defaults
644 self.factory_wrapper = _FactoryWrapper(factory)
646 def get_factory(self):
647 """Retrieve the wrapped factory.Factory subclass."""
648 return self.factory_wrapper.get()
650 def call(self, instance, step, context):
651 factory = self.get_factory()
653 if context.value_provided:
654 # The user passed in a custom value
655 logger.debug(
656 "RelatedFactory: Using provided %r instead of generating %s.%s.",
657 context.value,
658 factory.__module__, factory.__name__,
659 )
660 return context.value
662 passed_kwargs = dict(self.defaults)
663 passed_kwargs.update(context.extra)
664 if self.name:
665 passed_kwargs[self.name] = instance
667 logger.debug(
668 "RelatedFactory: Generating %s.%s(%s)",
669 factory.__module__,
670 factory.__name__,
671 utils.log_pprint((step,), passed_kwargs),
672 )
673 return step.recurse(factory, passed_kwargs)
676class RelatedFactoryList(RelatedFactory):
677 """Calls a factory 'size' times once the object has been generated.
679 Attributes:
680 factory (Factory): the factory to call "size-times"
681 defaults (dict): extra declarations for calling the related factory
682 factory_related_name (str): the name to use to refer to the generated
683 object when calling the related factory
684 size (int|lambda): the number of times 'factory' is called, ultimately
685 returning a list of 'factory' objects w/ size 'size'.
686 """
688 def __init__(self, factory, factory_related_name='', size=2, **defaults):
689 self.size = size
690 super().__init__(factory, factory_related_name, **defaults)
692 def call(self, instance, step, context):
693 parent = super()
694 return [
695 parent.call(instance, step, context)
696 for i in range(self.size if isinstance(self.size, int) else self.size())
697 ]
700class NotProvided:
701 pass
704class PostGenerationMethodCall(PostGenerationDeclaration):
705 """Calls a method of the generated object.
707 Attributes:
708 method_name (str): the method to call
709 method_args (list): arguments to pass to the method
710 method_kwargs (dict): keyword arguments to pass to the method
712 Example:
713 class UserFactory(factory.Factory):
714 ...
715 password = factory.PostGenerationMethodCall('set_pass', password='')
716 """
717 def __init__(self, method_name, *args, **kwargs):
718 super().__init__()
719 if len(args) > 1:
720 raise errors.InvalidDeclarationError(
721 "A PostGenerationMethodCall can only handle 1 positional argument; "
722 "please provide other parameters through keyword arguments."
723 )
724 self.method_name = method_name
725 self.method_arg = args[0] if args else NotProvided
726 self.method_kwargs = kwargs
728 def call(self, instance, step, context):
729 if not context.value_provided:
730 if self.method_arg is NotProvided:
731 args = ()
732 else:
733 args = (self.method_arg,)
734 else:
735 args = (context.value,)
737 kwargs = dict(self.method_kwargs)
738 kwargs.update(context.extra)
739 method = getattr(instance, self.method_name)
740 logger.debug(
741 "PostGenerationMethodCall: Calling %r.%s(%s)",
742 instance,
743 self.method_name,
744 utils.log_pprint(args, kwargs),
745 )
746 return method(*args, **kwargs)