mirror of
https://github.com/ytdl-org/youtube-dl
synced 2025-12-05 02:54:45 +00:00
[utils] Add partial_application decorator function
Thx: yt-dlp/yt-dlp#10653
This commit is contained in:
parent
a96a778750
commit
23a848c314
@ -69,6 +69,7 @@ from youtube_dl.utils import (
|
|||||||
parse_iso8601,
|
parse_iso8601,
|
||||||
parse_resolution,
|
parse_resolution,
|
||||||
parse_qs,
|
parse_qs,
|
||||||
|
partial_application,
|
||||||
pkcs1pad,
|
pkcs1pad,
|
||||||
prepend_extension,
|
prepend_extension,
|
||||||
read_batch_urls,
|
read_batch_urls,
|
||||||
@ -1723,6 +1724,21 @@ Line 1
|
|||||||
'a', 'b', 'c', 'd',
|
'a', 'b', 'c', 'd',
|
||||||
from_dict={'a': 'c', 'c': [], 'b': 'd', 'd': None}), 'c-d')
|
from_dict={'a': 'c', 'c': [], 'b': 'd', 'd': None}), 'c-d')
|
||||||
|
|
||||||
|
def test_partial_application(self):
|
||||||
|
test_fn = partial_application(lambda x, kwarg=None: '{0}, kwarg={1!r}'.format(x, kwarg))
|
||||||
|
self.assertTrue(
|
||||||
|
callable(test_fn(kwarg=10)),
|
||||||
|
'missing positional parameter should apply partially')
|
||||||
|
self.assertEqual(
|
||||||
|
test_fn(10, kwarg=0.1), '10, kwarg=0.1',
|
||||||
|
'positionally passed argument should call function')
|
||||||
|
self.assertEqual(
|
||||||
|
test_fn(x=10), '10, kwarg=None',
|
||||||
|
'keyword passed positional should call function')
|
||||||
|
self.assertEqual(
|
||||||
|
test_fn(kwarg=0.1)(10), '10, kwarg=0.1',
|
||||||
|
'call after partial application should call the function')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@ -1861,6 +1861,39 @@ def write_json_file(obj, fn):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class partial_application(object):
|
||||||
|
"""Allow a function to use pre-set argument values"""
|
||||||
|
|
||||||
|
# see _try_bind_args()
|
||||||
|
try:
|
||||||
|
inspect.signature
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def required_args(fn):
|
||||||
|
return [
|
||||||
|
param.name for param in inspect.signature(fn).parameters.values()
|
||||||
|
if (param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
|
||||||
|
and param.default is inspect.Parameter.empty)]
|
||||||
|
|
||||||
|
except AttributeError:
|
||||||
|
|
||||||
|
# Py < 3.3
|
||||||
|
@staticmethod
|
||||||
|
def required_args(fn):
|
||||||
|
fn_args = inspect.getargspec(fn)
|
||||||
|
n_defaults = len(fn_args.defaults or [])
|
||||||
|
return (fn_args.args or [])[:-n_defaults if n_defaults > 0 else None]
|
||||||
|
|
||||||
|
def __new__(cls, func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapped(*args, **kwargs):
|
||||||
|
if set(cls.required_args(func)[len(args):]).difference(kwargs):
|
||||||
|
return functools.partial(func, *args, **kwargs)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
if sys.version_info >= (2, 7):
|
if sys.version_info >= (2, 7):
|
||||||
def find_xpath_attr(node, xpath, key, val=None):
|
def find_xpath_attr(node, xpath, key, val=None):
|
||||||
""" Find the xpath xpath[@key=val] """
|
""" Find the xpath xpath[@key=val] """
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user