from __future__ import annotations
import warnings
from enum import Enum
from collections import OrderedDict
from types import MappingProxyType
from typing import Any, Callable, Dict, List, Type, Sequence
from functools import singledispatch, update_wrapper
from .utils import (
DEFAULT_BACKEND,
MultiImplementationsWarning,
TypeHolder,
evaluate_expr,
has_expr,
update_user_wrapper,
is_piping,
)
from .context import ContextPending, ContextType
from .piping import PipeableCall
class VerbCall(PipeableCall):DOCS
"""A verb call
Args:
func: The registered verb
args: and
kwargs: The arguments for the verb
"""
def __init__(
self,
func: Callable,
*args: Any,
**kwargs: Any,
) -> None:
self._pipda_func = func
self._pipda_args = args
self._pipda_kwargs = kwargs
self._pipda_backend = kwargs.pop("__backend", None)
def __str__(self) -> str:DOCS
strargs: List[str] = []
if not getattr(self._pipda_func, "dependent", False):
strargs.append(".")
funname = self._pipda_func.__name__
if self._pipda_args:
strargs.extend((str(arg) for arg in self._pipda_args))
if self._pipda_kwargs:
strargs.extend(
f"{key}={val}" for key, val in self._pipda_kwargs.items()
)
return f"{funname}({', '.join(strargs)})"
def _pipda_eval(DOCS
self,
data: Any,
context: ContextType | None = None,
) -> Any:
func = self._pipda_func.dispatch(
data.__class__,
backend=self._pipda_backend,
)
context, kw_context = self._pipda_func.get_context(func, context)
kw_context = kw_context or {}
if isinstance(context, Enum):
context = context.value
if isinstance(context, ContextPending):
return func(data, *self._pipda_args, **self._pipda_kwargs)
args = (evaluate_expr(arg, data, context) for arg in self._pipda_args)
kwargs = {
key: evaluate_expr(val, data, kw_context.get(key, context))
for key, val in self._pipda_kwargs.items()
}
return func(data, *args, **kwargs)
def register_verb(DOCS
cls: Type | Sequence[Type] = TypeHolder,
*,
func: Callable | None = None,
context: ContextType | None = None,
kw_context: Dict[str, ContextType] | None = None,
name: str | None = None,
qualname: str | None = None,
doc: str | None = None,
module: str | None = None,
dependent: bool = False,
ast_fallback: str = "piping_warning",
) -> Callable:
"""Register a verb
A verb is a function that takes a data as the first argument, and uses it
to evaluate the rest of the arguments. So the first argument is required
for a verb.
We can have multiple implementations of a verb for different types of data,
or evan the same type of data with different backends.
Args:
cls: The default type to register for _default backend
if TypeHolder, it is a generic function, and not counted as a
real implementation.
func: The function works as a verb.
If `None` (not provided), this function will return a decorator.
context: The context to evaluate the arguments
kw_context: The context to evaluate the keyword arguments
name: and
qualname: and
doc: and
module: The meta information about the function to overwrite `func`'s
or when it's not available from `func`
dependent: Whether the verb is dependent.
>>> @register_verb(context=Context.EVAL, dependent=True)
>>> def length(data):
>>> return len(data)
>>> # with dependent=True
>>> # length() -> VerbCall, waiting for data to evaluate
>>> # with dependent=False
>>> # length() -> TypeError, argument data is missing
ast_fallback: What's the supposed way to call the verb when
AST node detection fails.
piping - Suppose this verb is called like `data >> verb(...)`
normal - Suppose this verb is called like `verb(data, ...)`
piping_warning - Suppose piping call, but show a warning
normal_warning - Suppose normal call, but show a warning
raise - Raise an error
Returns:
The registered verb or a decorator to register a verb
"""
if func is None:
return lambda fun: register_verb(
cls,
func=fun,
context=context,
kw_context=kw_context,
name=name,
qualname=qualname,
doc=doc,
module=module,
dependent=dependent,
ast_fallback=ast_fallback,
)
def _backend_generic(*args, **kwargs):
raise NotImplementedError(
f"`{wrapper.__name__}` is not implemented by the given backend."
)
if not isinstance(cls, (list, tuple, set)) and cls is not TypeHolder:
cls = (cls,) # type: ignore
registry = OrderedDict(
{
DEFAULT_BACKEND: singledispatch(
func if cls is TypeHolder else _backend_generic
)
}
)
# implementation => backend
backends: Dict[Callable, str] = {}
# backend => implementation
favorables: Dict[str, Callable] = {}
# # cannot create weak reference to 'numpy.ufunc' object
# contexts = weakref.WeakKeyDictionary()
contexts = {
(func if cls is TypeHolder else _backend_generic): (
context,
kw_context,
)
}
def dispatch(cl, backend=None):
"""generic_func.dispatch(cls, backend) -> <function impl>, <context>
Runs the dispatch algorithm to return the best available implementation
for the given *cl* registered on *generic_func* of given *backend*.
if backend is not provided, we will look for the implementation of
the backends in reverse order.
"""
if backend is not None:
try:
reg = registry[backend]
except KeyError:
raise NotImplementedError(
f"[{wrapper.__name__}] "
f"No implementations found for backend `{backend}`."
)
return reg.dispatch(cl)
impls = []
favored_found = False
for backend, reg in reversed(registry.items()):
fun = reg.dispatch(cl)
if (
fun is _backend_generic
or (
fun is func
and cls is TypeHolder
and backend == DEFAULT_BACKEND
)
or (
# Non-favored impl after favored impl found
favored_found
and favorables.get(backend) is not fun
)
):
continue
if favorables.get(backend) is fun:
favored_found = True
backends[fun] = backend
impls.append((backend, fun))
if not impls:
return func if cls is TypeHolder else _backend_generic
if len(impls) > 1:
warnings.warn(
f"Multiple implementations found for `{wrapper.__name__}` "
f"by backends: [{', '.join(impl[0] for impl in impls)}], "
"register with more specific types, or pass "
"`__backend=<backend>` to specify a backend.",
MultiImplementationsWarning,
)
return impls[0][1]
def get_context(impl, default=None):
"""Get the context of the implementation
numpy ufuncs may not be able to set an attribute, so we need to
use a dict to store the context.
"""
out = contexts.get(impl, (context, kw_context))
return (default, out[1]) if out[0] is None else out
def register(
cls,
*,
backend=DEFAULT_BACKEND,
context=None,
kw_context=None,
favored=False,
overwrite_doc=False,
func=None,
):
if func is None:
return lambda fn: register(
cls,
backend=backend,
context=context,
kw_context=kw_context,
favored=favored,
overwrite_doc=overwrite_doc,
func=fn,
)
if backend not in registry:
registry[backend] = singledispatch(_backend_generic)
if isinstance(cls, (tuple, list, set)):
for c in cls:
registry[backend].register(c, func)
else:
registry[backend].register(cls, func)
if context is not None or kw_context is not None:
contexts[func] = context, kw_context
if favored:
favorables[backend] = func
if overwrite_doc:
wrapper.__doc__ = func.__doc__
return func
def wrapper(*args, **kwargs):
if dependent:
return VerbCall(wrapper, *args, **kwargs)
ast_fb = kwargs.pop("__ast_fallback", wrapper.ast_fallback)
if is_piping(wrapper.__name__, ast_fb):
return VerbCall(wrapper, *args, **kwargs)
if not args:
raise TypeError(
f"Missing the first argument for verb `{wrapper.__name__}`."
)
data, *args = args
if has_expr(data):
from .function import FunctionCall
return FunctionCall(wrapper, data, *args, **kwargs)
return VerbCall(wrapper, *args, **kwargs)._pipda_eval(data)
if cls is not TypeHolder:
register(cls, context=context, kw_context=kw_context, func=func)
wrapper.registry = MappingProxyType(registry)
wrapper.dispatch = dispatch
wrapper.register = register
wrapper.favorables = MappingProxyType(favorables)
wrapper.dependent = dependent
wrapper.ast_fallback = ast_fallback
wrapper.get_context = get_context
wrapper._pipda_functype = "verb"
update_wrapper(wrapper, func)
update_user_wrapper(
wrapper,
name=name,
qualname=qualname,
doc=doc,
module=module,
)
return wrapper