Skip to content

SOURCE CODE plkit.utils DOCS

"""Utility functions for plkit"""
from typing import Iterable, List, Optional, Tuple, Union
import sys
import logging
import warnings
from io import StringIO
from contextlib import contextmanager

from rich.table import Table
from rich.console import Console
from rich.logging import RichHandler
from diot import FrozenDiot
from pytorch_lightning import seed_everything, _logger as logger

from .exceptions import PlkitConfigException

RatioType = Union[int, float]

del logger.handlers[:]
logger.addHandler(RichHandler(show_path=False))

logging.getLogger('py.warnings').addHandler(RichHandler(show_path=False))


def check_config(config,DOCS
                 item,
                 how=lambda conf, key: key in conf,
                 msg="Configuration item {key} is required."):
    """Check configuration items

    Args:
        config (dict): The configuration dictionary
        item (str): The configuration key to check
        how (callable): How to check. Return False to fail the check.
        msg (str): The message to show in the exception.
            `{key}` is available to refer to the key checked.

    Raises:
        PlkitConfigException: When the check fails
    """
    checked = how(config, item)
    if not checked:
        raise PlkitConfigException(msg.format(key=item))

def collapse_suggest_config(config: dict) -> dict:DOCS
    """Use the default value of OptunaSuggest for config items.
    So that the configs can be used in the case that optuna is opted out.

    Args:
        config: The configuration dictionary

    Returns:
        The collapsed configuration
    """
    from .optuna import OptunaSuggest
    config = config.copy()
    collapsed = {key: val.default
                 for key, val in config.items()
                 if isinstance(val, OptunaSuggest)}
    if isinstance(config, FrozenDiot):
        with config.thaw():
            config.update(collapsed)
        return config
    config.update(collapsed)
    return FrozenDiot(config)

def normalize_tvt_ratio(DOCS
        tvt_ratio: Optional[Union[RatioType, Iterable[RatioType]]]
) -> Optional[Tuple[RatioType, List[RatioType], List[RatioType]]]:
    """Normalize the train-val-test data ratio into a format of
    (.7, [.1, .1], [.05, .05]).

    For `config.data_tvt`, the first element is required. If val or test ratios
    are not provided, it will be filled with `None`

    All numbers could be absolute numbers (>1) or ratios (<=1)

    Args:
        tvt_ratio: The train-val-test ratio

    Returns:
        The normalized ratios

    Raises:
        PlkitConfigException: When the passed-in tvt_ratio is in malformat
    """
    if not tvt_ratio:
        return None

    is_iter = lambda container: isinstance(container, (tuple, list))

    if not is_iter(tvt_ratio):
        tvt_ratio = [tvt_ratio]

    tvt_ratio = list(tvt_ratio)

    if len(tvt_ratio) < 3:
        tvt_ratio += [None] * (3 - len(tvt_ratio))

    if tvt_ratio[1] and not is_iter(tvt_ratio[1]):
        tvt_ratio[1] = [tvt_ratio[1]]
    if tvt_ratio[2] and not is_iter(tvt_ratio[2]):
        tvt_ratio[2] = [tvt_ratio[2]]

    return tuple(tvt_ratio)

@contextmanagerDOCS
def warning_to_logging():
    """Patch the warning message formatting to only show the message"""
    orig_format = warnings.formatwarning
    logging.captureWarnings(True)
    warnings.formatwarning = (
        lambda msg, category, *args, **kwargs: f'{category.__name__!r}: {msg}'
    )
    yield
    warnings.formatwarning = orig_format
    logging.captureWarnings(False)

@contextmanagerDOCS
def capture_stdout():
    """Capture the stdout"""
    _stdout = sys.stdout
    sys.stdout = stringio = StringIO()
    yield stringio
    del stringio
    sys.stdout = _stdout

@contextmanagerDOCS
def capture_stderr():
    """Capture the stderr"""
    _stderr = sys.stderr
    sys.stderr = stringio = StringIO()
    yield stringio
    del stringio
    sys.stderr = _stderr

@contextmanagerDOCS
def output_to_logging(stdout_level: str = 'info', stderr_level: str = 'error'):
    """Capture the stdout or stderr to logging"""
    with capture_stderr() as err, capture_stdout() as out:
        yield

    getattr(logger, stdout_level)(out.getvalue())
    getattr(logger, stderr_level)(err.getvalue())


def log_config(config, title='Configurations', items_per_row=1):DOCS
    """Log the configurations in a table in terminal

    Args:
        config (dict): The configuration dictionary
        title (str): The title of the table
        items_per_row (int): The number of items to print per row
    """
    table = Table(title=title)
    items = list(config.items())
    for i in range(items_per_row):
        table.add_column("Item")
        table.add_column("Value")

    for i in range(0, len(items), items_per_row):
        row_items = []
        for x in range(items_per_row):
            try:
                row_items.append(items[i + x][0])
                row_items.append(repr(items[i + x][1]))
            except IndexError: # pragma: no cover
                row_items.append('')
                row_items.append('')

        table.add_row(*row_items)

    console = Console(file=StringIO(), markup=False)
    console.print(table)
    logger.info('')
    for line in console.file.getvalue().splitlines():
        logger.info(line)

def plkit_seed_everything(config: FrozenDiot):DOCS
    """Try to seed everything and set deterministic to True
    if seed in config has been set

    Args:
        config: The configurations
    """
    if config.get('seed') is None:
        return

    seed_everything(config.seed)
    with config.thaw():
        config.setdefault('deterministic', True)