"""Provide some utilities"""
from __future__ import annotations
import re
import sys
import importlib
import importlib.util
import logging
import textwrap
import typing
from itertools import groupby
from operator import itemgetter
from io import StringIO
from os import PathLike, get_terminal_size, environ
from collections import defaultdict
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
DefaultDict,
Iterable,
List,
Mapping,
Sequence,
Tuple,
Type,
)
import diot
import simplug
from rich.console import Console
from rich.logging import RichHandler as _RichHandler
from rich.table import Table
from rich.text import Text
from simplug import SimplugContext
from .defaults import (
CONSOLE_DEFAULT_WIDTH,
CONSOLE_WIDTH_WITH_PANEL,
CONSOLE_WIDTH_SHIFT,
LOGGER_NAME,
)
from .version import __version__
from importlib import metadata as importlib_metadata
if TYPE_CHECKING: # pragma: no cover
import pandas
from rich.segment import Segment
from rich.console import RenderableType
from .pipen import Pipen
from .proc import Proc
from .procgroup import ProcGroup
LOADING_ARGV0 = "@pipen"
class RichHandler(_RichHandler):DOCS
"""Subclass of rich.logging.RichHandler, showing log levels as a single
character"""
def get_level_text(self, record: logging.LogRecord) -> Text:DOCS
"""Get the level name from the record.
Args:
record: LogRecord instance.
Returns:
Text: A tuple of the style and level name.
"""
level_name = record.levelname
level_text = Text.styled(
level_name[0].upper(), f"logging.level.{level_name.lower()}"
)
return level_text
class RichConsole(Console):DOCS
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
try:
self._width = get_terminal_size().columns
except (AttributeError, ValueError, OSError): # maybe not a terminal
if environ.get("JUPYTER_COLUMNS") is not None: # pragma: no cover
self._width = int(environ.get("JUPYTER_COLUMNS"))
elif environ.get("COLUMNS") is not None: # pragma: no cover
self._width = int(environ.get("COLUMNS"))
else:
self._width = CONSOLE_DEFAULT_WIDTH
def _render_buffer(self, buffer: Iterable[Segment]) -> str:
out = super()._render_buffer(buffer)
return out.rstrip() + "\n"
logging.lastResort = logging.NullHandler() # type: ignore
logger_console = RichConsole()
_logger_handler = RichHandler(
show_path=False,
show_level=True,
console=logger_console,
rich_tracebacks=True,
omit_repeated_times=False, # rich 10+
markup=True,
log_time_format="%m-%d %H:%M:%S",
tracebacks_extra_lines=0,
tracebacks_suppress=[simplug, diot, typing],
)
_logger_handler.setFormatter(
logging.Formatter("[purple]%(plugin_name)-7s[/purple] %(message)s")
)
def _excepthook(
type_: Type[BaseException],
value: BaseException,
traceback: Any,
) -> None:
"""The excepthook for pipen, to show rich traceback"""
if issubclass(type_, KeyboardInterrupt): # pragma: no cover
logger.error("")
logger.error("Interrupted by user")
return
print("", file=sys.stderr)
_excepthook.oldhook(type_, value, traceback)
_excepthook.oldhook = sys.excepthook
sys.excepthook = _excepthook
def get_logger(DOCS
name: str = LOGGER_NAME,
level: str | int = None,
) -> logging.LoggerAdapter:
"""Get the logger by given plugin name
Args:
level: The initial level of the logger
Returns:
The logger
"""
log = logging.getLogger(f"pipen.{name}")
log.addHandler(_logger_handler)
if level is not None:
log.setLevel(level.upper() if isinstance(level, str) else level)
return logging.LoggerAdapter(log, {"plugin_name": name})
logger = get_logger()
def desc_from_docstring(DOCS
obj: Type[Pipen | Proc],
base: Type[Pipen | Proc],
) -> str:
"""Get the description from docstring
Only extract the summary.
Args:
obj: The object with docstring
Returns:
The summary as desc
"""
if not obj.__doc__:
# If the docstring is empty, use the base's docstring
# Get the base from mro
bases = [
cls
for cls in obj.__mro__
if is_subclass(cls, base) and cls != base and cls != obj
]
if not bases:
return None
return desc_from_docstring(bases[0], base)
started: bool = False
out: List[str] = []
for line in obj.__doc__.splitlines():
line = line.strip()
if not started and not line:
continue
if not started:
out.append(line)
started = True
elif line:
out.append(line)
else:
break
return " ".join(out)
def update_dict(DOCS
parent: Mapping[str, Any],
new: Mapping[str, Any],
depth: int = 0,
) -> Mapping[str, Any]:
"""Update the new dict to the parent, but make sure parent does not change
Args:
parent: The parent dictionary
new: The new dictionary
depth: The depth to be copied. 0 for updating to the deepest level.
Examples:
>>> parent = {"a": {"b": 1}}
>>> new = {"a": {"c": 2}}
>>> update_dict(parent, new)
>>> # {"a": {"b": 1, "c": 2}}
Returns:
The updated dictionary or None if both parent and new are None.
"""
if parent is None and new is None:
return None
out = (parent or {}).copy()
for key, val in (new or {}).items():
if (
key not in out
or not isinstance(val, dict)
or not isinstance(out[key], dict)
or depth == 1
):
out[key] = val
else:
out[key] = update_dict(out[key], val, depth - 1)
return out
def strsplit(DOCS
string: str,
sep: str,
maxsplit: int = -1,
trim: str = "both",
) -> List[str]:
"""Split the string, with the ability to trim each part."""
parts = string.split(sep, maxsplit=maxsplit)
if trim is None:
return parts
if trim == "left":
return [part.lstrip() for part in parts]
if trim == "right":
return [part.rstrip() for part in parts]
return [part.strip() for part in parts]
def get_shebang(script: str) -> str:DOCS
"""Get the shebang of the script
Args:
script: The script string
Returns:
None if the script does not contain a shebang, otherwise the shebang
without `#!` prefix
"""
script = script.lstrip()
if not script.startswith("#!"):
return None
if "\n" not in script:
return script[2:].strip()
shebang_line, _ = strsplit(script, "\n", 1)
return shebang_line[2:].strip()
def ignore_firstline_dedent(text: str) -> str:DOCS
"""Like textwrap.dedent(), but ignore first empty lines
Args:
text: The text the be dedented
Returns:
The dedented text
"""
out = []
started = False
for line in text.splitlines():
if not started and not line.strip():
continue
if not started:
started = True
out.append(line)
return textwrap.dedent("\n".join(out))
def copy_dict(dic: Mapping[str, Any], depth: int = 1) -> Mapping[str, Any]:DOCS
"""Deep copy a dict
Args:
dic: The dict to be copied
depth: The depth to be deep copied
Returns:
The deep-copied dict
"""
if depth <= 1:
return dic.copy()
return {
key: copy_dict(val, depth - 1) if isinstance(val, dict) else val
for key, val in dic.items()
}
def get_logpanel_width() -> int:DOCS
"""Get the width of the log content
Args:
max_width: The maximum width to return
Note that it's not the console width. With console width, you
have to subtract the width of the log meta info
(CONSOLE_WIDTH_SHIFT).
Returns:
The width of the log content
"""
return (
min(
logger_console.width,
CONSOLE_WIDTH_WITH_PANEL,
)
- CONSOLE_WIDTH_SHIFT
)
def log_rich_renderable(DOCS
renderable: RenderableType,
color: str | None,
logfunc: Callable,
*args: Any,
**kwargs: Any,
) -> None:
"""Log a rich renderable to logger
Args:
renderable: The rich renderable
splitline: Whether split the lines or log the entire message
logfunc: The log function, if message is not the first argument,
use functools.partial to wrap it
*args: The arguments to the log function
**kwargs: The keyword arguments to the log function
"""
console = Console(
file=StringIO(),
width=logger_console.width - CONSOLE_WIDTH_SHIFT,
)
console.print(renderable)
for line in console.file.getvalue().splitlines():
logfunc(
f"[{color}]{line}[/{color}]" if color else line,
*args,
**kwargs,
)
def brief_list(blist: List[int]) -> str:DOCS
"""Briefly show an integer list, combine the continuous numbers.
Args:
blist: The list
Returns:
The string to show for the briefed list.
"""
ret = []
for _, g in groupby(enumerate(blist), lambda x: x[0] - x[1]):
list_group = list(map(itemgetter(1), g))
if len(list_group) > 1:
ret.append(f"{list_group[0]}-{list_group[-1]}")
else:
ret.append(str(list_group[0]))
return ", ".join(ret)
def pipen_banner() -> RenderableType:DOCS
"""The banner for pipen
Returns:
The banner renderable
"""
table = Table(
width=get_logpanel_width(),
show_header=False,
show_edge=False,
show_footer=False,
show_lines=False,
caption=f"version: {__version__}",
)
table.add_column(justify="center")
table.add_row(r" _____________________________________ __")
table.add_row(r" ___ __ \___ _/__ __ \__ ____/__ | / /")
table.add_row(r" __ /_/ /__ / __ /_/ /_ __/ __ |/ / ")
table.add_row(r" _ ____/__/ / _ ____/_ /___ _ /| / ")
table.add_row(r"/_/ /___/ /_/ /_____/ /_/ |_/ ")
table.add_row("")
return table
def get_mtime(path: str | PathLike, dir_depth: int = 1) -> float:DOCS
"""Get the modification time of a path.
If path is a directory, try to get the last modification time of the
contents in the directory at given dir_depth
Args:
dir_depth: The depth of the directory to check the
last modification time
Returns:
The last modification time of path
"""
path = Path(path)
if not path.is_dir() or dir_depth == 0:
return path.lstat().st_mtime if path.is_symlink() else path.stat().st_mtime
mtime = 0.0
for file in path.glob("*"):
mtime = max(mtime, get_mtime(file, dir_depth - 1))
return mtime
def is_subclass(obj: Any, cls: type) -> bool:DOCS
"""Tell if obj is a subclass of cls
Differences with issubclass is that we don't raise Type error if obj
is not a class
Args:
obj: The object to check
cls: The class to check
Returns:
True if obj is a subclass of cls otherwise False
"""
try:
return issubclass(obj, cls)
except TypeError:
return False
def load_entrypoints(DOCS
group: str
) -> Iterable[Tuple[str, Any]]: # pragma: no cover
"""Load objects from setuptools entrypoints by given group name
Args:
group: The group name of the entrypoints
Returns:
An iterable of tuples with name and the loaded object
"""
try:
eps = importlib_metadata.entry_points(group=group)
except TypeError:
eps = importlib_metadata.entry_points().get(group, []) # type: ignore
yield from ((ep.name, ep.load()) for ep in eps)
def truncate_text(text: str, width: int, end: str = "…") -> str:DOCS
"""Truncate a text not based on words/whitespaces
Otherwise, we could use textwrap.shorten.
Args:
text: The text to be truncated
width: The max width of the the truncated text
end: The end string of the truncated text
Returns:
The truncated text with end appended.
"""
if len(text) <= width:
return text
return text[: (width - len(end))] + end
def make_df_colnames_unique_inplace(thedf: pandas.DataFrame) -> None:DOCS
"""Make the columns of a data frame unique
Args:
thedf: The data frame
"""
col_counts: DefaultDict = defaultdict(lambda: 0)
new_cols = []
for col in thedf.columns:
if col_counts[col] == 0:
new_cols.append(col)
else:
new_cols.append(f"{col}_{col_counts[col]}")
col_counts[col] += 1
thedf.columns = new_cols
def get_base(DOCS
klass: Type,
abc_base: Type,
value: Any,
value_getter: Callable,
) -> Type:
"""Get the base class where the value was first defined
Args:
klass: The class
abc_base: The very base class to check in __bases__
value: The value to check
value_getter: How to get the value from the class
Returns:
The base class
"""
bases = [
base
for base in klass.__bases__
if issubclass(base, abc_base) and value_getter(base) == value
]
if not bases:
return klass
return get_base(bases[0], abc_base, value, value_getter)
def mark(**kwargs) -> Callable[[type], type]:DOCS
"""Mark a class (e.g. Proc) with given kwargs as metadata
These marks will not be inherited by the subclasses if the class is
a subclass of `Proc` or `ProcGroup`.
Args:
**kwargs: The kwargs to mark the proc
Returns:
The decorator
"""
def decorator(cls: type) -> type:
if not getattr(cls, "__meta__", None):
cls.__meta__ = {}
cls.__meta__.update(kwargs)
return cls
return decorator
def get_marked(cls: type, mark_name: str, default: Any = None) -> Any:DOCS
"""Get the marked value from a proc
Args:
cls: The proc
mark_name: The mark name
default: The default value if the mark is not found
Returns:
The marked value
"""
if not getattr(cls, "__meta__", None):
return default
return cls.__meta__.get(mark_name, default)
def is_valid_name(name: str) -> bool:DOCS
"""Check if a name is valid for a proc or pipen
Args:
name: The name to check
Returns:
True if valid, otherwise False
"""
return re.match(r"^[\w.-]+$", name) is not None
def _get_obj_from_spec(spec: str) -> Any:
"""Get the object from a spec like `<module[.submodule]>:name` or
`/path/to/script.py:name`
Args:
spec: The spec
Returns:
The object
Raises:
AttributeError: If name cannot be found in the module
"""
modpath, sep, name = spec.rpartition(":")
if sep != ":":
raise ValueError(
f"Invalid specification: {spec}.\n"
"It must be in the format '<module[.submodule]>:name' or \n"
"'/path/to/spec.py:name'"
)
path = Path(modpath)
if path.is_file():
mspec = importlib.util.spec_from_file_location(path.stem, modpath)
module = importlib.util.module_from_spec(mspec)
mspec.loader.exec_module(module)
else:
module = importlib.import_module(modpath)
return getattr(module, name)
async def load_pipeline(DOCS
obj: str | Type[Proc] | Type[ProcGroup] | Type[Pipen],
argv0: str | None = None,
argv1p: Sequence[str] | None = None,
**kwargs: Any,
) -> Pipen:
"""Load a pipeline from a Pipen, Proc or ProcGroup object
It does not only load the Pipen object or convert the Proc/ProcGroup
object to Pipen, but also build the process relationships. So that we
can access `pipeline.procs` and `requires/nexts` of each proc.
To avoid running the pipeline and notify the plugins that this is just
for loading the pipeline, `sys.argv[0]` is set to `@pipen`.
Args:
obj: The Pipen, Proc or ProcGroup object. It can also be a string in
the format of `part1:part2` to load the pipeline, where part1 is
a path to a python file or package directory, and part2 is the name
of the proc, procgroup or pipeline to load.
It should be able to be loaded by `getattr(module, part2)`, where
module is loaded from `part1`.
argv0: The value to replace sys.argv[0]. "@pipen" will be used
by default.
argv1p: The values to replace sys.argv[1:]. Do not replace by default.
kwargs: The kwargs to pass to the Pipen constructor
Returns:
The loaded Pipen object
Raises:
TypeError: If obj or loaded obj is not a Pipen, Proc or ProcGroup
object
"""
from .pipen import Pipen
from .proc import Proc
from .procgroup import ProcGroup
old_argv = sys.argv
if argv0 is None:
# Set it at runtime to allow LOADING_ARGV0 to be monkey-patched
argv0 = LOADING_ARGV0
if argv1p is None:
# Set it at runtime to adopt sys.argv changes
argv1p = sys.argv[1:]
sys.argv = [argv0] + list(argv1p)
try:
if isinstance(obj, str):
obj = _get_obj_from_spec(obj)
if isinstance(obj, Pipen) or (
isinstance(obj, type) and issubclass(obj, (Pipen, Proc, ProcGroup))
):
pass
else:
raise TypeError(
"Expected a Pipen, Proc, ProcGroup class, or a Pipen object, "
f"got {type(obj)}"
)
pipeline = obj
if isinstance(obj, type) and issubclass(obj, Proc):
kwargs.setdefault("name", f"{obj.name}Pipeline")
pipeline = Pipen(**kwargs).set_starts(obj)
elif isinstance(obj, type) and issubclass(obj, ProcGroup):
pipeline = obj().as_pipen(**kwargs) # type: ignore
elif isinstance(obj, type) and issubclass(obj, Pipen):
# Avoid "pipeline" to be used as pipeline name by varname
(pipeline, ) = (obj(**kwargs), ) # type: ignore
elif isinstance(obj, Pipen):
pipeline._kwargs.update(kwargs)
# Initialize the pipeline so that the arguments definied by
# other plugins (i.e. pipen-args) to take in place.
pipeline.workdir = Path(pipeline.config.workdir).joinpath(
kwargs.get("name", pipeline.name)
)
await pipeline._init()
pipeline.workdir.mkdir(parents=True, exist_ok=True)
pipeline.build_proc_relationships()
finally:
sys.argv = old_argv
return pipeline
def is_loading_pipeline(*flags: str, argv: Sequence[str] | None = None) -> bool:DOCS
"""Check if we are loading the pipeline. Works only when
`argv0` is "@pipen" while loading the pipeline.
Note if you are using this function at compile time, make
sure you load your pipeline using the string form (`part1:part2`)
See more with `load_pipline()`.
Args:
*flags: Additional flags to check in sys.argv (e.g. "-h", "--help")
to determine if we are loading the pipeline
argv: The arguments to check. sys.argv is used by default.
Note that the first argument should be included in the check.
You could typically pass `[sys.argv[0], *your_args]` to this if you want
to check if `sys.argv[0]` is "@pipen" or `your_args` contains some flags.
Returns:
True if we are loading the pipeline (argv[0] == "@pipen"),
otherwise False
"""
if argv is None:
argv = sys.argv
if len(argv) > 0 and argv[0] == LOADING_ARGV0:
return True
if flags:
return any(flag in argv for flag in flags)
return False # pragma: no cover