Skip to content

SOURCE CODE plotnine_prism.scale DOCS

"""Provides scales"""

from plotnine.exceptions import PlotnineError
from plotnine.scales.scale import scale_discrete, scale_continuous
from .pal import prism_color_pal, prism_fill_pal, prism_shape_pal


def get_minor_breaks(DOCS
    self,
    major,
    limits=None,
):
    """
    Return minor breaks

    See https://github.com/has2k1/plotnine/issues/696
    """
    if limits is None:
        limits = self.limits

    if self.minor_breaks is True:
        # TODO: Remove ignore when mizani is static typed
        minor_breaks = self.trans.minor_breaks(
            major, limits
        )  # pyright: ignore
    elif isinstance(self.minor_breaks, int):
        # TODO: Remove ignore when mizani is static typed
        minor_breaks = self.trans.minor_breaks(
            major, limits, n=self.minor_breaks
        )  # pyright: ignore
    elif (
        self.minor_breaks is False
        or self.minor_breaks is None
        or not len(major)
    ):
        minor_breaks = []
    elif callable(self.minor_breaks):
        breaks = self.minor_breaks(self.trans.inverse(limits))
        _major = set(major)
        minor = self.trans.transform(breaks)
        minor_breaks = [x for x in minor if x not in _major]
    else:
        minor_breaks = self.trans.transform(self.minor_breaks)

    return minor_breaks


def get_labels(DOCS
    self, breaks=None
):
    """
    Generate labels for the axis or legend

    Parameters
    ----------
    breaks: None or array-like
        If None, use self.breaks.
    """
    if breaks is None:
        breaks = self.get_breaks()

    breaks = self.inverse(breaks)

    if self.labels is True:
        labels = self.trans.format(breaks)
    elif self.labels is False or self.labels is None:
        labels = []
    elif callable(self.labels):
        labels = self.labels(breaks)
    elif isinstance(self.labels, dict):
        labels = [
            str(self.labels[b]) if b in self.labels else str(b)
            for b in breaks
        ]
    else:
        # When user sets breaks and labels of equal size,
        # but the limits exclude some of the breaks.
        # We remove the corresponding labels
        from collections.abc import Sized

        labels = self.labels
        if (
            len(labels) != len(breaks)
            and isinstance(self.breaks, Sized)
            and len(labels) == len(self.breaks)
        ):
            _wanted_breaks = set(breaks)
            labels = [
                lab
                for lab, b in zip(labels, self.breaks)
                if b in _wanted_breaks
            ]

    if len(labels) != len(breaks):
        raise PlotnineError("Breaks and labels are different lengths")

    return labels


# Patch scale_continuous to fix #696
scale_continuous.get_minor_breaks = get_minor_breaks
scale_continuous.get_labels = get_labels


class scale_color_prism(scale_discrete):DOCS
    """Prism color scale

    Args:
        palette: The color palette name
    """

    _aesthetics = ["color"]
    na_value = "#7F7F7F"

    def __init__(self, palette="colors", **kwargs):
        """Construct"""
        self.palette = prism_color_pal(palette)
        scale_discrete.__init__(self, **kwargs)


class scale_fill_prism(scale_color_prism):DOCS
    """Prism fill scale

    Args:
        palette: The fill palette name
    """

    _aesthetics = ["fill"]
    na_value = "#7F7F7F"

    def __init__(self, palette="colors", **kwargs):
        """Construct"""
        self.palette = prism_fill_pal(palette)
        scale_discrete.__init__(self, **kwargs)


class scale_shape_prism(scale_discrete):DOCS
    """Prism shape scale

    Args:
        palette: The shape palette name
    """

    _aesthetics = ["shape"]

    def __init__(self, palette="default", **kwargs):
        """Construct"""
        self.palette = prism_shape_pal(palette)
        scale_discrete.__init__(self, **kwargs)


scale_colour_prism = scale_color_prism