Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New disambiguator tweaks #435

Merged
merged 4 commits into from
Oct 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@
- **Potentially breaking**: {py:func}`cattrs.gen.make_dict_structure_fn` and {py:func}`cattrs.gen.typeddicts.make_dict_structure_fn` will use the values for the `detailed_validation` and `forbid_extra_keys` parameters from the given converter by default now.
If you're using these functions directly, the old behavior can be restored by passing in the desired values directly.
([#410](https://github.com/python-attrs/cattrs/issues/410) [#411](https://github.com/python-attrs/cattrs/pull/411))
- **Potentially breaking**: The default union structuring strategy will also use fields annotated as `typing.Literal` to help guide structuring.
([#391](https://github.com/python-attrs/cattrs/pull/391))
- Python 3.12 is now supported. Python 3.7 is no longer supported; use older releases there.
([#424](https://github.com/python-attrs/cattrs/pull/424))
- Implement the `union passthrough` strategy, enabling much richer union handling for preconfigured converters. [Learn more here](https://catt.rs/en/stable/strategies.html#union-passthrough).
- Introduce the `use_class_methods` strategy. Learn more [here](https://catt.rs/en/latest/strategies.html#using-class-specific-structure-and-unstructure-methods).
([#405](https://github.com/python-attrs/cattrs/pull/405))
- Implement the `union passthrough` strategy, enabling much richer union handling for preconfigured converters. [Learn more here](https://catt.rs/en/stable/strategies.html#union-passthrough).
- The `omit` parameter of {py:func}`cattrs.override` is now of type `bool | None` (from `bool`).
`None` is the new default and means to apply default _cattrs_ handling to the attribute, which is to omit the attribute if it's marked as `init=False`, and keep it otherwise.
- Fix {py:func}`format_exception() <cattrs.v.format_exception>` parameter working for recursive calls to {py:func}`transform_error <cattrs.transform_error>`.
([#389](https://github.com/python-attrs/cattrs/issues/389))
- [_attrs_ aliases](https://www.attrs.org/en/stable/init.html#private-attributes-and-aliases) are now supported, although aliased fields still map to their attribute name instead of their alias by default when un/structuring.
([#322](https://github.com/python-attrs/cattrs/issues/322) [#391](https://github.com/python-attrs/cattrs/pull/391))
- Use [PDM](https://pdm.fming.dev/latest/) instead of Poetry.
- _cattrs_ is now linted with [Ruff](https://beta.ruff.rs/docs/).
- Fix TypedDicts with periods in their field names.
([#376](https://github.com/python-attrs/cattrs/issues/376) [#377](https://github.com/python-attrs/cattrs/pull/377))
- Optimize and improve unstructuring of `Optional` (unions of one type and `None`).
Expand All @@ -45,10 +45,10 @@
([#420](https://github.com/python-attrs/cattrs/pull/420))
- Add support for `datetime.date`s to the PyYAML preconfigured converter.
([#393](https://github.com/python-attrs/cattrs/issues/393))
- Use [PDM](https://pdm.fming.dev/latest/) instead of Poetry.
- _cattrs_ is now linted with [Ruff](https://beta.ruff.rs/docs/).
- Remove some unused lines in the unstructuring code.
([#416](https://github.com/python-attrs/cattrs/pull/416))
- Disambiguate a union of attrs classes where there's a `typing.Literal` tag of some sort.
([#391](https://github.com/python-attrs/cattrs/pull/391))

## 23.1.2 (2023-06-02)

Expand Down
64 changes: 62 additions & 2 deletions docs/unions.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,75 @@

This sections contains information for advanced union handling.

As mentioned in the structuring section, _cattrs_ is able to handle simple unions of _attrs_ classes automatically.
As mentioned in the structuring section, _cattrs_ is able to handle simple unions of _attrs_ classes [automatically](#default-union-strategy).
More complex cases require converter customization (since there are many ways of handling unions).

_cattrs_ also comes with a number of strategies to help handle unions:

- [tagged unions strategy](strategies.md#tagged-unions-strategy) mentioned below
- [union passthrough strategy](strategies.md#union-passthrough), which is preapplied to all the [preconfigured](preconf.md) converters

## Unstructuring unions with extra metadata
## Default Union Strategy

For convenience, _cattrs_ includes a default union structuring strategy which is a little more opinionated.

Given a union of several _attrs_ classes, the default union strategy will attempt to handle it in several ways.

First, it will look for `Literal` fields.
If all members of the union contain a literal field, _cattrs_ will generate a disambiguation function based on the field.

```python
from typing import Literal

@define
class ClassA:
field_one: Literal["one"]

@define
class ClassB:
field_one: Literal["two"]
```

In this case, a payload containing `{"field_one": "one"}` will produce an instance of `ClassA`.

````{note}
The following snippet can be used to disable the use of literal fields, restoring the previous behavior.

```python
from functools import partial
from cattrs.disambiguators import is_supported_union

converter.register_structure_hook_factory(
is_supported_union,
partial(converter._gen_attrs_union_structure, use_literals=False),
)
```

````

If there are no appropriate fields, the strategy will examine the classes for **unique required fields**.

So, given a union of `ClassA` and `ClassB`:

```python
@define
class ClassA:
field_one: str
field_with_default: str = "a default"

@define
class ClassB:
field_two: str
```

the strategy will determine that if a payload contains the key `field_one` it should be handled as `ClassA`, and if it contains the key `field_two` it should be handled as `ClassB`.
The field `field_with_default` will not be considered since it has a default value, so it gets treated as optional.

```{versionchanged} 23.2.0
Literals can now be potentially used to disambiguate.
```

## Unstructuring Unions with Extra Metadata

```{note}
_cattrs_ comes with the [tagged unions strategy](strategies.md#tagged-unions-strategy) for handling this exact use-case since version 23.1.
Expand Down
32 changes: 12 additions & 20 deletions src/cattrs/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
Union,
)

from attr import Attribute
from attr import has as attrs_has
from attr import resolve_types
from attrs import Attribute
from attrs import has as attrs_has
from attrs import resolve_types

from ._compat import (
FrozenSetSubscriptable,
Expand Down Expand Up @@ -55,7 +55,7 @@
is_typeddict,
is_union_type,
)
from .disambiguators import create_default_dis_func
from .disambiguators import create_default_dis_func, is_supported_union
from .dispatch import MultiStrategyDispatch
from .errors import (
IterableValidationError,
Expand Down Expand Up @@ -96,16 +96,6 @@ def _subclass(typ: Type) -> Callable[[Type], bool]:
return lambda cls: issubclass(cls, typ)


def is_attrs_union(typ: Type) -> bool:
return is_union_type(typ) and all(has(get_origin(e) or e) for e in typ.__args__)


def is_attrs_union_or_none(typ: Type) -> bool:
return is_union_type(typ) and all(
e is NoneType or has(get_origin(e) or e) for e in typ.__args__
)


def is_optional(typ: Type) -> bool:
return is_union_type(typ) and NoneType in typ.__args__ and len(typ.__args__) == 2

Expand Down Expand Up @@ -204,7 +194,7 @@ def __init__(
(is_frozenset, self._structure_frozenset),
(is_tuple, self._structure_tuple),
(is_mapping, self._structure_dict),
(is_attrs_union_or_none, self._gen_attrs_union_structure, True),
(is_supported_union, self._gen_attrs_union_structure, True),
(
lambda t: is_union_type(t) and t in self._union_struct_registry,
self._structure_union,
Expand Down Expand Up @@ -411,17 +401,19 @@ def _gen_structure_generic(self, cl: Type[T]) -> DictStructureFn[T]:
)

def _gen_attrs_union_structure(
self, cl: Any
self, cl: Any, use_literals: bool = True
) -> Callable[[Any, Type[T]], Optional[Type[T]]]:
"""
Generate a structuring function for a union of attrs classes (and maybe None).

:param use_literals: Whether to consider literal fields.
"""
dis_fn = self._get_dis_func(cl)
dis_fn = self._get_dis_func(cl, use_literals=use_literals)
has_none = NoneType in cl.__args__

if has_none:

def structure_attrs_union(obj, _):
def structure_attrs_union(obj, _) -> cl:
if obj is None:
return None
return self.structure(obj, dis_fn(obj))
Expand Down Expand Up @@ -719,7 +711,7 @@ def _structure_tuple(self, obj: Any, tup: Type[T]) -> T:
return res

@staticmethod
def _get_dis_func(union: Any) -> Callable[[Any], Type]:
def _get_dis_func(union: Any, use_literals: bool = True) -> Callable[[Any], Type]:
"""Fetch or try creating a disambiguation function for a union."""
union_types = union.__args__
if NoneType in union_types: # type: ignore
Expand All @@ -738,7 +730,7 @@ def _get_dis_func(union: Any) -> Callable[[Any], Type]:
type_=union,
)

return create_default_dis_func(*union_types)
return create_default_dis_func(*union_types, use_literals=use_literals)

def __deepcopy__(self, _) -> "BaseConverter":
return self.copy()
Expand Down
119 changes: 67 additions & 52 deletions src/cattrs/disambiguators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,71 +2,86 @@
from collections import OrderedDict, defaultdict
from functools import reduce
from operator import or_
from typing import Any, Callable, Dict, Mapping, Optional, Type, Union
from typing import Any, Callable, Dict, Mapping, Optional, Set, Type, Union

from attr import NOTHING, fields, fields_dict
from attrs import NOTHING, fields, fields_dict

from cattrs._compat import get_args, get_origin, is_literal
from ._compat import get_args, get_origin, has, is_literal, is_union_type

__all__ = ("is_supported_union", "create_default_dis_func")

NoneType = type(None)


def is_supported_union(typ: Type) -> bool:
"""Whether the type is a union of attrs classes."""
return is_union_type(typ) and all(
e is NoneType or has(get_origin(e) or e) for e in typ.__args__
)


def create_default_dis_func(
*classes: Type[Any],
*classes: Type[Any], use_literals: bool = True
) -> Callable[[Mapping[Any, Any]], Optional[Type[Any]]]:
"""Given attr classes, generate a disambiguation function.
"""Given attrs classes, generate a disambiguation function.

The function is based on unique fields or unique values.

The function is based on unique fields or unique values."""
:param use_literals: Whether to try using fields annotated as literals for
disambiguation.
"""
if len(classes) < 2:
raise ValueError("At least two classes required.")

# first, attempt for unique values
if use_literals:
# requirements for a discriminator field:
# (... TODO: a single fallback is OK)
# - it must always be enumerated
cls_candidates = [
{at.name for at in fields(get_origin(cl) or cl) if is_literal(at.type)}
for cl in classes
]

# literal field names common to all members
discriminators: Set[str] = cls_candidates[0]
for possible_discriminators in cls_candidates:
discriminators &= possible_discriminators

best_result = None
best_discriminator = None
for discriminator in discriminators:
# maps Literal values (strings, ints...) to classes
mapping = defaultdict(list)

for cl in classes:
for key in get_args(
fields_dict(get_origin(cl) or cl)[discriminator].type
):
mapping[key].append(cl)

if best_result is None or max(len(v) for v in mapping.values()) <= max(
len(v) for v in best_result.values()
):
best_result = mapping
best_discriminator = discriminator

if (
best_result
and best_discriminator
and max(len(v) for v in best_result.values()) != len(classes)
):
final_mapping = {
k: v[0] if len(v) == 1 else Union[tuple(v)]
for k, v in best_result.items()
}

# requirements for a discriminator field:
# (... TODO: a single fallback is OK)
# - it must be *required*
# - it must always be enumerated
cls_candidates = [
{
at.name
for at in fields(get_origin(cl) or cl)
if at.default is NOTHING and is_literal(at.type)
}
for cl in classes
]

discriminators = cls_candidates[0]
for possible_discriminators in cls_candidates:
discriminators &= possible_discriminators

best_result = None
best_discriminator = None
for discriminator in discriminators:
mapping = defaultdict(list)

for cl in classes:
for key in get_args(fields_dict(get_origin(cl) or cl)[discriminator].type):
mapping[key].append(cl)
def dis_func(data: Mapping[Any, Any]) -> Optional[Type]:
if not isinstance(data, Mapping):
raise ValueError("Only input mappings are supported.")
return final_mapping[data[best_discriminator]]

if best_result is None or max(len(v) for v in mapping.values()) <= max(
len(v) for v in best_result.values()
):
best_result = mapping
best_discriminator = discriminator

if (
best_result
and best_discriminator
and max(len(v) for v in best_result.values()) != len(classes)
):
final_mapping = {
k: v[0] if len(v) == 1 else Union[tuple(v)] for k, v in best_result.items()
}

def dis_func(data: Mapping[Any, Any]) -> Optional[Type]:
if not isinstance(data, Mapping):
raise ValueError("Only input mappings are supported.")
return final_mapping[data[best_discriminator]]

return dis_func
return dis_func

# next, attempt for unique keys

Expand Down
Loading