@@ -108,7 +108,7 @@ def __init__(self) -> None:
108108from mypy .expandtype import expand_type
109109from mypy .literals import Key , extract_var_from_literal_hash , literal , literal_hash
110110from mypy .maptype import map_instance_to_supertype
111- from mypy .meet import is_overlapping_erased_types , is_overlapping_types , meet_types
111+ from mypy .meet import is_overlapping_types , meet_types
112112from mypy .message_registry import ErrorMessage
113113from mypy .messages import (
114114 SUGGESTED_TEST_FIXTURES ,
@@ -6720,22 +6720,6 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
67206720 narrowable_indices = {0 },
67216721 )
67226722
6723- # TODO: This remove_optional code should no longer be needed. The only
6724- # thing it does is paper over a pre-existing deficiency in equality
6725- # narrowing w.r.t to enums.
6726- # We only try and narrow away 'None' for now
6727- if (
6728- not is_unreachable_map (if_map )
6729- and is_overlapping_none (item_type )
6730- and not is_overlapping_none (collection_item_type )
6731- and not (
6732- isinstance (collection_item_type , Instance )
6733- and collection_item_type .type .fullname == "builtins.object"
6734- )
6735- and is_overlapping_erased_types (item_type , collection_item_type )
6736- ):
6737- if_map [operands [left_index ]] = remove_optional (item_type )
6738-
67396723 if right_index in narrowable_operand_index_to_hash :
67406724 if_type , else_type = self .conditional_types_for_iterable (
67416725 item_type , iterable_type
@@ -6820,17 +6804,15 @@ def narrow_type_by_identity_equality(
68206804 # have to be more careful about what narrowing we can conclude from a successful comparison
68216805 custom_eq_indices : set [int ]
68226806
6823- # enum_comparison_is_ambiguous:
6824- # `if x is Fruits.APPLE` we know `x` is `Fruits.APPLE`, but `if x == Fruits.APPLE: ...`
6825- # it could e.g. be an int or str if Fruits is an IntEnum or StrEnum.
6826- # See ambiguous_enum_equality_keys for more details
6827- enum_comparison_is_ambiguous : bool
6807+ # Equality can use value semantics, so `if x == Fruits.APPLE: ...` may also
6808+ # match non-enum values for IntEnum/StrEnum-like enums. Identity checks don't
6809+ # have this ambiguity.
6810+ is_identity_comparison = operator in {"is" , "is not" }
68286811
6829- if operator in { "is" , "is not" } :
6812+ if is_identity_comparison :
68306813 is_target_for_value_narrowing = is_singleton_identity_type
68316814 should_coerce_literals = True
68326815 custom_eq_indices = set ()
6833- enum_comparison_is_ambiguous = False
68346816
68356817 elif operator in {"==" , "!=" }:
68366818 is_target_for_value_narrowing = is_singleton_equality_type
@@ -6843,7 +6825,6 @@ def narrow_type_by_identity_equality(
68436825 break
68446826
68456827 custom_eq_indices = {i for i in expr_indices if has_custom_eq_checks (operand_types [i ])}
6846- enum_comparison_is_ambiguous = True
68476828 else :
68486829 raise AssertionError
68496830
@@ -6859,8 +6840,6 @@ def narrow_type_by_identity_equality(
68596840 continue
68606841
68616842 expr_type = operand_types [i ]
6862- expr_enum_keys = ambiguous_enum_equality_keys (expr_type )
6863- expr_type = try_expanding_sum_type_to_union (coerce_to_literal (expr_type ), None )
68646843 for j in expr_indices :
68656844 if i == j :
68666845 continue
@@ -6872,18 +6851,30 @@ def narrow_type_by_identity_equality(
68726851 if should_coerce_literals :
68736852 target_type = coerce_to_literal (target_type )
68746853
6875- if (
6876- # See comments in ambiguous_enum_equality_keys
6877- enum_comparison_is_ambiguous
6878- and len (expr_enum_keys | ambiguous_enum_equality_keys (target_type )) > 1
6879- ):
6880- continue
6854+ narrowable_expr_type , ambiguous_expr_type = partition_equality_ambiguous_types (
6855+ expr_type , target_type , is_identity = is_identity_comparison
6856+ )
68816857
6882- target = TypeRange (target_type , is_upper_bound = False )
6858+ if narrowable_expr_type is None :
6859+ if_type = else_type = ambiguous_expr_type
6860+ else :
6861+ narrowable_expr_type = try_expanding_sum_type_to_union (
6862+ coerce_to_literal (narrowable_expr_type ), None
6863+ )
6864+ if_type , else_type = conditional_types (
6865+ narrowable_expr_type ,
6866+ [TypeRange (target_type , is_upper_bound = False )],
6867+ from_equality = True ,
6868+ )
6869+ if ambiguous_expr_type is not None :
6870+ if_type = make_simplified_union (
6871+ [if_type or narrowable_expr_type , ambiguous_expr_type ]
6872+ )
6873+ else_type = make_simplified_union (
6874+ [else_type or narrowable_expr_type , ambiguous_expr_type ]
6875+ )
68836876
6884- if_map , else_map = conditional_types_to_typemaps (
6885- operands [i ], * conditional_types (expr_type , [target ], from_equality = True )
6886- )
6877+ if_map , else_map = conditional_types_to_typemaps (operands [i ], if_type , else_type )
68876878 if is_target_for_value_narrowing (get_proper_type (target_type )):
68886879 all_if_maps .append (if_map )
68896880 all_else_maps .append (else_map )
@@ -6964,13 +6955,29 @@ def narrow_type_by_identity_equality(
69646955 target_type = operand_types [j ]
69656956 if should_coerce_literals :
69666957 target_type = coerce_to_literal (target_type )
6967- target = TypeRange (target_type , is_upper_bound = False )
6958+
6959+ narrowable_expr_type , ambiguous_expr_type = partition_equality_ambiguous_types (
6960+ expr_type , target_type , is_identity = is_identity_comparison
6961+ )
6962+
6963+ if narrowable_expr_type is None :
6964+ if_type = else_type = ambiguous_expr_type
6965+ else :
6966+ narrowable_expr_type = coerce_to_literal (
6967+ try_expanding_sum_type_to_union (narrowable_expr_type , None )
6968+ )
6969+ if_type , else_type = conditional_types (
6970+ narrowable_expr_type ,
6971+ [TypeRange (target_type , is_upper_bound = False )],
6972+ default = narrowable_expr_type ,
6973+ from_equality = True ,
6974+ )
6975+ if ambiguous_expr_type is not None :
6976+ if_type = make_simplified_union ([if_type , ambiguous_expr_type ])
6977+ else_type = make_simplified_union ([else_type , ambiguous_expr_type ])
69686978
69696979 if_map , else_map = conditional_types_to_typemaps (
6970- operands [i ],
6971- * conditional_types (
6972- expr_type , [target ], default = expr_type , from_equality = True
6973- ),
6980+ operands [i ], if_type , else_type
69746981 )
69756982 or_if_maps .append (if_map )
69766983 if is_target_for_value_narrowing (get_proper_type (target_type )):
@@ -8564,17 +8571,10 @@ def conditional_types(
85648571 # We erase generic args because values with different generic types can compare equal
85658572 # For instance, cast(list[str], []) and cast(list[int], [])
85668573 proposed_type = shallow_erase_type_for_equality (proposed_type )
8567- if not is_overlapping_types (current_type , proposed_type , ignore_promotions = False ):
8568- # Equality narrowing is one of the places at runtime where subtyping with promotion
8569- # does happen to match runtime semantics
8570- # Expression is never of any type in proposed_type_ranges
8571- return UninhabitedType (), default
8572- if not is_overlapping_types (current_type , proposed_type , ignore_promotions = True ):
8573- return default , default
8574- else :
8575- if not is_overlapping_types (current_type , proposed_type , ignore_promotions = True ):
8576- # Expression is never of any type in proposed_type_ranges
8577- return UninhabitedType (), default
8574+
8575+ if not is_overlapping_types (current_type , proposed_type , ignore_promotions = True ):
8576+ # Expression is never of any type in proposed_type_ranges
8577+ return UninhabitedType (), default
85788578
85798579 # we can only restrict when the type is precise, not bounded
85808580 proposed_precise_type = UnionType .make_union (
@@ -8844,8 +8844,6 @@ def reduce_and_conditional_type_maps(ms: list[TypeMap], *, use_meet: bool) -> Ty
88448844
88458845
88468846BUILTINS_CUSTOM_EQ_CHECKS : Final = {
8847- "builtins.bytearray" ,
8848- "builtins.memoryview" ,
88498847 "builtins.frozenset" ,
88508848 "_collections_abc.dict_keys" ,
88518849 "_collections_abc.dict_items" ,
@@ -8857,9 +8855,8 @@ def has_custom_eq_checks(t: Type) -> bool:
88578855 custom_special_method (t , "__eq__" , check_all = False )
88588856 or custom_special_method (t , "__ne__" , check_all = False )
88598857 # custom_special_method has special casing for builtins.* and typing.* that make the
8860- # above always return False. So here we return True if the a value of a builtin type
8861- # will ever compare equal to value of another type, e.g. a bytes value can compare equal
8862- # to a bytearray value.
8858+ # above always return False. Some builtin collections still have equality behavior that
8859+ # crosses nominal type boundaries and isn't captured by VALUE_EQUALITY_TYPE_DOMAINS.
88638860 or (
88648861 isinstance (pt := get_proper_type (t ), Instance )
88658862 and pt .type .fullname in BUILTINS_CUSTOM_EQ_CHECKS
@@ -9637,45 +9634,158 @@ def visit_starred_pattern(self, p: StarredPattern) -> None:
96379634 self .lvalue = False
96389635
96399636
9640- def ambiguous_enum_equality_keys (t : Type ) -> set [str ]:
9641- """
9642- Used when narrowing types based on equality.
9637+ # Open domains also block cross-type narrowing for known domain members, but they
9638+ # don't provide an exhaustive union to narrow top types to.
9639+ OPEN_VALUE_EQUALITY_DOMAINS : Final = {
9640+ "builtins.str" : "builtins.str" ,
9641+ "builtins.bool" : "builtins.numeric" ,
9642+ "builtins.int" : "builtins.numeric" ,
9643+ "builtins.float" : "builtins.numeric" ,
9644+ "builtins.complex" : "builtins.numeric" ,
9645+ }
9646+ OPEN_VALUE_EQUALITY_DOMAIN_NAMES : Final = frozenset (OPEN_VALUE_EQUALITY_DOMAINS .values ())
9647+
9648+ # Closed domains also block ordinary cross-type narrowing within the domain.
9649+ CLOSED_VALUE_EQUALITY_DOMAINS : Final = {
9650+ "builtins.bytes" : "builtins.bytes" ,
9651+ "builtins.bytearray" : "builtins.bytes" ,
9652+ "builtins.memoryview" : "builtins.bytes" ,
9653+ }
9654+
9655+ VALUE_EQUALITY_DOMAINS : Final = {** OPEN_VALUE_EQUALITY_DOMAINS , ** CLOSED_VALUE_EQUALITY_DOMAINS }
9656+
96439657
9644- Certain kinds of enums can compare equal to values of other types, so doing type math
9645- the way `conditional_types` does will be misleading if you expect it to correspond to
9646- conditions based on equality comparisons.
9658+ class EqualityDomainInfo ( NamedTuple ):
9659+ type_names : set [ str ]
9660+ enum_type_names : set [ str ]
96479661
9648- For example, StrEnum classes can compare equal to str values. So if we see
9649- `val: StrEnum; if val == "foo": ...` we currently avoid narrowing.
9650- Note that we do wish to continue narrowing for `if val == StrEnum.MEMBER: ...`
9662+
9663+ class EqualityValueInfo (NamedTuple ):
9664+ domains : dict [str , EqualityDomainInfo ]
9665+ is_top : bool
9666+
9667+
9668+ def closed_equality_domain_type_names (info : EqualityValueInfo ) -> list [str ]:
9669+ return [
9670+ fullname
9671+ for fullname , domain in CLOSED_VALUE_EQUALITY_DOMAINS .items ()
9672+ if domain in info .domains
9673+ ]
9674+
9675+
9676+ def partition_equality_ambiguous_types (
9677+ current_type : Type , target_type : Type , * , is_identity : bool
9678+ ) -> tuple [Type | None , Type | None ]:
9679+ """Split current_type into ordinary-narrowable and equality-ambiguous pieces.
9680+
9681+ Some values compare equal through a value domain broader than their nominal type. For
9682+ example, an IntEnum member can compare equal to an int, and a StrEnum member can compare
9683+ equal to a str. When narrowing `x: MyStrEnum | str` against `MyStrEnum.MEMBER`, we can
9684+ still narrow the enum portion of the union, but we must keep the str portion in both
9685+ branches.
96519686 """
9652- # We need these things for this to be ambiguous:
9653- # (1) an IntEnum or StrEnum type or enum subclass of int or str
9654- # (2) either a different IntEnum/StrEnum type or a non-enum type ("<other>")
9655- result = set ()
9687+ if is_identity :
9688+ return current_type , None
9689+
9690+ typ = get_proper_type (current_type )
9691+ items = typ .relevant_items () if isinstance (typ , UnionType ) else [current_type ]
9692+ narrowable_items = []
9693+ ambiguous_items = []
9694+ for item in items :
9695+ if is_equality_ambiguous_for_narrowing (item , target_type ):
9696+ ambiguous_items .append (item )
9697+ else :
9698+ narrowable_items .append (item )
9699+ return (
9700+ UnionType .make_union (narrowable_items ) if narrowable_items else None ,
9701+ UnionType .make_union (ambiguous_items ) if ambiguous_items else None ,
9702+ )
9703+
9704+
9705+ def is_equality_ambiguous_for_narrowing (left : Type , right : Type ) -> bool :
9706+ """Can left compare equal to right through a value domain outside nominal overlap?"""
9707+ left_info = equality_value_info (left )
9708+ right_info = equality_value_info (right )
9709+
9710+ if left_info .is_top or right_info .is_top :
9711+ # Only open-domain enum values can make a top-like type ambiguous.
9712+ # Closed domains can be narrowed to their complete known set instead.
9713+ other_info = right_info if left_info .is_top else left_info
9714+ return any (
9715+ domain in OPEN_VALUE_EQUALITY_DOMAIN_NAMES and domain_info .enum_type_names
9716+ for domain , domain_info in other_info .domains .items ()
9717+ )
9718+
9719+ shared_domains = left_info .domains .keys () & right_info .domains .keys ()
9720+ if not shared_domains :
9721+ return False
9722+
9723+ for domain in shared_domains :
9724+ left_domain = left_info .domains [domain ]
9725+ right_domain = right_info .domains [domain ]
9726+ # Equality between two values from the same enum can still narrow by literal member.
9727+ if (
9728+ left_domain .enum_type_names
9729+ and left_domain .enum_type_names == right_domain .enum_type_names
9730+ and left_domain .type_names == left_domain .enum_type_names
9731+ and right_domain .type_names == right_domain .enum_type_names
9732+ ):
9733+ continue
9734+ # Different domain-member types may compare equal, but nominal narrowing would
9735+ # otherwise treat them as disjoint.
9736+ if left_domain .type_names != right_domain .type_names :
9737+ return True
9738+ # Same domain-member types are only ambiguous if an enum value may compare equal to
9739+ # its underlying value type.
9740+ if left_domain .enum_type_names or right_domain .enum_type_names :
9741+ return True
9742+
9743+ return False
9744+
9745+
9746+ def equality_value_info (t : Type ) -> EqualityValueInfo :
96569747 t = get_proper_type (t )
96579748 if isinstance (t , UnionType ):
9658- for item in t .items :
9659- result .update (ambiguous_enum_equality_keys (item ))
9660- elif isinstance (t , Instance ):
9661- if t .last_known_value :
9662- result .update (ambiguous_enum_equality_keys (t .last_known_value ))
9663- elif t .type .is_enum and any (
9664- base .fullname in ("enum.IntEnum" , "enum.StrEnum" , "builtins.str" , "builtins.int" )
9665- for base in t .type .mro
9666- ):
9667- result .add (t .type .fullname )
9668- elif not t .type .is_enum :
9669- # These might compare equal to IntEnum/StrEnum types (e.g. Decimal), so
9670- # let's be conservative
9671- result .add ("<other>" )
9672- elif isinstance (t , LiteralType ):
9673- result .update (ambiguous_enum_equality_keys (t .fallback ))
9674- elif isinstance (t , NoneType ):
9675- pass
9676- else :
9677- result .add ("<other>" )
9678- return result
9749+ return combine_equality_value_info (equality_value_info (item ) for item in t .items )
9750+ if isinstance (t , TypeVarType ):
9751+ if t .values :
9752+ return combine_equality_value_info (equality_value_info (item ) for item in t .values )
9753+ return equality_value_info (t .upper_bound )
9754+ if isinstance (t , Instance ) and t .last_known_value is not None :
9755+ return equality_value_info (t .last_known_value )
9756+ if isinstance (t , LiteralType ):
9757+ return equality_value_info (t .fallback )
9758+ if isinstance (t , Instance ):
9759+ if t .type .fullname == "builtins.object" :
9760+ return EqualityValueInfo ({}, is_top = True )
9761+
9762+ enum_type_names = {t .type .fullname } if t .type .is_enum else set ()
9763+ domains = {}
9764+ for base in t .type .mro :
9765+ if domain := VALUE_EQUALITY_DOMAINS .get (base .fullname ):
9766+ domains [domain ] = EqualityDomainInfo ({t .type .fullname }, enum_type_names )
9767+
9768+ return EqualityValueInfo (domains , is_top = False )
9769+ if isinstance (t , AnyType ):
9770+ return EqualityValueInfo ({}, is_top = True )
9771+ return EqualityValueInfo ({}, is_top = False )
9772+
9773+
9774+ def combine_equality_value_info (infos : Iterable [EqualityValueInfo ]) -> EqualityValueInfo :
9775+ domains : dict [str , EqualityDomainInfo ] = {}
9776+ is_top = False
9777+ for info in infos :
9778+ for domain , domain_info in info .domains .items ():
9779+ existing_domain_info = domains .get (domain )
9780+ if existing_domain_info is None :
9781+ domains [domain ] = EqualityDomainInfo (
9782+ set (domain_info .type_names ), set (domain_info .enum_type_names )
9783+ )
9784+ else :
9785+ existing_domain_info .type_names .update (domain_info .type_names )
9786+ existing_domain_info .enum_type_names .update (domain_info .enum_type_names )
9787+ is_top = is_top or info .is_top
9788+ return EqualityValueInfo (domains , is_top )
96799789
96809790
96819791def is_typeddict_type_context (lvalue_type : Type ) -> bool :
0 commit comments