Skip to content

SOURCE CODE pipda.function DOCS

from __future__ import annotations

import warnings
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Type
from types import MappingProxyType
from functools import singledispatch, update_wrapper

from .utils import (
    DEFAULT_BACKEND,
    MultiImplementationsWarning,
    TypeHolder,
    evaluate_expr,
    update_user_wrapper,
    has_expr,
    is_piping,
)
from .expression import Expression

if TYPE_CHECKING:
    from .context import ContextType


class FunctionCall(Expression):DOCS
    """A function call object that awaits for evaluation

    Args:
        func: A registered function by `register_func` or an expression,
            for example, `f.col.mean`
        args: and
        kwargs: The arguments for the function
    """

    def __init__(
        self,
        func: Callable | Expression,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        self._pipda_func = func
        self._pipda_args = args
        self._pipda_kwargs = kwargs
        self._pipda_backend = kwargs.pop("__backend", None)

    def __str__(self) -> str:DOCS
        """Representation of the function call"""
        strargs: List[str] = []
        if isinstance(self._pipda_func, Expression):
            funname = str(self._pipda_func)
        else:
            funname = self._pipda_func.__name__

        if self._pipda_args:
            strargs.extend((str(arg) for arg in self._pipda_args))

        if self._pipda_kwargs:
            strargs.extend(
                f"{key}={val}" for key, val in self._pipda_kwargs.items()
            )
        return f"{funname}({', '.join(strargs)})"

    def _pipda_eval(
        self,
        data: Any,
        context: ContextType | None = None,
    ) -> Any:
        """Evaluate the function call"""
        func = impl = self._pipda_func
        if isinstance(func, Expression):
            # f.a(1)
            impl = evaluate_expr(func, data, context)

        args = self._pipda_args
        kwargs = self._pipda_kwargs

        functype = getattr(func, "_pipda_functype", None)
        if functype == "verb":
            dt = evaluate_expr(args[0], data, context)
            impl = func.dispatch(dt.__class__, backend=self._pipda_backend)
            ctx, kw_ctx = func.get_context(impl, context)
            ctx = ctx or context
            kw_ctx = kw_ctx or {}
            args = (
                dt,
                *(evaluate_expr(arg, dt, ctx) for arg in args[1:]),
            )
            kwargs = {
                key: evaluate_expr(val, dt, kw_ctx.get(key, ctx))
                for key, val in kwargs.items()
            }
        else:
            args = tuple(
                evaluate_expr(arg, data, context)
                for arg in args
            )
            kwargs = {
                key: evaluate_expr(val, data, context)
                for key, val in kwargs.items()
            }
            if functype == "func":
                impl = func.dispatch(backend=self._pipda_backend)
            elif functype == "dispatchable":
                impl = func.dispatch(
                    *(arg.__class__ for arg in args),
                    backend=self._pipda_backend,
                )

        return impl(*args, **kwargs)


def register_func(DOCS
    func: Callable | None = None,
    cls: Type = TypeHolder,
    *,
    plain: bool = False,
    name: str | None = None,
    qualname: str | None = None,
    doc: str | None = None,
    module: str | None = None,
    dispatchable: str | bool = False,
    pipeable: bool = False,
    context: ContextType | None = None,
    kw_context: Dict[str, ContextType] | None = None,
    ast_fallback: str = "normal_warning",
) -> Callable:
    """Register a function

    A function, unlike a verb, is a function that doesn't evaluate its
    arguments by the first argument, which is the data, it depends on the
    data from a verb to evaluate the arguments if they are Expression objects.

    A function can also be defined as pipeable, so that the first argument
    can be piped in later.

    A function can also be defined as dispatchable. The types of any positional
    arguments are used to dispatch the implementation.

    Args:
        func: The generic function.
            If `None` (not provided), this function will return a decorator.
        cls: The default type to register for _default backend
            if TypeHolder, it is a generic function, and not counted as a
            real implementation.
            For plain or non-dispatchable functions, specify a different type
            than TypeHolder to indicate the func is a real implementation.
        plain: If True, the function will be registered as a plain function,
            which means it will be called without any evaluation of the
            arguments. It doesn't support dispatchable and pipeable.
        name: and
        qualname: and
        doc: and
        module: The meta information about the function to overwrite `func`'s
            or when it's not available from `func`
        ast_fallback: What's the supposed way to call the func when
            AST node detection fails.
            piping - Suppose this func is called like `data >> func(...)`
            normal - Suppose this func is called like `func(data, ...)`
            piping_warning - Suppose piping call, but show a warning
            normal_warning - Suppose normal call, but show a warning
            raise - Raise an error
        dispatchable: If not False nor None, the function will be registered as
            a dispatchable function, which means it will be dispatched using
            the types of the arguments:
            "first" - Use the first argument
            "args" - Use all positional arguments
            "kwargs" - Use all keyword arguments
            "all" - Use all arguments
            If False, the function is  not dispatchable.
        pipeable: If True, the function will work like a verb when a data is
            piping in. If dispatchable, the first argument will be used to
            dispatch the implementation.
            The rest of the arguments will be evaluated using the data from
            the first argument.
        context: The context used to evaluate the rest arguments using the
            first argument only when the function is pipeable and the data
            is piping in.
        kw_context: The context used to evaluate the keyword arguments

    Returns:
        The registered func or a decorator to register a func
    """
    if func is None:
        return lambda fun: register_func(
            fun,
            cls=cls,
            plain=plain,
            name=name,
            qualname=qualname,
            doc=doc,
            module=module,
            dispatchable=dispatchable,
            pipeable=pipeable,
            context=context,
            kw_context=kw_context,
            ast_fallback=ast_fallback,
        )

    if plain:
        # make sure the flags are correct
        dispatchable = pipeable = False

    def _backend_generic(*args, **kwargs):  # pyright: ignore
        raise NotImplementedError(
            f"`{wrapper.__name__}` is not implemented by the given backend."
        )

    if not isinstance(cls, (list, tuple, set)) and cls is not TypeHolder:
        cls = (cls,)  # type: ignore

    if dispatchable:
        registry = OrderedDict(
            {
                DEFAULT_BACKEND: singledispatch(
                    func if cls is TypeHolder else _backend_generic
                )
            }
        )
    else:
        registry = OrderedDict({DEFAULT_BACKEND: func})  # type: ignore
    # backend => implementation
    favorables: Dict[str, Callable] = {}
    contexts = {
        (func if cls is TypeHolder else _backend_generic): (
            context,
            kw_context,
        )
    }

    def dispatch(*clses, backend=None):
        """generic_func.dispatch(*clses, backend) -> <function impl>

        Runs the dispatch algorithm to return the best available implementation
        for the given *cls* registered on *generic_func* of given *backend*.

        If backend is not provided, we will look for the implementation of
        the backends in reverse order.

        The first cls can be dispatched is used.

        Args:
            clses: The types to dispatch
            backend: The backend to dispatch

        Returns:
            The implementation function
        """
        if not clses:
            clses = (type(None),)

        if backend is not None:
            try:
                reg = registry[backend]
            except KeyError:
                raise NotImplementedError(
                    f"[{wrapper.__name__}] "
                    f"No implementations found for backend `{backend}`."
                )

            if not dispatchable:
                return reg

            for cl in clses:
                fun = reg.dispatch(cl)
                # Any impl found
                if fun is not _backend_generic:
                    return fun
            return _backend_generic

        impls = []
        favored_found = False
        for backend, reg in reversed(registry.items()):
            impl = None
            if not dispatchable:
                if (backend == DEFAULT_BACKEND and cls is TypeHolder) or (
                    favored_found and favorables.get(backend) is not reg
                ):
                    continue

                impl = reg
            else:
                for cl in clses:
                    fun = reg.dispatch(cl)
                    if (
                        # Not really an impl
                        fun is _backend_generic
                        or (
                            # The generic, supposed to raise NotImplementedError
                            fun is func
                            and cls is TypeHolder
                            and backend == DEFAULT_BACKEND
                        )
                        or (
                            # Non-favored impl after favored impl found
                            favored_found
                            and favorables.get(backend) is not fun
                        )
                        or (
                            # Previously found impl is better
                            # (not dispatched by object), but this one is
                            # dispatched by object, skip
                            impls
                            and impls[-1][1]
                            is not registry[impls[-1][0]].dispatch(object)
                            and fun is registry[backend].dispatch(object)
                        )
                    ):  # pragma: no cover
                        continue

                    impl = fun
                    break

            if impl is not None:
                if favorables.get(backend) is impl:
                    favored_found = True

                impls.append((backend, impl))

        if not impls:
            fn = func if cls is TypeHolder else _backend_generic
            return fn

        if len(impls) > 1:
            warnings.warn(
                f"Multiple implementations found for `{wrapper.__name__}` "
                f"by backends: [{', '.join(impl[0] for impl in impls)}], "
                "register with more specific types, or pass "
                "`__backend=<backend>` to specify a backend.",
                MultiImplementationsWarning,
            )

        return impls[0][1]

    def get_context(impl, default=None):
        """Get the context of the implementation

        numpy ufuncs may not be able to set an attribute, so we need to
        use a dict to store the context.
        """
        out = contexts.get(impl, (context, kw_context))
        return (default, out[1]) if out[0] is None else out

    def register(
        cls=None,
        *,
        backend=DEFAULT_BACKEND,
        favored=False,
        overwrite_doc=False,
        context=None,
        kw_context=None,
        func=None,
    ):
        """generic_func.register(
            cls,
            backend,
            favored,
            overwrite_doc,
            func
        ) -> func

        Args:
            cls: The type to register for the given backend
            backend: The backend to register for
            favored: Whether this implementation is favored. If so, non-favored
                implementations will be ignored if this implementation is found.
            overwrite_doc: Whether to overwrite the docstring of the function
            context: The context used to evaluate the rest arguments using the
                first argument only when the function is pipeable and the data
                is piping in.
            kw_context: The context used to evaluate the keyword arguments
            func: The implementation function

        Returns:
            The implementation function
        """
        if func is None:
            return lambda fn: register(
                cls,
                backend=backend,
                favored=favored,
                overwrite_doc=overwrite_doc,
                context=context,
                kw_context=kw_context,
                func=fn,
            )

        if not dispatchable:
            registry[backend] = func
        else:
            if backend not in registry:
                registry[backend] = singledispatch(_backend_generic)

            if isinstance(cls, (tuple, list, set)):
                for c in cls:
                    registry[backend].register(c, func)
            else:
                registry[backend].register(cls, func)  # type: ignore

        if favored:
            favorables[backend] = func
        if context is not None or kw_context is not None:
            contexts[func] = context, kw_context
        if overwrite_doc:
            wrapper.__doc__ = func.__doc__
        return func

    def wrapper(*args, **kwargs):
        if plain:
            backend = kwargs.pop("__backend", None)
            return dispatch(backend=backend)(*args, **kwargs)

        if pipeable:
            ast_fb = kwargs.pop("__ast_fallback", wrapper.ast_fallback)

            if is_piping(wrapper.__name__, ast_fb):
                from .verb import VerbCall

                return VerbCall(wrapper, *args, **kwargs)

        # Not pipeable
        if has_expr(args) or has_expr(kwargs):
            return FunctionCall(wrapper, *args, **kwargs)

        # No Expression objects, call directly
        backend = kwargs.pop("__backend", None)

        if not dispatchable:
            func = dispatch(backend=backend)
        elif dispatchable == "first":
            func = dispatch(args[0].__class__, backend=backend)
        elif dispatchable == "args":
            func = dispatch(
                *(arg.__class__ for arg in args),
                backend=backend,
            )
        elif dispatchable == "kwargs":
            func = dispatch(
                *(arg.__class__ for arg in kwargs.values()),
                backend=backend,
            )
        else:  # all
            func = dispatch(
                *(arg.__class__ for arg in args),
                *(arg.__class__ for arg in kwargs.values()),
                backend=backend,
            )

        return func(*args, **kwargs)

    if plain:
        wrapper._pipda_functype = "plain"
    elif dispatchable:
        if cls is not TypeHolder:
            register(cls, context=context, kw_context=kw_context, func=func)
        wrapper._pipda_functype = "dispatchable"
    else:
        wrapper._pipda_functype = "func"

    wrapper.registry = MappingProxyType(registry)
    wrapper.dispatch = dispatch
    wrapper.register = register
    wrapper.get_context = get_context
    wrapper.ast_fallback = ast_fallback
    wrapper.favorables = MappingProxyType(favorables)

    update_wrapper(wrapper, func)
    update_user_wrapper(
        wrapper,
        name=name,
        qualname=qualname,
        doc=doc,
        module=module,
    )
    return wrapper