Skip to content

Commit cdb948f

Browse files
committed
gh-145478: Implement frozendict into functools.partial
1 parent 1dfe99a commit cdb948f

4 files changed

Lines changed: 113 additions & 69 deletions

File tree

Doc/library/functools.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,9 @@ The :mod:`!functools` module defines the following functions:
411411
.. versionchanged:: 3.14
412412
Added support for :data:`Placeholder` in positional arguments.
413413

414+
.. versionchanged:: 3.15
415+
:class:`partial` now stores keywords in a :class:`frozendict`
416+
414417
.. class:: partialmethod(func, /, *args, **keywords)
415418

416419
Return a new :class:`partialmethod` descriptor which behaves

Lib/functools.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -342,11 +342,12 @@ def _partial_new(cls, func, /, *args, **keywords):
342342
phcount, merger = _partial_prepare_merger(tot_args)
343343
else: # works for both pto_phcount == 0 and != 0
344344
phcount, merger = pto_phcount, func._merger
345-
keywords = {**func.keywords, **keywords}
345+
keywords = frozendict(**func.keywords, **keywords)
346346
func = func.func
347347
else:
348348
tot_args = args
349349
phcount, merger = _partial_prepare_merger(tot_args)
350+
keywords = frozendict(**keywords)
350351

351352
self = object.__new__(cls)
352353
self.func = func
@@ -408,19 +409,25 @@ def __setstate__(self, state):
408409
raise TypeError(f"expected 4 items in state, got {len(state)}")
409410
func, args, kwds, namespace = state
410411
if (not callable(func) or not isinstance(args, tuple) or
411-
(kwds is not None and not isinstance(kwds, dict)) or
412412
(namespace is not None and not isinstance(namespace, dict))):
413413
raise TypeError("invalid partial state")
414+
if kwds is not None and not (
415+
isinstance(kwds, dict) or isinstance(kwds, frozendict)):
416+
raise TypeError(f"keywords must be an instance of dict or frozendict, not {type(kwds)}")
414417

415418
if args and args[-1] is Placeholder:
416419
raise TypeError("trailing Placeholders are not allowed")
417420
phcount, merger = _partial_prepare_merger(args)
418421

419422
args = tuple(args) # just in case it's a subclass
420423
if kwds is None:
421-
kwds = {}
422-
elif type(kwds) is not dict: # XXX does it need to be *exactly* dict?
423-
kwds = dict(kwds)
424+
kwds = frozendict()
425+
else:
426+
for key in kwds:
427+
if type(key) is not str:
428+
raise TypeError("keywords must be a string")
429+
kwds = frozendict(kwds)
430+
424431
if namespace is None:
425432
namespace = {}
426433

Lib/test/test_functools.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@ class BadTuple(tuple):
6262
def __add__(self, other):
6363
return list(self) + list(other)
6464

65-
class MyDict(dict):
65+
66+
class MyDict(frozendict):
67+
pass
68+
69+
class MyFrozenDict(frozendict):
6670
pass
6771

6872
class TestImportTime(unittest.TestCase):
@@ -338,7 +342,7 @@ def test_pickle(self):
338342
with replaced_module('functools', self.module):
339343
f = self.partial(signature, ['asdf'], bar=[True])
340344
f.attr = []
341-
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
345+
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
342346
f_copy = pickle.loads(pickle.dumps(f, proto))
343347
self.assertEqual(signature(f_copy), signature(f))
344348

@@ -404,6 +408,16 @@ def test_setstate(self):
404408
with self.assertRaisesRegex(TypeError, f'^{msg_regex}$') as cm:
405409
f.__setstate__((capture, (1, PH), dict(a=10), dict(attr=[])))
406410

