Skip to content

Commit e61527c

Browse files
committed
Better narrowing for enums and other known types with equality
1 parent 93bc02b commit e61527c

File tree

4 files changed

+352
-115
lines changed

4 files changed

+352
-115
lines changed

mypy/checker.py

Lines changed: 202 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __init__(self) -> None:
108108
from mypy.expandtype import expand_type
109109
from mypy.literals import Key, extract_var_from_literal_hash, literal, literal_hash
110110
from mypy.maptype import map_instance_to_supertype
111-
from mypy.meet import is_overlapping_erased_types, is_overlapping_types, meet_types
111+
from mypy.meet import is_overlapping_types, meet_types
112112
from mypy.message_registry import ErrorMessage
113113
from mypy.messages import (
114114
SUGGESTED_TEST_FIXTURES,
@@ -6720,22 +6720,6 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
67206720
narrowable_indices={0},
67216721
)
67226722

6723-
# TODO: This remove_optional code should no longer be needed. The only
6724-
# thing it does is paper over a pre-existing deficiency in equality
6725-
# narrowing w.r.t to enums.
6726-
# We only try and narrow away 'None' for now
6727-
if (
6728-
not is_unreachable_map(if_map)
6729-
and is_overlapping_none(item_type)
6730-
and not is_overlapping_none(collection_item_type)
6731-
and not (
6732-
isinstance(collection_item_type, Instance)
6733-
and collection_item_type.type.fullname == "builtins.object"
6734-
)
6735-
and is_overlapping_erased_types(item_type, collection_item_type)
6736-
):
6737-
if_map[operands[left_index]] = remove_optional(item_type)
6738-
67396723
if right_index in narrowable_operand_index_to_hash:
67406724
if_type, else_type = self.conditional_types_for_iterable(
67416725
item_type, iterable_type
@@ -6820,17 +6804,15 @@ def narrow_type_by_identity_equality(
68206804
# have to be more careful about what narrowing we can conclude from a successful comparison
68216805
custom_eq_indices: set[int]
68226806

6823-
# enum_comparison_is_ambiguous:
6824-
# `if x is Fruits.APPLE` we know `x` is `Fruits.APPLE`, but `if x == Fruits.APPLE: ...`
6825-
# it could e.g. be an int or str if Fruits is an IntEnum or StrEnum.
6826-
# See ambiguous_enum_equality_keys for more details
6827-
enum_comparison_is_ambiguous: bool
6807+
# Equality can use value semantics, so `if x == Fruits.APPLE: ...` may also
6808+
# match non-enum values for IntEnum/StrEnum-like enums. Identity checks don't
6809+
# have this ambiguity.
6810+
is_identity_comparison = operator in {"is", "is not"}
68286811

6829-
if operator in {"is", "is not"}:
6812+
if is_identity_comparison:
68306813
is_target_for_value_narrowing = is_singleton_identity_type
68316814
should_coerce_literals = True
68326815
custom_eq_indices = set()
6833-
enum_comparison_is_ambiguous = False
68346816

68356817
elif operator in {"==", "!="}:
68366818
is_target_for_value_narrowing = is_singleton_equality_type
@@ -6843,7 +6825,6 @@ def narrow_type_by_identity_equality(
68436825
break
68446826

68456827
custom_eq_indices = {i for i in expr_indices if has_custom_eq_checks(operand_types[i])}
6846-
enum_comparison_is_ambiguous = True
68476828
else:
68486829
raise AssertionError
68496830

@@ -6859,8 +6840,6 @@ def narrow_type_by_identity_equality(
68596840
continue
68606841

68616842
expr_type = operand_types[i]
6862-
expr_enum_keys = ambiguous_enum_equality_keys(expr_type)
6863-
expr_type = try_expanding_sum_type_to_union(coerce_to_literal(expr_type), None)
68646843
for j in expr_indices:
68656844
if i == j:
68666845
continue
@@ -6872,18 +6851,30 @@ def narrow_type_by_identity_equality(
68726851
if should_coerce_literals:
68736852
target_type = coerce_to_literal(target_type)
68746853

6875-
if (
6876-
# See comments in ambiguous_enum_equality_keys
6877-
enum_comparison_is_ambiguous
6878-
and len(expr_enum_keys | ambiguous_enum_equality_keys(target_type)) > 1
6879-
):
6880-
continue
6854+
narrowable_expr_type, ambiguous_expr_type = partition_equality_ambiguous_types(
6855+
expr_type, target_type, is_identity=is_identity_comparison
6856+
)
68816857

6882-
target = TypeRange(target_type, is_upper_bound=False)
6858+
if narrowable_expr_type is None:
6859+
if_type = else_type = ambiguous_expr_type
6860+
else:
6861+
narrowable_expr_type = try_expanding_sum_type_to_union(
6862+
coerce_to_literal(narrowable_expr_type), None
6863+
)
6864+
if_type, else_type = conditional_types(
6865+
narrowable_expr_type,
6866+
[TypeRange(target_type, is_upper_bound=False)],
6867+
from_equality=True,
6868+
)
6869+
if ambiguous_expr_type is not None:
6870+
if_type = make_simplified_union(
6871+
[if_type or narrowable_expr_type, ambiguous_expr_type]
6872+
)
6873+
else_type = make_simplified_union(
6874+
[else_type or narrowable_expr_type, ambiguous_expr_type]
6875+
)
68836876

6884-
if_map, else_map = conditional_types_to_typemaps(
6885-
operands[i], *conditional_types(expr_type, [target], from_equality=True)
6886-
)
6877+
if_map, else_map = conditional_types_to_typemaps(operands[i], if_type, else_type)
68876878
if is_target_for_value_narrowing(get_proper_type(target_type)):
68886879
all_if_maps.append(if_map)
68896880
all_else_maps.append(else_map)
@@ -6964,13 +6955,29 @@ def narrow_type_by_identity_equality(
69646955
target_type = operand_types[j]
69656956
if should_coerce_literals:
69666957
target_type = coerce_to_literal(target_type)
6967-
target = TypeRange(target_type, is_upper_bound=False)
6958+
6959+
narrowable_expr_type, ambiguous_expr_type = partition_equality_ambiguous_types(
6960+
expr_type, target_type, is_identity=is_identity_comparison
6961+
)
6962+
6963+
if narrowable_expr_type is None:
6964+
if_type = else_type = ambiguous_expr_type
6965+
else:
6966+
narrowable_expr_type = coerce_to_literal(
6967+
try_expanding_sum_type_to_union(narrowable_expr_type, None)
6968+
)
6969+
if_type, else_type = conditional_types(
6970+
narrowable_expr_type,
6971+
[TypeRange(target_type, is_upper_bound=False)],
6972+
default=narrowable_expr_type,
6973+
from_equality=True,
6974+
)
6975+
if ambiguous_expr_type is not None:
6976+
if_type = make_simplified_union([if_type, ambiguous_expr_type])
6977+
else_type = make_simplified_union([else_type, ambiguous_expr_type])
69686978

69696979
if_map, else_map = conditional_types_to_typemaps(
6970-
operands[i],
6971-
*conditional_types(
6972-
expr_type, [target], default=expr_type, from_equality=True
6973-
),
6980+
operands[i], if_type, else_type
69746981
)
69756982
or_if_maps.append(if_map)
69766983
if is_target_for_value_narrowing(get_proper_type(target_type)):
@@ -8564,17 +8571,10 @@ def conditional_types(
85648571
# We erase generic args because values with different generic types can compare equal
85658572
# For instance, cast(list[str], []) and cast(list[int], [])
85668573
proposed_type = shallow_erase_type_for_equality(proposed_type)
8567-
if not is_overlapping_types(current_type, proposed_type, ignore_promotions=False):
8568-
# Equality narrowing is one of the places at runtime where subtyping with promotion
8569-
# does happen to match runtime semantics
8570-
# Expression is never of any type in proposed_type_ranges
8571-
return UninhabitedType(), default
8572-
if not is_overlapping_types(current_type, proposed_type, ignore_promotions=True):
8573-
return default, default
8574-
else:
8575-
if not is_overlapping_types(current_type, proposed_type, ignore_promotions=True):
8576-
# Expression is never of any type in proposed_type_ranges
8577-
return UninhabitedType(), default
8574+
8575+
if not is_overlapping_types(current_type, proposed_type, ignore_promotions=True):
8576+
# Expression is never of any type in proposed_type_ranges
8577+
return UninhabitedType(), default
85788578

85798579
# we can only restrict when the type is precise, not bounded
85808580
proposed_precise_type = UnionType.make_union(
@@ -8844,8 +8844,6 @@ def reduce_and_conditional_type_maps(ms: list[TypeMap], *, use_meet: bool) -> Ty
88448844

88458845

88468846
BUILTINS_CUSTOM_EQ_CHECKS: Final = {
8847-
"builtins.bytearray",
8848-
"builtins.memoryview",
88498847
"builtins.frozenset",
88508848
"_collections_abc.dict_keys",
88518849
"_collections_abc.dict_items",
@@ -8857,9 +8855,8 @@ def has_custom_eq_checks(t: Type) -> bool:
88578855
custom_special_method(t, "__eq__", check_all=False)
88588856
or custom_special_method(t, "__ne__", check_all=False)
88598857
# custom_special_method has special casing for builtins.* and typing.* that make the
8860-
# above always return False. So here we return True if the a value of a builtin type
8861-
# will ever compare equal to value of another type, e.g. a bytes value can compare equal
8862-
# to a bytearray value.
8858+
# above always return False. Some builtin collections still have equality behavior that
8859+
# crosses nominal type boundaries and isn't captured by VALUE_EQUALITY_TYPE_DOMAINS.
88638860
or (
88648861
isinstance(pt := get_proper_type(t), Instance)
88658862
and pt.type.fullname in BUILTINS_CUSTOM_EQ_CHECKS
@@ -9637,45 +9634,158 @@ def visit_starred_pattern(self, p: StarredPattern) -> None:
96379634
self.lvalue = False
96389635

96399636

9640-
def ambiguous_enum_equality_keys(t: Type) -> set[str]:
9641-
"""
9642-
Used when narrowing types based on equality.
9637+
# Open domains also block cross-type narrowing for known domain members, but they
9638+
# don't provide an exhaustive union to narrow top types to.
9639+
OPEN_VALUE_EQUALITY_DOMAINS: Final = {
9640+
"builtins.str": "builtins.str",
9641+
"builtins.bool": "builtins.numeric",
9642+
"builtins.int": "builtins.numeric",
9643+
"builtins.float": "builtins.numeric",
9644+
"builtins.complex": "builtins.numeric",
9645+
}
9646+
OPEN_VALUE_EQUALITY_DOMAIN_NAMES: Final = frozenset(OPEN_VALUE_EQUALITY_DOMAINS.values())
9647+
9648+
# Closed domains also block ordinary cross-type narrowing within the domain.
9649+
CLOSED_VALUE_EQUALITY_DOMAINS: Final = {
9650+
"builtins.bytes": "builtins.bytes",
9651+
"builtins.bytearray": "builtins.bytes",
9652+
"builtins.memoryview": "builtins.bytes",
9653+
}
9654+
9655+
VALUE_EQUALITY_DOMAINS: Final = {**OPEN_VALUE_EQUALITY_DOMAINS, **CLOSED_VALUE_EQUALITY_DOMAINS}
9656+
96439657

9644-
Certain kinds of enums can compare equal to values of other types, so doing type math
9645-
the way `conditional_types` does will be misleading if you expect it to correspond to
9646-
conditions based on equality comparisons.
9658+
class EqualityDomainInfo(NamedTuple):
9659+
type_names: set[str]
9660+
enum_type_names: set[str]
96479661

9648-
For example, StrEnum classes can compare equal to str values. So if we see
9649-
`val: StrEnum; if val == "foo": ...` we currently avoid narrowing.
9650-
Note that we do wish to continue narrowing for `if val == StrEnum.MEMBER: ...`
9662+
9663+
class EqualityValueInfo(NamedTuple):
9664+
domains: dict[str, EqualityDomainInfo]
9665+
is_top: bool
9666+
9667+
9668+
def closed_equality_domain_type_names(info: EqualityValueInfo) -> list[str]:
9669+
return [
9670+
fullname
9671+
for fullname, domain in CLOSED_VALUE_EQUALITY_DOMAINS.items()
9672+
if domain in info.domains
9673+
]
9674+
9675+
9676+
def partition_equality_ambiguous_types(
9677+
current_type: Type, target_type: Type, *, is_identity: bool
9678+
) -> tuple[Type | None, Type | None]:
9679+
"""Split current_type into ordinary-narrowable and equality-ambiguous pieces.
9680+
9681+
Some values compare equal through a value domain broader than their nominal type. For
9682+
example, an IntEnum member can compare equal to an int, and a StrEnum member can compare
9683+
equal to a str. When narrowing `x: MyStrEnum | str` against `MyStrEnum.MEMBER`, we can
9684+
still narrow the enum portion of the union, but we must keep the str portion in both
9685+
branches.
96519686
"""
9652-
# We need these things for this to be ambiguous:
9653-
# (1) an IntEnum or StrEnum type or enum subclass of int or str
9654-
# (2) either a different IntEnum/StrEnum type or a non-enum type ("<other>")
9655-
result = set()
9687+
if is_identity:
9688+
return current_type, None
9689+
9690+
typ = get_proper_type(current_type)
9691+
items = typ.relevant_items() if isinstance(typ, UnionType) else [current_type]
9692+
narrowable_items = []
9693+
ambiguous_items = []
9694+
for item in items:
9695+
if is_equality_ambiguous_for_narrowing(item, target_type):
9696+
ambiguous_items.append(item)
9697+
else:
9698+
narrowable_items.append(item)
9699+
return (
9700+
UnionType.make_union(narrowable_items) if narrowable_items else None,
9701+
UnionType.make_union(ambiguous_items) if ambiguous_items else None,
9702+
)
9703+
9704+
9705+
def is_equality_ambiguous_for_narrowing(left: Type, right: Type) -> bool:
9706+
"""Can left compare equal to right through a value domain outside nominal overlap?"""
9707+
left_info = equality_value_info(left)
9708+
right_info = equality_value_info(right)
9709+
9710+
if left_info.is_top or right_info.is_top:
9711+
# Only open-domain enum values can make a top-like type ambiguous.
9712+
# Closed domains can be narrowed to their complete known set instead.
9713+
other_info = right_info if left_info.is_top else left_info
9714+
return any(
9715+
domain in OPEN_VALUE_EQUALITY_DOMAIN_NAMES and domain_info.enum_type_names
9716+
for domain, domain_info in other_info.domains.items()
9717+
)
9718+
9719+
shared_domains = left_info.domains.keys() & right_info.domains.keys()
9720+
if not shared_domains:
9721+
return False
9722+
9723+
for domain in shared_domains:
9724+
left_domain = left_info.domains[domain]
9725+
right_domain = right_info.domains[domain]
9726+
# Equality between two values from the same enum can still narrow by literal member.
9727+
if (
9728+
left_domain.enum_type_names
9729+
and left_domain.enum_type_names == right_domain.enum_type_names
9730+
and left_domain.type_names == left_domain.enum_type_names
9731+
and right_domain.type_names == right_domain.enum_type_names
9732+
):
9733+
continue
9734+
# Different domain-member types may compare equal, but nominal narrowing would
9735+
# otherwise treat them as disjoint.
9736+
if left_domain.type_names != right_domain.type_names:
9737+
return True
9738+
# Same domain-member types are only ambiguous if an enum value may compare equal to
9739+
# its underlying value type.
9740+
if left_domain.enum_type_names or right_domain.enum_type_names:
9741+
return True
9742+
9743+
return False
9744+
9745+
9746+
def equality_value_info(t: Type) -> EqualityValueInfo:
96569747
t = get_proper_type(t)
96579748
if isinstance(t, UnionType):
9658-
for item in t.items:
9659-
result.update(ambiguous_enum_equality_keys(item))
9660-
elif isinstance(t, Instance):
9661-
if t.last_known_value:
9662-
result.update(ambiguous_enum_equality_keys(t.last_known_value))
9663-
elif t.type.is_enum and any(
9664-
base.fullname in ("enum.IntEnum", "enum.StrEnum", "builtins.str", "builtins.int")
9665-
for base in t.type.mro
9666-
):
9667-
result.add(t.type.fullname)
9668-
elif not t.type.is_enum:
9669-
# These might compare equal to IntEnum/StrEnum types (e.g. Decimal), so
9670-
# let's be conservative
9671-
result.add("<other>")
9672-
elif isinstance(t, LiteralType):
9673-
result.update(ambiguous_enum_equality_keys(t.fallback))
9674-
elif isinstance(t, NoneType):
9675-
pass
9676-
else:
9677-
result.add("<other>")
9678-
return result
9749+
return combine_equality_value_info(equality_value_info(item) for item in t.items)
9750+
if isinstance(t, TypeVarType):
9751+
if t.values:
9752+
return combine_equality_value_info(equality_value_info(item) for item in t.values)
9753+
return equality_value_info(t.upper_bound)
9754+
if isinstance(t, Instance) and t.last_known_value is not None:
9755+
return equality_value_info(t.last_known_value)
9756+
if isinstance(t, LiteralType):
9757+
return equality_value_info(t.fallback)
9758+
if isinstance(t, Instance):
9759+
if t.type.fullname == "builtins.object":
9760+
return EqualityValueInfo({}, is_top=True)
9761+
9762+
enum_type_names = {t.type.fullname} if t.type.is_enum else set()
9763+
domains = {}
9764+
for base in t.type.mro:
9765+
if domain := VALUE_EQUALITY_DOMAINS.get(base.fullname):
9766+
domains[domain] = EqualityDomainInfo({t.type.fullname}, enum_type_names)
9767+
9768+
return EqualityValueInfo(domains, is_top=False)
9769+
if isinstance(t, AnyType):
9770+
return EqualityValueInfo({}, is_top=True)
9771+
return EqualityValueInfo({}, is_top=False)
9772+
9773+
9774+
def combine_equality_value_info(infos: Iterable[EqualityValueInfo]) -> EqualityValueInfo:
9775+
domains: dict[str, EqualityDomainInfo] = {}
9776+
is_top = False
9777+
for info in infos:
9778+
for domain, domain_info in info.domains.items():
9779+
existing_domain_info = domains.get(domain)
9780+
if existing_domain_info is None:
9781+
domains[domain] = EqualityDomainInfo(
9782+
set(domain_info.type_names), set(domain_info.enum_type_names)
9783+
)
9784+
else:
9785+
existing_domain_info.type_names.update(domain_info.type_names)
9786+
existing_domain_info.enum_type_names.update(domain_info.enum_type_names)
9787+
is_top = is_top or info.is_top
9788+
return EqualityValueInfo(domains, is_top)
96799789

96809790

96819791
def is_typeddict_type_context(lvalue_type: Type) -> bool:

0 commit comments

Comments
 (0)