Skip to content

SOURCE CODE pipda.piping DOCS

from __future__ import annotations

import ast
import functools
from abc import ABC
from typing import Type, Dict, Callable

from .expression import Expression

PIPING_OPS = {
    # op: (method, ast node, numpy ufunc name)
    ">>": ("__rrshift__", ast.RShift, "right_shift"),
    "|": ("__ror__", ast.BitOr, "bitwise_or"),
    "//": ("__rfloordiv__", ast.FloorDiv, "floor_divide"),
    "@": ("__rmatmul__", ast.MatMult, "matmul"),
    "%": ("__rmod__", ast.Mod, "remainder"),
    "&": ("__rand__", ast.BitAnd, "bitwise_and"),
    "^": ("__rxor__", ast.BitXor, "bitwise_xor"),
}

PATCHED_CLASSES: Dict[Type, Dict[str, Callable]] = {
    # kls:
    #    {}  # registered but not patched
    #    {"method": <method>, "imethod": <imethod>}  # patched
}


class PipeableCall(Expression, ABC):DOCS
    """A pipeable call that waits for the data to be piped in

    >>> data >> pipeable_call(...)
    """
    PIPING: str | None = None


def _patch_cls_method(kls: Type, method: str) -> None:
    """Borrowed from https://github.com/sspipe/sspipe"""
    try:
        original = getattr(kls, method)
    except AttributeError:
        return

    PATCHED_CLASSES[kls][method] = original

    @functools.wraps(original)
    def wrapper(self, x, *args, **kwargs):
        if isinstance(x, PipeableCall):
            return NotImplemented
        return original(self, x, *args, **kwargs)

    setattr(kls, method, wrapper)


def _unpatch_cls_method(kls: Type, method: str) -> None:
    if method in PATCHED_CLASSES[kls]:
        setattr(kls, method, PATCHED_CLASSES[kls].pop(method))


def _patch_cls_operator(kls: Type, op: str | None) -> None:
    if op is None:  # pragma: no cover
        return
    method = PIPING_OPS[op][0].replace("__r", "__")
    imethod = PIPING_OPS[op][0].replace("__r", "__i")
    _patch_cls_method(kls, method)
    _patch_cls_method(kls, imethod)


def _unpatch_cls_operator(kls: Type, op: str | None) -> None:
    if op is None:  # pragma: no cover
        return
    method = PIPING_OPS[op][0].replace("__r", "__")
    imethod = PIPING_OPS[op][0].replace("__r", "__i")
    _unpatch_cls_method(kls, method)
    _unpatch_cls_method(kls, imethod)


def patch_classes(*classes: Type) -> None:DOCS
    """Patch the classes in case it has piping operator defined

    For example, DataFrame.__or__ has already been defined, so we need to
    patch it to force it to use __ror__ of PipeableCall if `|` is registered
    for piping.

    Args:
        classes: The classes to patch
    """
    for kls in classes:
        if kls not in PATCHED_CLASSES:
            PATCHED_CLASSES[kls] = {}

        if not PATCHED_CLASSES[kls]:
            _patch_cls_operator(kls, PipeableCall.PIPING)


def unpatch_classes(*classes: Type) -> None:DOCS
    """Unpatch the classes

    Args:
        classes: The classes to unpatch
    """
    for kls in classes:
        if PATCHED_CLASSES[kls]:
            _unpatch_cls_operator(kls, PipeableCall.PIPING)
        # Don't patch it in the future
        del PATCHED_CLASSES[kls]


def _patch_all(op: str) -> None:
    """Patch all registered classes that has the operator defined

    Args:
        op: The operator used for piping
            Avaiable:  ">>", "|", "//", "@", "%", "&" and "^"
        un: Unpatch the classes
    """
    for kls in PATCHED_CLASSES:
        _patch_cls_operator(kls, op)


def _unpatch_all(op: str) -> None:
    """Unpatch all registered classes

    Args:
        op: The operator used for piping
            Avaiable:  ">>", "|", "//", "@", "%", "&" and "^"
    """
    for kls in PATCHED_CLASSES:
        _unpatch_cls_operator(kls, op)


def _patch_default_classes() -> None:
    """Patch the default/commonly used classes"""

    try:
        import pandas
        patch_classes(
            pandas.DataFrame,
            pandas.Series,
            pandas.Index,
            pandas.Categorical,
        )
    except ImportError:
        pass

    try:  # pragma: no cover
        from modin import pandas  # pyright: ignore
        patch_classes(
            pandas.DataFrame,
            pandas.Series,
            pandas.Index,
            pandas.Categorical,
        )
    except ImportError:
        pass

    try:  # pragma: no cover
        import torch  # pyright: ignore
        patch_classes(torch.Tensor)
    except ImportError:
        pass

    try:  # pragma: no cover
        from django.db.models import query  # pyright: ignore
        patch_classes(query.QuerySet)
    except ImportError:
        pass


def register_piping(op: str) -> None:DOCS
    """Register the piping operator for verbs

    Args:
        op: The operator used for piping
            Avaiable:  ">>", "|", "//", "@", "%", "&" and "^"
    """
    if op not in PIPING_OPS:
        raise ValueError(f"Unsupported piping operator: {op}")

    from .verb import VerbCall

    if PipeableCall.PIPING:
        curr_method = PIPING_OPS[PipeableCall.PIPING][0]
        verb_orig_method = VerbCall.__orig_opmethod__  # type: ignore
        setattr(VerbCall, curr_method, verb_orig_method)
        _unpatch_all(PipeableCall.PIPING)

    PipeableCall.PIPING = op
    VerbCall.__orig_opmethod__ = getattr(  # type: ignore
        VerbCall,
        PIPING_OPS[op][0],
    )
    setattr(VerbCall, PIPING_OPS[op][0], VerbCall._pipda_eval)

    _patch_all(op)