Coverage for adhoc-cicd-odoo-odoo / odoo / orm / registry.py: 75%
688 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 18:15 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 18:15 +0000
1# Part of Odoo. See LICENSE file for full copyright and licensing details.
3""" Models registries.
5"""
6from __future__ import annotations
8import functools
9import inspect
10import logging
11import os
12import threading
13import time
14import typing
15import warnings
16from collections import defaultdict, deque
17from collections.abc import Mapping
18from contextlib import closing, contextmanager, nullcontext, ExitStack
19from functools import partial
20from operator import attrgetter
22import psycopg2.sql
24from odoo import sql_db
25from odoo.tools import (
26 SQL,
27 OrderedSet,
28 config,
29 gc,
30 lazy_classproperty,
31 remove_accents,
32 sql,
33)
34from odoo.tools.func import locked, reset_cached_properties
35from odoo.tools.lru import LRU
36from odoo.tools.misc import Collector, format_frame
38from .utils import SUPERUSER_ID
39from . import model_classes
41if typing.TYPE_CHECKING:
42 from collections.abc import Callable, Collection, Iterable, Iterator
43 from odoo.fields import Field
44 from odoo.models import BaseModel
45 from odoo.sql_db import BaseCursor, Connection, Cursor
46 from odoo.modules import module_graph
49_logger = logging.getLogger('odoo.registry')
50_schema = logging.getLogger('odoo.schema')
53_REGISTRY_CACHES = {
54 'default': 8192,
55 'assets': 512,
56 'stable': 1024,
57 'templates': 1024,
58 'routing': 1024, # 2 entries per website
59 'routing.rewrites': 8192, # url_rewrite entries
60 'templates.cached_values': 2048, # arbitrary
61 'groups': 64, # see res.groups
62}
64# cache invalidation dependencies, as follows:
65# { 'cache_key': ('cache_container_1', 'cache_container_3', ...) }
66_CACHES_BY_KEY = {
67 'default': ('default', 'templates.cached_values'),
68 'assets': ('assets', 'templates.cached_values'),
69 'stable': ('stable', 'default', 'templates.cached_values'),
70 'templates': ('templates', 'templates.cached_values'),
71 'routing': ('routing', 'routing.rewrites', 'templates.cached_values'),
72 'groups': ('groups', 'templates', 'templates.cached_values'), # The processing of groups is saved in the view
73}
75_REPLICA_RETRY_TIME = 20 * 60 # 20 minutes
78def _unaccent(x: SQL | str | psycopg2.sql.Composable) -> SQL | str | psycopg2.sql.Composed:
79 if isinstance(x, SQL):
80 return SQL("unaccent(%s)", x)
81 if isinstance(x, psycopg2.sql.Composable):
82 return psycopg2.sql.SQL('unaccent({})').format(x)
83 return f'unaccent({x})'
86class Registry(Mapping[str, type["BaseModel"]]):
87 """ Model registry for a particular database.
89 The registry is essentially a mapping between model names and model classes.
90 There is one registry instance per database.
92 """
93 _lock: threading.RLock | DummyRLock = threading.RLock()
94 _saved_lock: threading.RLock | DummyRLock | None = None
96 @lazy_classproperty
97 def registries(cls) -> LRU[str, Registry]:
98 """ A mapping from database names to registries. """
99 size = config.get('registry_lru_size', None)
100 if not size: 100 ↛ 111line 100 didn't jump to line 111 because the condition on line 100 was always true
101 # Size the LRU depending of the memory limits
102 if os.name != 'posix': 102 ↛ 104line 102 didn't jump to line 104 because the condition on line 102 was never true
103 # cannot specify the memory limit soft on windows...
104 size = 42
105 else:
106 # A registry takes 10MB of memory on average, so we reserve
107 # 10Mb (registry) + 5Mb (working memory) per registry
108 avgsz = 15 * 1024 * 1024
109 limit_memory_soft = config['limit_memory_soft'] if config['limit_memory_soft'] > 0 else (2048 * 1024 * 1024)
110 size = (limit_memory_soft // avgsz) or 1
111 return LRU(size)
113 def __new__(cls, db_name: str):
114 """ Return the registry for the given database name."""
115 assert db_name, "Missing database name"
116 with cls._lock:
117 try:
118 return cls.registries[db_name]
119 except KeyError:
120 return cls.new(db_name)
122 _init: bool # whether init needs to be done
123 ready: bool # whether everything is set up
124 loaded: bool # whether all modules are loaded
125 models: dict[str, type[BaseModel]]
127 @classmethod
128 @locked
129 def new(
130 cls,
131 db_name: str,
132 *,
133 update_module: bool = False,
134 install_modules: Collection[str] = (),
135 upgrade_modules: Collection[str] = (),
136 reinit_modules: Collection[str] = (),
137 new_db_demo: bool | None = None,
138 models_to_check: set[str] | None = None,
139 ) -> Registry:
140 """Create and return a new registry for the given database name.
142 :param db_name: The name of the database to associate with the Registry instance.
143 :param update_module: If ``True``, update modules while loading the registry. Defaults to ``False``.
144 :param install_modules: Names of modules to install.
146 * If a specified module is **not installed**, it and all of its direct and indirect
147 dependencies will be installed.
149 Defaults to an empty tuple.
151 :param upgrade_modules: Names of modules to upgrade. Their direct or indirect dependent
152 modules will also be upgraded. Defaults to an empty tuple.
153 :param reinit_modules: Names of modules to reinitialize.
155 * If a specified module is **already installed**, it and all of its installed direct and
156 indirect dependents will be re-initialized. Re-initialization means the module will be
157 upgraded without running upgrade scripts, but with data loaded in ``'init'`` mode.
159 :param new_db_demo: Whether to install demo data for the new database. If set to ``None``, the value will be
160 determined by the ``config['with_demo']``. Defaults to ``None``
161 """
162 t0 = time.time()
163 registry: Registry = object.__new__(cls)
164 registry.init(db_name)
165 registry.new = registry.init = registry.registries = None # type: ignore
166 first_registry = not cls.registries
168 # Initializing a registry will call general code which will in
169 # turn call Registry() to obtain the registry being initialized.
170 # Make it available in the registries dictionary then remove it
171 # if an exception is raised.
172 cls.delete(db_name)
173 cls.registries[db_name] = registry # pylint: disable=unsupported-assignment-operation
174 try:
175 registry.setup_signaling()
176 with registry.cursor() as cr:
177 # This transaction defines a critical section for multi-worker concurrency control.
178 # When the transaction commits, the first worker proceeds to upgrade modules. Other workers
179 # encounter a serialization error and retry, finding no upgrade marker in the database.
180 # This significantly reduces the likelihood of concurrent module upgrades across workers.
181 # NOTE: This block is intentionally outside the try-except below to prevent workers that fail
182 # due to serialization errors from calling `reset_modules_state` while the first worker is
183 # actively upgrading modules.
184 from odoo.modules import db # noqa: PLC0415
185 if db.is_initialized(cr): 185 ↛ 186line 185 didn't jump to line 186 because the condition on line 185 was never true
186 cr.execute("DELETE FROM ir_config_parameter WHERE key='base.partially_updated_database'")
187 if cr.rowcount:
188 update_module = True
189 # This should be a method on Registry
190 from odoo.modules.loading import load_modules, reset_modules_state # noqa: PLC0415
191 exit_stack = ExitStack()
192 try:
193 if upgrade_modules or install_modules or reinit_modules: 193 ↛ 195line 193 didn't jump to line 195 because the condition on line 193 was always true
194 update_module = True
195 if new_db_demo is None: 195 ↛ 197line 195 didn't jump to line 197 because the condition on line 195 was always true
196 new_db_demo = config['with_demo']
197 if first_registry and not update_module: 197 ↛ 198line 197 didn't jump to line 198 because the condition on line 197 was never true
198 exit_stack.enter_context(gc.disabling_gc())
199 load_modules(
200 registry,
201 update_module=update_module,
202 upgrade_modules=upgrade_modules,
203 install_modules=install_modules,
204 reinit_modules=reinit_modules,
205 new_db_demo=new_db_demo,
206 models_to_check=models_to_check,
207 )
208 except Exception:
209 reset_modules_state(db_name)
210 raise
211 finally:
212 exit_stack.close()
213 except Exception:
214 _logger.error('Failed to load registry')
215 del cls.registries[db_name] # pylint: disable=unsupported-delete-operation
216 raise
218 del registry._reinit_modules
220 # load_modules() above can replace the registry by calling
221 # indirectly new() again (when modules have to be uninstalled).
222 # Yeah, crazy.
223 registry = cls.registries[db_name] # pylint: disable=unsubscriptable-object
225 registry._init = False
226 registry.ready = True
227 registry.registry_invalidated = bool(update_module)
228 registry.signal_changes()
230 _logger.info("Registry loaded in %.3fs", time.time() - t0)
231 return registry
233 def init(self, db_name: str) -> None:
234 self._init = True
235 self.loaded = False
236 self.ready = False
238 self.models: dict[str, type[BaseModel]] = {} # model name/model instance mapping
239 self._sql_constraints = set() # type: ignore
240 self._database_translated_fields: dict[str, str] = {} # names and translate function names of translated fields in database {"{model}.{field_name}": "translate_func"}
241 self._database_company_dependent_fields: set[str] = set() # names of company dependent fields in database
242 if config['test_enable']: 242 ↛ 246line 242 didn't jump to line 246 because the condition on line 242 was always true
243 from odoo.tests.result import OdooTestResult # noqa: PLC0415
244 self._assertion_report: OdooTestResult | None = OdooTestResult()
245 else:
246 self._assertion_report = None
247 self._ordinary_tables: set[str] | None = None # cached names of regular tables
248 self._constraint_queue: dict[typing.Any, Callable[[BaseCursor], None]] = {} # queue of functions to call on finalization of constraints
249 self.__caches: dict[str, LRU] = {cache_name: LRU(cache_size) for cache_name, cache_size in _REGISTRY_CACHES.items()}
251 # update context during loading modules
252 self._force_upgrade_scripts: set[str] = set() # force the execution of the upgrade script for these modules
253 self._reinit_modules: set[str] = set() # modules to reinitialize
255 # modules fully loaded (maintained during init phase by `loading` module)
256 self._init_modules: set[str] = set() # modules have been initialized
257 self.updated_modules: list[str] = [] # installed/updated modules
258 self.loaded_xmlids: set[str] = set()
260 self.db_name = db_name
261 self._db: Connection = sql_db.db_connect(db_name, readonly=False)
262 self._db_readonly: Connection | None = None
263 self._db_readonly_failed_time: float | None = None
264 if config['db_replica_host'] or config['test_enable'] or 'replica' in config['dev_mode']: # by default, only use readonly pool if we have a db_replica_host defined. 264 ↛ 268line 264 didn't jump to line 268 because the condition on line 264 was always true
265 self._db_readonly = sql_db.db_connect(db_name, readonly=True)
267 # field dependencies
268 self.field_depends: Collector[Field, Field] = Collector()
269 self.field_depends_context: Collector[Field, str] = Collector()
271 # field inverses
272 self.many2many_relations: defaultdict[tuple[str, str, str], OrderedSet[tuple[str, str]]] = defaultdict(OrderedSet)
274 # field setup dependents: this enables to invalidate the setup of
275 # related fields when some of their dependencies are invalidated
276 # (for incremental model setup)
277 self.field_setup_dependents: Collector[Field, Field] = Collector()
279 # company dependent
280 self.many2one_company_dependents: Collector[str, Field] = Collector() # {model_name: (field1, field2, ...)}
282 # constraint checks
283 self.not_null_fields: set[Field] = set()
285 # cache of methods get_field_trigger_tree() and is_modifying_relations()
286 self._field_trigger_trees: dict[Field, TriggerTree] = {}
287 self._is_modifying_relations: dict[Field, bool] = {}
289 # Inter-process signaling:
290 # The `orm_signaling_registry` sequence indicates the whole registry
291 # must be reloaded.
292 # The `orm_signaling_... sequence` indicates the corresponding cache must be
293 # invalidated (i.e. cleared).
294 self.registry_sequence: int = -1
295 self.cache_sequences: dict[str, int] = {}
297 # Flags indicating invalidation of the registry or the cache.
298 self._invalidation_flags = threading.local()
300 from odoo.modules import db # noqa: PLC0415
301 with closing(self.cursor()) as cr:
302 self.has_unaccent = db.has_unaccent(cr)
303 self.has_trigram = db.has_trigram(cr)
305 self.unaccent = _unaccent if self.has_unaccent else lambda x: x # type: ignore
306 self.unaccent_python = remove_accents if self.has_unaccent else lambda x: x
308 @classmethod
309 @locked
310 def delete(cls, db_name: str) -> None:
311 """ Delete the registry linked to a given database. """
312 if db_name in cls.registries: # pylint: disable=unsupported-membership-test 312 ↛ 313line 312 didn't jump to line 313 because the condition on line 312 was never true
313 del cls.registries[db_name] # pylint: disable=unsupported-delete-operation
315 @classmethod
316 @locked
317 def delete_all(cls):
318 """ Delete all the registries. """
319 cls.registries.clear()
321 #
322 # Mapping abstract methods implementation
323 # => mixin provides methods keys, items, values, get, __eq__, and __ne__
324 #
325 def __len__(self):
326 """ Return the size of the registry. """
327 return len(self.models)
329 def __iter__(self):
330 """ Return an iterator over all model names. """
331 return iter(self.models)
333 def __getitem__(self, model_name: str) -> type[BaseModel]:
334 """ Return the model with the given name or raise KeyError if it doesn't exist."""
335 return self.models[model_name]
337 def __setitem__(self, model_name: str, model: type[BaseModel]):
338 """ Add or replace a model in the registry."""
339 self.models[model_name] = model
341 def __delitem__(self, model_name: str):
342 """ Remove a (custom) model from the registry. """
343 del self.models[model_name]
344 # the custom model can inherit from mixins ('mail.thread', ...)
345 for Model in self.models.values():
346 Model._inherit_children.discard(model_name)
348 def descendants(self, model_names: Iterable[str], *kinds: typing.Literal['_inherit', '_inherits']) -> OrderedSet[str]:
349 """ Return the models corresponding to ``model_names`` and all those
350 that inherit/inherits from them.
351 """
352 assert all(kind in ('_inherit', '_inherits') for kind in kinds)
353 funcs = [attrgetter(kind + '_children') for kind in kinds]
355 models: OrderedSet[str] = OrderedSet()
356 queue = deque(model_names)
357 while queue:
358 model = self.get(queue.popleft())
359 if model is None or model._name in models:
360 continue
361 models.add(model._name)
362 for func in funcs:
363 queue.extend(func(model))
364 return models
366 def load(self, module: module_graph.ModuleNode) -> list[str]:
367 """ Load a given module in the registry, and return the names of the
368 directly modified models.
370 At the Python level, the modules are already loaded, but not yet on a
371 per-registry level. This method populates a registry with the given
372 modules, i.e. it instantiates all the classes of a the given module
373 and registers them in the registry.
375 In order to determine all the impacted models, one should invoke method
376 :meth:`descendants` with `'_inherit'` and `'_inherits'`.
377 """
378 from . import models # noqa: PLC0415
380 # clear cache to ensure consistency, but do not signal it
381 for cache in self.__caches.values():
382 cache.clear()
384 reset_cached_properties(self)
385 self._field_trigger_trees.clear()
386 self._is_modifying_relations.clear()
388 # Instantiate registered classes (via the MetaModel automatic discovery
389 # or via explicit constructor call), and add them to the pool.
390 model_names = []
391 for model_def in models.MetaModel._module_to_models__.get(module.name, []):
392 # models register themselves in self.models
393 model_cls = model_classes.add_to_registry(self, model_def)
394 model_names.append(model_cls._name)
396 return model_names
398 @locked
399 def _setup_models__(self, cr: BaseCursor, model_names: Iterable[str] | None = None) -> None: # noqa: PLW3201
400 """ Perform the setup of models.
401 This must be called after loading modules and before using the ORM.
403 When given ``model_names``, it performs an incremental setup: only the
404 models impacted by the given ``model_names`` and all the already-marked
405 models will be set up. Otherwise, all models are set up.
406 """
407 from .environments import Environment # noqa: PLC0415
408 env = Environment(cr, SUPERUSER_ID, {})
409 env.invalidate_all()
411 # Uninstall registry hooks. Because of the condition, this only happens
412 # on a fully loaded registry, and not on a registry being loaded.
413 if self.ready: 413 ↛ 414line 413 didn't jump to line 414 because the condition on line 413 was never true
414 for model in env.values():
415 model._unregister_hook()
417 # clear cache to ensure consistency, but do not signal it
418 for cache in self.__caches.values():
419 cache.clear()
421 reset_cached_properties(self)
422 self._field_trigger_trees.clear()
423 self._is_modifying_relations.clear()
424 self.registry_invalidated = True
426 # model classes on which to *not* recompute field_depends[_context]
427 models_field_depends_done = set()
429 if model_names is None:
430 self.many2many_relations.clear()
431 self.field_setup_dependents.clear()
433 # mark all models for setup
434 for model_cls in self.models.values():
435 model_cls._setup_done__ = False
437 self.field_depends.clear()
438 self.field_depends_context.clear()
440 else:
441 # only mark impacted models for setup and invalidate related fields
442 model_names_to_setup = self.descendants(model_names, '_inherit', '_inherits')
443 for fields in self.many2many_relations.values():
444 for pair in list(fields):
445 if pair[0] in model_names_to_setup:
446 fields.discard(pair)
448 for model_name in model_names_to_setup:
449 self[model_name]._setup_done__ = False
451 # recursively mark fields to re-setup
452 todo = []
453 for model_cls in self.models.values():
454 if model_cls._setup_done__:
455 models_field_depends_done.add(model_cls)
456 else:
457 todo.extend(model_cls._fields.values())
459 done = set()
460 for field in todo:
461 if field in done:
462 continue
464 model_cls = self[field.model_name]
465 if model_cls._setup_done__ and field._base_fields__:
466 # the field has been created by model_classes._setup() as
467 # Field(_base_fields__=...); restore it to force its setup
468 name = field.name
469 base_fields = field._base_fields__
471 field.__dict__.clear()
472 field.__init__(_base_fields__=base_fields)
473 field._toplevel = True
474 field.__set_name__(model_cls, name)
475 field._setup_done = False
477 models_field_depends_done.discard(model_cls)
479 # partial invalidation of field_depends[_context]
480 self.field_depends.pop(field, None)
481 self.field_depends_context.pop(field, None)
483 done.add(field)
484 todo.extend(self.field_setup_dependents.pop(field, ()))
486 self.many2one_company_dependents.clear()
488 model_classes.setup_model_classes(env)
490 # determine field_depends and field_depends_context
491 for model_cls in self.models.values():
492 if model_cls in models_field_depends_done:
493 continue
494 model = model_cls(env, (), ())
495 for field in model._fields.values():
496 depends, depends_context = field.get_depends(model)
497 self.field_depends[field] = tuple(depends)
498 self.field_depends_context[field] = tuple(depends_context)
500 # clean the lazy_property again in case they are cached by another ongoing registry readonly request
501 reset_cached_properties(self)
503 # Reinstall registry hooks. Because of the condition, this only happens
504 # on a fully loaded registry, and not on a registry being loaded.
505 if self.ready: 505 ↛ 506line 505 didn't jump to line 506 because the condition on line 505 was never true
506 for model in env.values():
507 model._register_hook()
508 env.flush_all()
510 @functools.cached_property
511 def field_inverses(self) -> Collector[Field, Field]:
512 result = Collector()
513 for model_cls in self.models.values():
514 for field in model_cls._fields.values():
515 if field.relational:
516 field.setup_inverses(self, result)
517 return result
519 @functools.cached_property
520 def field_computed(self) -> dict[Field, list[Field]]:
521 """ Return a dict mapping each field to the fields computed by the same method. """
522 computed: dict[Field, list[Field]] = {}
523 for model_name, Model in self.models.items():
524 groups: defaultdict[Field, list[Field]] = defaultdict(list)
525 for field in Model._fields.values():
526 if field.compute:
527 computed[field] = group = groups[field.compute]
528 group.append(field)
529 for fields in groups.values():
530 if len(fields) < 2:
531 continue
532 if len({field.compute_sudo for field in fields}) > 1: 532 ↛ 533line 532 didn't jump to line 533 because the condition on line 532 was never true
533 fnames = ", ".join(field.name for field in fields)
534 warnings.warn(
535 f"{model_name}: inconsistent 'compute_sudo' for computed fields {fnames}. "
536 f"Either set 'compute_sudo' to the same value on all those fields, or "
537 f"use distinct compute methods for sudoed and non-sudoed fields.",
538 stacklevel=1,
539 )
540 if len({field.precompute for field in fields}) > 1: 540 ↛ 541line 540 didn't jump to line 541 because the condition on line 540 was never true
541 fnames = ", ".join(field.name for field in fields)
542 warnings.warn(
543 f"{model_name}: inconsistent 'precompute' for computed fields {fnames}. "
544 f"Either set all fields as precompute=True (if possible), or "
545 f"use distinct compute methods for precomputed and non-precomputed fields.",
546 stacklevel=1,
547 )
548 if len({field.store for field in fields}) > 1: 548 ↛ 549line 548 didn't jump to line 549 because the condition on line 548 was never true
549 fnames1 = ", ".join(field.name for field in fields if not field.store)
550 fnames2 = ", ".join(field.name for field in fields if field.store)
551 warnings.warn(
552 f"{model_name}: inconsistent 'store' for computed fields, "
553 f"accessing {fnames1} may recompute and update {fnames2}. "
554 f"Use distinct compute methods for stored and non-stored fields.",
555 stacklevel=1,
556 )
557 return computed
559 def get_trigger_tree(self, fields: list[Field], select: Callable[[Field], bool] = bool) -> TriggerTree:
560 """ Return the trigger tree to traverse when ``fields`` have been modified.
561 The function ``select`` is called on every field to determine which fields
562 should be kept in the tree nodes. This enables to discard some unnecessary
563 fields from the tree nodes.
564 """
565 trees = [
566 self.get_field_trigger_tree(field)
567 for field in fields
568 if field in self._field_triggers
569 ]
570 return TriggerTree.merge(trees, select)
572 def get_dependent_fields(self, field: Field) -> Iterator[Field]:
573 """ Return an iterable on the fields that depend on ``field``. """
574 if field not in self._field_triggers: 574 ↛ 575line 574 didn't jump to line 575 because the condition on line 574 was never true
575 return
577 for tree in self.get_field_trigger_tree(field).depth_first():
578 yield from tree.root
580 def _discard_fields(self, fields: list[Field]) -> None:
581 """ Discard the given fields from the registry's internal data structures. """
582 for f in fields:
583 # tests usually don't reload the registry, so when they create
584 # custom fields those may not have the entire dependency setup, and
585 # may be missing from these maps
586 self.field_depends.pop(f, None)
588 # discard fields from field triggers
589 self.__dict__.pop('_field_triggers', None)
590 self._field_trigger_trees.clear()
591 self._is_modifying_relations.clear()
593 # discard fields from field inverses
594 self.field_inverses.discard_keys_and_values(fields)
596 def get_field_trigger_tree(self, field: Field) -> TriggerTree:
597 """ Return the trigger tree of a field by computing it from the transitive
598 closure of field triggers.
599 """
600 try:
601 return self._field_trigger_trees[field]
602 except KeyError:
603 pass
605 triggers = self._field_triggers
607 if field not in triggers: 607 ↛ 608line 607 didn't jump to line 608 because the condition on line 607 was never true
608 return TriggerTree()
610 def transitive_triggers(field, prefix=(), seen=()):
611 if field in seen or field not in triggers:
612 return
613 for path, targets in triggers[field].items():
614 full_path = concat(prefix, path)
615 yield full_path, targets
616 for target in targets:
617 yield from transitive_triggers(target, full_path, seen + (field,))
619 def concat(seq1, seq2):
620 if seq1 and seq2:
621 f1, f2 = seq1[-1], seq2[0]
622 if (
623 f1.type == 'many2one' and f2.type == 'one2many'
624 and f1.name == f2.inverse_name
625 and f1.model_name == f2.comodel_name
626 and f1.comodel_name == f2.model_name
627 ):
628 return concat(seq1[:-1], seq2[1:])
629 return seq1 + seq2
631 tree = TriggerTree()
632 for path, targets in transitive_triggers(field):
633 current = tree
634 for label in path:
635 current = current.increase(label)
636 if current.root:
637 assert isinstance(current.root, OrderedSet)
638 current.root.update(targets)
639 else:
640 current.root = OrderedSet(targets)
642 self._field_trigger_trees[field] = tree
644 return tree
646 @functools.cached_property
647 def _field_triggers(self) -> defaultdict[Field, defaultdict[tuple[str, ...], OrderedSet[Field]]]:
648 """ Return the field triggers, i.e., the inverse of field dependencies,
649 as a dictionary like ``{field: {path: fields}}``, where ``field`` is a
650 dependency, ``path`` is a sequence of fields to inverse and ``fields``
651 is a collection of fields that depend on ``field``.
652 """
653 triggers: defaultdict[Field, defaultdict[tuple[str, ...], OrderedSet[Field]]] = defaultdict(lambda: defaultdict(OrderedSet))
655 for Model in self.models.values():
656 if Model._abstract:
657 continue
658 for field in Model._fields.values():
659 try:
660 dependencies = list(field.resolve_depends(self))
661 except Exception:
662 # dependencies of custom fields may not exist; ignore that case
663 if not field.base_field.manual:
664 raise
665 else:
666 for dependency in dependencies:
667 *path, dep_field = dependency
668 triggers[dep_field][tuple(reversed(path))].add(field)
670 return triggers
672 def is_modifying_relations(self, field: Field) -> bool:
673 """ Return whether ``field`` has dependent fields on some records, and
674 that modifying ``field`` might change the dependent records.
675 """
676 try:
677 return self._is_modifying_relations[field]
678 except KeyError:
679 result = field in self._field_triggers and bool(
680 field.relational or self.field_inverses[field] or any(
681 dep.relational or self.field_inverses[dep]
682 for dep in self.get_dependent_fields(field)
683 )
684 )
685 self._is_modifying_relations[field] = result
686 return result
688 def post_init(self, func: Callable, *args, **kwargs) -> None:
689 """ Register a function to call at the end of :meth:`~.init_models`. """
690 self._post_init_queue.append(partial(func, *args, **kwargs))
692 def post_constraint(self, cr: BaseCursor, func: Callable[[BaseCursor], None], key) -> None:
693 """ Call the given function, and delay it if it fails during an upgrade. """
694 try:
695 if key not in self._constraint_queue: 695 ↛ 705line 695 didn't jump to line 705 because the condition on line 695 was always true
696 # Module A may try to apply a constraint and fail but another module B inheriting
697 # from Module A may try to reapply the same constraint and succeed, however the
698 # constraint would already be in the _constraint_queue and would be executed again
699 # at the end of the registry cycle, this would fail (already-existing constraint)
700 # and generate an error, therefore a constraint should only be applied if it's
701 # not already marked as "to be applied".
702 with cr.savepoint(flush=False):
703 func(cr)
704 else:
705 self._constraint_queue[key] = func
706 except Exception as e:
707 if self._is_install:
708 _schema.error(*e.args)
709 else:
710 _schema.info(*e.args)
711 self._constraint_queue[key] = func
713 def finalize_constraints(self, cr: Cursor) -> None:
714 """ Call the delayed functions from above. """
715 for func in self._constraint_queue.values(): 715 ↛ 716line 715 didn't jump to line 716 because the loop on line 715 never started
716 try:
717 with cr.savepoint(flush=False):
718 func(cr)
719 except Exception as e:
720 # warn only, this is not a deployment showstopper, and
721 # can sometimes be a transient error
722 _schema.warning(*e.args)
723 self._constraint_queue.clear()
725 def init_models(self, cr: Cursor, model_names: Iterable[str], context: dict[str, typing.Any], install: bool = True):
726 """ Initialize a list of models (given by their name). Call methods
727 ``_auto_init`` and ``init`` on each model to create or update the
728 database tables supporting the models.
730 The ``context`` may contain the following items:
731 - ``module``: the name of the module being installed/updated, if any;
732 - ``update_custom_fields``: whether custom fields should be updated.
733 """
734 if not model_names:
735 return
737 if 'module' in context:
738 _logger.info('module %s: creating or updating database tables', context['module'])
739 elif context.get('models_to_check', False):
740 _logger.info("verifying fields for every extended model")
742 from .environments import Environment # noqa: PLC0415
743 env = Environment(cr, SUPERUSER_ID, context)
744 models = [env[model_name] for model_name in model_names]
746 try:
747 self._post_init_queue: deque[Callable] = deque()
748 # (table1, column1) -> (table2, column2, ondelete, model, module)
749 self._foreign_keys: dict[tuple[str, str], tuple[str, str, str, BaseModel, str]] = {}
750 self._is_install: bool = install
752 for model in models:
753 model._auto_init()
754 model.init()
756 env['ir.model']._reflect_models(model_names)
757 env['ir.model.fields']._reflect_fields(model_names)
758 env['ir.model.fields.selection']._reflect_selections(model_names)
759 env['ir.model.constraint']._reflect_constraints(model_names)
760 env['ir.model.inherit']._reflect_inherits(model_names)
762 self._ordinary_tables = None
764 while self._post_init_queue:
765 func = self._post_init_queue.popleft()
766 func()
768 self.check_indexes(cr, model_names)
769 self.check_foreign_keys(cr)
771 env.flush_all()
773 # make sure all tables are present
774 self.check_tables_exist(cr)
776 finally:
777 del self._post_init_queue
778 del self._foreign_keys
779 del self._is_install
781 def check_null_constraints(self, cr: Cursor) -> None:
782 """ Check that all not-null constraints are set. """
783 cr.execute('''
784 SELECT c.relname, a.attname
785 FROM pg_attribute a
786 JOIN pg_class c ON a.attrelid = c.oid
787 WHERE c.relnamespace = current_schema::regnamespace
788 AND a.attnotnull = true
789 AND a.attnum > 0
790 AND a.attname != 'id';
791 ''')
792 not_null_columns = set(cr.fetchall())
794 self.not_null_fields.clear()
795 for Model in self.models.values():
796 if Model._auto and not Model._abstract:
797 for field_name, field in Model._fields.items():
798 if field_name == 'id':
799 self.not_null_fields.add(field)
800 continue
801 if field.column_type and field.store and field.required:
802 if (Model._table, field_name) in not_null_columns: 802 ↛ 805line 802 didn't jump to line 805 because the condition on line 802 was always true
803 self.not_null_fields.add(field)
804 else:
805 _schema.warning("Missing not-null constraint on %s", field)
807 def check_indexes(self, cr: Cursor, model_names: Iterable[str]) -> None:
808 """ Create or drop column indexes for the given models. """
810 expected = [
811 (sql.make_index_name(Model._table, field.name), Model._table, field)
812 for model_name in model_names
813 for Model in [self.models[model_name]]
814 if Model._auto and not Model._abstract
815 for field in Model._fields.values()
816 if field.column_type and field.store
817 ]
818 if not expected:
819 return
821 # retrieve existing indexes with their corresponding table
822 cr.execute("SELECT indexname, tablename FROM pg_indexes WHERE indexname IN %s"
823 " AND schemaname = current_schema",
824 [tuple(row[0] for row in expected)])
825 existing = dict(cr.fetchall())
827 for indexname, tablename, field in expected:
828 index = field.index
829 assert index in ('btree', 'btree_not_null', 'trigram', True, False, None)
830 if index and indexname not in existing:
831 if index == 'trigram' and not self.has_trigram:
832 # Ignore if trigram index is not supported
833 continue
834 if field.translate and index != 'trigram': 834 ↛ 835line 834 didn't jump to line 835 because the condition on line 834 was never true
835 _schema.warning(f"Index attribute on {field!r} ignored, only trigram index is supported for translated fields")
836 continue
838 column_expression = f'"{field.name}"'
839 if index == 'trigram': 839 ↛ 840line 839 didn't jump to line 840 because the condition on line 839 was never true
840 if field.translate:
841 column_expression = f'''(jsonb_path_query_array({column_expression}, '$.*')::text)'''
842 # add `unaccent` to the trigram index only because the
843 # trigram indexes are mainly used for (=)ilike search and
844 # unaccent is added only in these cases when searching
845 from odoo.modules.db import FunctionStatus # noqa: PLC0415
846 if self.has_unaccent == FunctionStatus.INDEXABLE:
847 column_expression = self.unaccent(column_expression)
848 elif self.has_unaccent:
849 warnings.warn(
850 "PostgreSQL function 'unaccent' is present but not immutable, "
851 "therefore trigram indexes may not be effective.",
852 stacklevel=1,
853 )
854 expression = f'{column_expression} gin_trgm_ops'
855 method = 'gin'
856 where = ''
857 elif index == 'btree_not_null' and field.company_dependent:
858 # company dependent condition will use extra
859 # `AND col IS NOT NULL` to use the index.
860 expression = f'({column_expression} IS NOT NULL)'
861 method = 'btree'
862 where = f'{column_expression} IS NOT NULL'
863 else: # index in ['btree', 'btree_not_null', True]
864 expression = f'{column_expression}'
865 method = 'btree'
866 where = f'{column_expression} IS NOT NULL' if index == 'btree_not_null' else ''
867 try:
868 with cr.savepoint(flush=False):
869 sql.create_index(cr, indexname, tablename, [expression], method, where)
870 except psycopg2.OperationalError:
871 _schema.error("Unable to add index %r for %s", indexname, self)
873 elif not index and tablename == existing.get(indexname):
874 _schema.info("Keep unexpected index %s on table %s", indexname, tablename)
876 def add_foreign_key(
877 self, table1: str, column1: str, table2: str, column2: str,
878 ondelete: str, model: BaseModel, module: str,
879 force: bool = True,
880 ) -> None:
881 """ Specify an expected foreign key. """
882 key = (table1, column1)
883 val = (table2, column2, ondelete, model, module)
884 if force:
885 self._foreign_keys[key] = val
886 else:
887 self._foreign_keys.setdefault(key, val)
889 def check_foreign_keys(self, cr: Cursor) -> None:
890 """ Create or update the expected foreign keys. """
891 if not self._foreign_keys:
892 return
894 # determine existing foreign keys on the tables
895 query = """
896 SELECT fk.conname, c1.relname, a1.attname, c2.relname, a2.attname, fk.confdeltype
897 FROM pg_constraint AS fk
898 JOIN pg_class AS c1 ON fk.conrelid = c1.oid
899 JOIN pg_class AS c2 ON fk.confrelid = c2.oid
900 JOIN pg_attribute AS a1 ON a1.attrelid = c1.oid AND fk.conkey[1] = a1.attnum
901 JOIN pg_attribute AS a2 ON a2.attrelid = c2.oid AND fk.confkey[1] = a2.attnum
902 WHERE fk.contype = 'f' AND c1.relname IN %s
903 AND c1.relnamespace = current_schema::regnamespace
904 """
905 cr.execute(query, [tuple({table for table, column in self._foreign_keys})])
906 existing = {
907 (table1, column1): (name, table2, column2, deltype)
908 for name, table1, column1, table2, column2, deltype in cr.fetchall()
909 }
911 # create or update foreign keys
912 for key, val in self._foreign_keys.items():
913 table1, column1 = key
914 table2, column2, ondelete, model, module = val
915 deltype = sql._CONFDELTYPES[ondelete.upper()]
916 spec = existing.get(key)
917 if spec is None:
918 sql.add_foreign_key(cr, table1, column1, table2, column2, ondelete)
919 conname = sql.get_foreign_keys(cr, table1, column1, table2, column2, ondelete)[0]
920 model.env['ir.model.constraint']._reflect_constraint(model, conname, 'f', None, module)
921 elif (spec[1], spec[2], spec[3]) != (table2, column2, deltype):
922 sql.drop_constraint(cr, table1, spec[0])
923 sql.add_foreign_key(cr, table1, column1, table2, column2, ondelete)
924 conname = sql.get_foreign_keys(cr, table1, column1, table2, column2, ondelete)[0]
925 model.env['ir.model.constraint']._reflect_constraint(model, conname, 'f', None, module)
927 def check_tables_exist(self, cr: Cursor) -> None:
928 """
929 Verify that all tables are present and try to initialize those that are missing.
930 """
931 from .environments import Environment # noqa: PLC0415
932 env = Environment(cr, SUPERUSER_ID, {})
933 table2model = {
934 model._table: name
935 for name, model in env.registry.items()
936 if not model._abstract and not model._table_query
937 }
938 missing_tables = set(table2model).difference(sql.existing_tables(cr, table2model))
940 if missing_tables: 940 ↛ 941line 940 didn't jump to line 941 because the condition on line 940 was never true
941 missing = {table2model[table] for table in missing_tables}
942 _logger.info("Models have no table: %s.", ", ".join(missing))
943 # recreate missing tables
944 for name in missing:
945 _logger.info("Recreate table of model %s.", name)
946 env[name].init()
947 env.flush_all()
948 # check again, and log errors if tables are still missing
949 missing_tables = set(table2model).difference(sql.existing_tables(cr, table2model))
950 for table in missing_tables:
951 _logger.error("Model %s has no table.", table2model[table])
953 def clear_cache(self, *cache_names: str) -> None:
954 """ Clear the caches associated to methods decorated with
955 ``tools.ormcache``if cache is in `cache_name` subset. """
956 cache_names = cache_names or ('default',)
957 assert not any('.' in cache_name for cache_name in cache_names)
958 for cache_name in cache_names:
959 for cache in _CACHES_BY_KEY[cache_name]:
960 self.__caches[cache].clear()
961 self.cache_invalidated.add(cache_name)
963 # log information about invalidation_cause
964 if _logger.isEnabledFor(logging.DEBUG): 964 ↛ 967line 964 didn't jump to line 967 because the condition on line 964 was never true
965 # could be interresting to log in info but this will need to minimize invalidation first,
966 # mainly in some setupclass and crons
967 caller_info = format_frame(inspect.currentframe().f_back) # type: ignore
968 _logger.debug('Invalidating %s model caches from %s', ','.join(cache_names), caller_info)
970 def clear_all_caches(self) -> None:
971 """ Clear the caches associated to methods decorated with
972 ``tools.ormcache``.
973 """
974 for cache_name, caches in _CACHES_BY_KEY.items():
975 for cache in caches:
976 self.__caches[cache].clear()
977 self.cache_invalidated.add(cache_name)
979 caller_info = format_frame(inspect.currentframe().f_back) # type: ignore
980 log = _logger.info if self.loaded else _logger.debug
981 log('Invalidating all model caches from %s', caller_info)
983 def is_an_ordinary_table(self, model: BaseModel) -> bool:
984 """ Return whether the given model has an ordinary table. """
985 if self._ordinary_tables is None:
986 cr = model.env.cr
987 query = """
988 SELECT c.relname
989 FROM pg_class c
990 WHERE c.relname IN %s
991 AND c.relkind = 'r'
992 AND c.relnamespace = current_schema::regnamespace
993 """
994 tables = tuple(m._table for m in self.models.values())
995 cr.execute(query, [tables])
996 self._ordinary_tables = {row[0] for row in cr.fetchall()}
998 return model._table in self._ordinary_tables
1000 @property
1001 def registry_invalidated(self) -> bool:
1002 """ Determine whether the current thread has modified the registry. """
1003 return getattr(self._invalidation_flags, 'registry', False)
1005 @registry_invalidated.setter
1006 def registry_invalidated(self, value: bool):
1007 self._invalidation_flags.registry = value
1009 @property
1010 def cache_invalidated(self) -> set[str]:
1011 """ Determine whether the current thread has modified the cache. """
1012 try:
1013 return self._invalidation_flags.cache
1014 except AttributeError:
1015 names = self._invalidation_flags.cache = set()
1016 return names
1018 def setup_signaling(self) -> None:
1019 """ Setup the inter-process signaling on this registry. """
1020 with self.cursor() as cr:
1021 # The `orm_signaling_registry` sequence indicates when the registry
1022 # must be reloaded.
1023 # The `orm_signaling_...` sequences indicates when caches must
1024 # be invalidated (i.e. cleared).
1025 signaling_tables = tuple(f'orm_signaling_{cache_name}' for cache_name in ['registry', *_CACHES_BY_KEY])
1026 cr.execute("SELECT table_name FROM information_schema.tables"
1027 " WHERE table_name IN %s AND table_schema = current_schema", [signaling_tables])
1029 existing_sig_tables = tuple(s[0] for s in cr.fetchall()) # could be a set but not efficient with such a little list
1030 # signaling was previously using sequence but this doesn't work with replication
1031 # https://www.postgresql.org/docs/current/logical-replication-restrictions.html
1032 # this is the reason why insert only tables are used.
1033 for table_name in signaling_tables:
1034 if table_name not in existing_sig_tables: 1034 ↛ 1033line 1034 didn't jump to line 1033 because the condition on line 1034 was always true
1035 cr.execute(SQL(
1036 "CREATE TABLE %s (id SERIAL PRIMARY KEY, date TIMESTAMP DEFAULT now())",
1037 SQL.identifier(table_name),
1038 ))
1039 cr.execute(SQL("INSERT INTO %s DEFAULT VALUES", SQL.identifier(table_name)))
1041 db_registry_sequence, db_cache_sequences = self.get_sequences(cr)
1042 self.registry_sequence = db_registry_sequence
1043 self.cache_sequences.update(db_cache_sequences)
1045 _logger.debug("Multiprocess load registry signaling: [Registry: %s] %s",
1046 self.registry_sequence, ' '.join('[Cache %s: %s]' % cs for cs in self.cache_sequences.items()))
1048 def get_sequences(self, cr: BaseCursor) -> tuple[int, dict[str, int]]:
1049 signaling_tables = tuple(f'orm_signaling_{cache_name}' for cache_name in ['registry', *_CACHES_BY_KEY])
1050 signaling_selects = SQL(', ').join([SQL('( SELECT max(id) FROM %s)', SQL.identifier(signaling_table)) for signaling_table in signaling_tables])
1051 cr.execute(SQL("SELECT %s", signaling_selects))
1052 row = cr.fetchone()
1053 assert row is not None, "No result when reading signaling sequences"
1054 registry_sequence, *cache_sequences_values = row
1055 cache_sequences = dict(zip(_CACHES_BY_KEY, cache_sequences_values))
1056 return registry_sequence, cache_sequences
1058 def check_signaling(self, cr: BaseCursor | None = None) -> Registry:
1059 """ Check whether the registry has changed, and performs all necessary
1060 operations to update the registry. Return an up-to-date registry.
1061 """
1062 with nullcontext(cr) if cr is not None else closing(self.cursor(readonly=True)) as cr:
1063 assert cr is not None
1064 db_registry_sequence, db_cache_sequences = self.get_sequences(cr)
1065 changes = ''
1066 # Check if the model registry must be reloaded
1067 if self.registry_sequence != db_registry_sequence:
1068 _logger.info("Reloading the model registry after database signaling.")
1069 self = Registry.new(self.db_name)
1070 self.registry_sequence = db_registry_sequence
1071 if _logger.isEnabledFor(logging.DEBUG):
1072 changes += "[Registry - %s -> %s]" % (self.registry_sequence, db_registry_sequence)
1073 # Check if the model caches must be invalidated.
1074 else:
1075 invalidated = []
1076 for cache_name, cache_sequence in self.cache_sequences.items():
1077 expected_sequence = db_cache_sequences[cache_name]
1078 if cache_sequence != expected_sequence:
1079 for cache in _CACHES_BY_KEY[cache_name]: # don't call clear_cache to avoid signal loop
1080 if cache not in invalidated:
1081 invalidated.append(cache)
1082 self.__caches[cache].clear()
1083 self.cache_sequences[cache_name] = expected_sequence
1084 if _logger.isEnabledFor(logging.DEBUG):
1085 changes += "[Cache %s - %s -> %s]" % (cache_name, cache_sequence, expected_sequence)
1086 if invalidated:
1087 _logger.info("Invalidating caches after database signaling: %s", sorted(invalidated))
1088 if changes:
1089 _logger.debug("Multiprocess signaling check: %s", changes)
1090 return self
1092 def signal_changes(self) -> None:
1093 """ Notifies other processes if registry or cache has been invalidated. """
1094 if not self.ready: 1094 ↛ 1095line 1094 didn't jump to line 1095 because the condition on line 1094 was never true
1095 _logger.warning('Calling signal_changes when registry is not ready is not suported')
1096 return
1098 if self.registry_invalidated: 1098 ↛ 1110line 1098 didn't jump to line 1110 because the condition on line 1098 was always true
1099 _logger.info("Registry changed, signaling through the database")
1100 with self.cursor() as cr:
1101 cr.execute("INSERT INTO orm_signaling_registry DEFAULT VALUES")
1102 # If another process concurrently updates the registry,
1103 # self.registry_sequence will actually be out-of-date,
1104 # and the next call to check_signaling() will detect that and trigger a registry reload.
1105 # otherwise, self.registry_sequence should be equal to cr.fetchone()[0]
1106 self.registry_sequence += 1
1108 # no need to notify cache invalidation in case of registry invalidation,
1109 # because reloading the registry implies starting with an empty cache
1110 elif self.cache_invalidated:
1111 _logger.info("Caches invalidated, signaling through the database: %s", sorted(self.cache_invalidated))
1112 with self.cursor() as cr:
1113 for cache_name in self.cache_invalidated:
1114 cr.execute(SQL("INSERT INTO %s DEFAULT VALUES", SQL.identifier(f'orm_signaling_{cache_name}')))
1115 # If another process concurrently updates the cache,
1116 # self.cache_sequences[cache_name] will actually be out-of-date,
1117 # and the next call to check_signaling() will detect that and trigger cache invalidation.
1118 # otherwise, self.cache_sequences[cache_name] should be equal to cr.fetchone()[0]
1119 self.cache_sequences[cache_name] += 1
1121 self.registry_invalidated = False
1122 self.cache_invalidated.clear()
1124 def reset_changes(self) -> None:
1125 """ Reset the registry and cancel all invalidations. """
1126 if self.registry_invalidated:
1127 with closing(self.cursor()) as cr:
1128 self._setup_models__(cr)
1129 self.registry_invalidated = False
1130 if self.cache_invalidated:
1131 for cache_name in self.cache_invalidated:
1132 for cache in _CACHES_BY_KEY[cache_name]:
1133 self.__caches[cache].clear()
1134 self.cache_invalidated.clear()
1136 @contextmanager
1137 def manage_changes(self):
1138 """ Context manager to signal/discard registry and cache invalidations. """
1139 warnings.warn("Since 19.0, use signal_changes() and reset_changes() directly", DeprecationWarning)
1140 try:
1141 yield self
1142 self.signal_changes()
1143 except Exception:
1144 self.reset_changes()
1145 raise
1147 def cursor(self, /, readonly: bool = False) -> BaseCursor:
1148 """ Return a new cursor for the database. The cursor itself may be used
1149 as a context manager to commit/rollback and close automatically.
1151 :param readonly: Attempt to acquire a cursor on a replica database.
1152 Acquire a read/write cursor on the primary database in case no
1153 replica exists or that no readonly cursor could be acquired.
1154 """
1155 if readonly and self._db_readonly is not None: 1155 ↛ 1156line 1155 didn't jump to line 1156 because the condition on line 1155 was never true
1156 if (
1157 self._db_readonly_failed_time is None
1158 or time.monotonic() > self._db_readonly_failed_time + _REPLICA_RETRY_TIME
1159 ):
1160 try:
1161 cr = self._db_readonly.cursor()
1162 self._db_readonly_failed_time = None
1163 return cr
1164 except psycopg2.OperationalError:
1165 self._db_readonly_failed_time = time.monotonic()
1166 _logger.warning("Failed to open a readonly cursor, falling back to read-write cursor for %dmin %dsec", *divmod(_REPLICA_RETRY_TIME, 60))
1167 threading.current_thread().cursor_mode = 'ro->rw'
1168 return self._db.cursor()
1171class DummyRLock(object):
1172 """ Dummy reentrant lock, to be used while running rpc and js tests """
1173 def acquire(self):
1174 pass
1175 def release(self):
1176 pass
1177 def __enter__(self):
1178 self.acquire()
1179 def __exit__(self, type, value, traceback):
1180 self.release()
1183class TriggerTree(dict['Field', 'TriggerTree']):
1184 """ The triggers of a field F is a tree that contains the fields that
1185 depend on F, together with the fields to inverse to find out which records
1186 to recompute.
1188 For instance, assume that G depends on F, H depends on X.F, I depends on
1189 W.X.F, and J depends on Y.F. The triggers of F will be the tree:
1191 [G]
1192 X/ \\Y
1193 [H] [J]
1194 W/
1195 [I]
1197 This tree provides perfect support for the trigger mechanism:
1198 when F is # modified on records,
1199 - mark G to recompute on records,
1200 - mark H to recompute on inverse(X, records),
1201 - mark I to recompute on inverse(W, inverse(X, records)),
1202 - mark J to recompute on inverse(Y, records).
1203 """
1204 __slots__ = ['root']
1205 root: Collection[Field]
1207 # pylint: disable=keyword-arg-before-vararg
1208 def __init__(self, root: Collection[Field] = (), *args, **kwargs):
1209 super().__init__(*args, **kwargs)
1210 self.root = root
1212 def __bool__(self) -> bool:
1213 return bool(self.root or len(self))
1215 def __repr__(self) -> str:
1216 return f"TriggerTree(root={self.root!r}, {super().__repr__()})"
1218 def increase(self, key: Field) -> TriggerTree:
1219 try:
1220 return self[key]
1221 except KeyError:
1222 subtree = self[key] = TriggerTree()
1223 return subtree
1225 def depth_first(self) -> Iterator[TriggerTree]:
1226 yield self
1227 for subtree in self.values():
1228 yield from subtree.depth_first()
1230 @classmethod
1231 def merge(cls, trees: list[TriggerTree], select: Callable[[Field], bool] = bool) -> TriggerTree:
1232 """ Merge trigger trees into a single tree. The function ``select`` is
1233 called on every field to determine which fields should be kept in the
1234 tree nodes. This enables to discard some fields from the tree nodes.
1235 """
1236 root_fields: OrderedSet[Field] = OrderedSet() # fields in the root node
1237 subtrees_to_merge = defaultdict(list) # subtrees to merge grouped by key
1239 for tree in trees:
1240 root_fields.update(tree.root)
1241 for label, subtree in tree.items():
1242 subtrees_to_merge[label].append(subtree)
1244 # the root node contains the collected fields for which select is true
1245 result = cls([field for field in root_fields if select(field)])
1246 for label, subtrees in subtrees_to_merge.items():
1247 subtree = cls.merge(subtrees, select)
1248 if subtree:
1249 result[label] = subtree
1251 return result