411+
with self.assertRaises(TypeError):
412+
f.__setstate__((capture, (1,), {1234: 1234}, dict(attr=[])))
413+
414+
class FakeString(str):
415+
pass
416+
417+
with self.assertRaises(TypeError):
418+
f.__setstate__((capture, (1,), {FakeString("string"): 1234}, dict(attr=[])))
419+
420+
407421
def test_setstate_errors(self):
408422
f = self.partial(signature)
409423

@@ -423,7 +437,18 @@ def test_setstate_subclasses(self):
423437
s = signature(f)
424438
self.assertEqual(s, (capture, (1,), dict(a=10), {}))
425439
self.assertIs(type(s[1]), tuple)
426-
self.assertIs(type(s[2]), dict)
440+
self.assertIs(type(s[2]), frozendict)
441+
r = f()
442+
self.assertEqual(r, ((1,), {'a': 10}))
443+
self.assertIs(type(r[0]), tuple)
444+
self.assertIs(type(r[1]), dict)
445+
446+
447+
f.__setstate__((capture, MyTuple((1,)), MyFrozenDict(a=10), None))
448+
s = signature(f)
449+
self.assertEqual(s, (capture, (1,), dict(a=10), {}))
450+
self.assertIs(type(s[1]), tuple)
451+
self.assertIs(type(s[2]), frozendict)
427452
r = f()
428453
self.assertEqual(r, ((1,), {'a': 10}))
429454
self.assertIs(type(r[0]), tuple)
@@ -445,7 +470,7 @@ def test_recursive_pickle(self):
445470
f = self.partial(capture)
446471
f.__setstate__((f, (), {}, {}))
447472
try:
448-
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
473+
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
449474
# gh-117008: Small limit since pickle uses C stack memory
450475
with support.infinite_recursion(100):
451476
with self.assertRaises(RecursionError):
@@ -456,7 +481,7 @@ def test_recursive_pickle(self):
456481
f = self.partial(capture)
457482
f.__setstate__((capture, (f,), {}, {}))
458483
try:
459-
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
484+
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
460485
f_copy = pickle.loads(pickle.dumps(f, proto))
461486
try:
462487
self.assertIs(f_copy.args[0], f_copy)
@@ -468,7 +493,7 @@ def test_recursive_pickle(self):
468493
f = self.partial(capture)
469494
f.__setstate__((capture, (), {'a': f}, {}))
470495
try:
471-
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
496+
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
472497
f_copy = pickle.loads(pickle.dumps(f, proto))
473498
try:
474499
self.assertIs(f_copy.keywords['a'], f_copy)
@@ -588,30 +613,15 @@ def test_attributes_unwritable(self):
588613
else:
589614
self.fail('partial object allowed __dict__ to be deleted')
590615

591-
def test_manually_adding_non_string_keyword(self):
616+
def test_keyword_mutations(self):
592617
p = self.partial(capture)
593-
# Adding a non-string/unicode keyword to partial kwargs
594-
p.keywords[1234] = 'value'
595-
r = repr(p)
596-
self.assertIn('1234', r)
597-
self.assertIn("'value'", r)
598-
with self.assertRaises(TypeError):
599-
p()
600618

601-
def test_keystr_replaces_value(self):
602-
p = self.partial(capture)
619+
with self.assertRaises(TypeError):
620+
p.keywords["new key"] = ['sth']
603621

604-
class MutatesYourDict(object):
605-
def __str__(self):
606-
p.keywords[self] = ['sth2']
607-
return 'astr'
608-
609-
# Replacing the value during key formatting should keep the original
610-
# value alive (at least long enough).
611-
p.keywords[MutatesYourDict()] = ['sth']
612-
r = repr(p)
613-
self.assertIn('astr', r)
614-
self.assertIn("['sth']", r)
622+
# Adding a non-string/unicode keyword to partial kwargs
623+
with self.assertRaises(TypeError):
624+
p.keywords[1234] = 'value'
615625

616626
def test_placeholders_refcount_smoke(self):
617627
PH = self.module.Placeholder

0 commit comments

Comments
 (0)