"""
Misc. ObsPlus utilities.
"""
import contextlib
import fnmatch
import hashlib
import os
import warnings
from functools import wraps, partial, singledispatch
from os.path import join
from pathlib import Path, PurePosixPath
from typing import (
Generator,
Tuple,
Any,
Set,
Optional,
Callable,
Union,
TypeVar,
Collection,
Dict,
Iterable,
)
import numpy as np
import obspy
import pandas as pd
from obspy.core import event as ev
from obspy.core.inventory import Station, Channel
from obspy.io.mseed.core import _read_mseed as mread
from obspy.io.quakeml.core import _read_quakeml
import obsplus
from obsplus.constants import NULL_SEED_CODES, NSLC
BASIC_NON_SEQUENCE_TYPE = (int, float, str, bool, type(None))
READ_DICT = dict(mseed=mread, quakeml=_read_quakeml)
def _get_progressbar():
"""Suppress ProgressBar's warning."""
# TODO remove this when progress no longer issues warning
with suppress_warnings():
from progressbar import ProgressBar
return ProgressBar
[docs]
def deprecated_callable(func=None, replacement_str=None):
"""
Mark a function as deprecated.
Whenever it is used a userwarning will be issued. You can optionally
provide a string to indicate which function should be used in its place.
Parameters
----------
func
replacement_str
Returns
-------
"""
fname = str(getattr(func, "__name__", func))
if callable(func):
@wraps(func)
def _wrap(*args, **kwargs):
msg = f"{fname} is deprecated and will be removed in a future release."
if replacement_str:
msg += f" Please use {replacement_str} instead."
warnings.warn(msg)
return func(*args, **kwargs)
return _wrap
else:
return partial(deprecated_callable, replacement_str=replacement_str)
[docs]
def yield_obj_parent_attr(
obj, cls=None, is_attr=None, has_attr=None, basic_types=False
) -> Generator[Tuple[Any, Any, str], None, None]:
"""
Recurse an object, yield a tuple of object, parent, attr.
Useful when data need to be changed or the provided DataFrame extractors
don't quite perform the desired task. Can also be used to extract
relationships between entities in object trees to build a connecting graph.
Parameters
----------
obj
The object to recurse through attributes of lists, tuples, and other
instances.
cls
Only return instances of cls if not None, dont filter on types.
is_attr
Only return objects stored as attr_name, if None return all.
has_attr
Only return objects that have attribute has_attr, if None return all.
basic_types
If True, yield non-sequence basic types (int, float, str, bool).
Examples
--------
>>> # --- get all picks from a catalog object
>>> import obsplus
>>> import obspy.core.event as ev
>>> cat = obsplus.load_dataset('bingham_test').event_client.get_events()
>>> picks = [] # put all the picks in a list.
>>> for pick, _, _ in yield_obj_parent_attr(cat, cls=ev.Pick):
... picks.append(pick)
>>> assert len(picks)
>>> # --- yield all objects which have resource identifiers
>>> objects = [] # list of (rid, parent)
>>> RID = ev.ResourceIdentifier
>>> for rid, parent, attr in yield_obj_parent_attr(cat, cls=RID):
... objects.append((str(rid), parent))
>>> assert len(objects)
>>> # --- Create a dict of {resource_id: [(attr, parent), ...]}
>>> from collections import defaultdict
>>> rid_mapping = defaultdict(list)
>>> for rid, parent, attr in yield_obj_parent_attr(cat, cls=RID):
... rid_mapping[str(rid)].append((attr, parent))
>>> # count how many times each resource_id is referred to
>>> count = {i: len(v) for i, v in rid_mapping.items()}
"""
ids: Set[int] = set() # id cache to avoid circular references
def func(obj, attr=None, parent=None):
id_tuple = (id(obj), id(parent))
# If object/parent combo have not been yielded continue.
if id_tuple in ids:
return
ids.add(id_tuple)
# Check if this object is stored as the desired attribute.
is_attribute = is_attr is None or attr == is_attr
# Check if the object has the desired attribute.
has_attribute = has_attr is None or hasattr(obj, has_attr)
# Check if isinstance of desired class.
is_instance = cls is None or isinstance(obj, cls)
# Check if basic type (dont
is_basic = basic_types or not isinstance(obj, BASIC_NON_SEQUENCE_TYPE)
# Iterate through basic built-in types.
if isinstance(obj, (list, tuple)):
for val in obj:
yield from func(val, attr=attr, parent=parent)
elif isinstance(obj, dict):
for item, val in obj.items():
yield from func(val, attr=item, parent=obj)
# Yield object, parent, and attr if desired conditions are met.
elif is_attribute and has_attribute and is_instance and is_basic:
yield (obj, parent, attr)
# Iterate through non built-in object attributes.
if hasattr(obj, "__slots__"):
for attr in obj.__slots__:
val = getattr(obj, attr)
yield from func(val, attr=attr, parent=obj)
if hasattr(obj, "__dict__"):
for item, val in obj.__dict__.items():
yield from func(val, attr=item, parent=obj)
return func(obj)
[docs]
def get_instances_from_tree(object, cls):
"""
Get all instances in an object tree.
Simply uses :func:`~obsplus.utils.misc.yield_obj_parent_attr` under the
hood.
"""
return [x for x, _, _ in yield_obj_parent_attr(object, cls=cls)]
[docs]
def try_read_catalog(catalog_path, **kwargs):
"""Try to read a events from file, if it raises return None"""
read = READ_DICT.get(kwargs.pop("format", None), obspy.read_events)
try:
cat = read(catalog_path, **kwargs)
except Exception:
warnings.warn(f"obspy failed to read {catalog_path}")
else:
if cat is not None and len(cat):
return cat
return None
[docs]
def read_file(file_path, funcs=(pd.read_csv,)) -> Optional[Any]:
"""
For a given file_path, try reading it with each function in funcs.
Parameters
----------
file_path
The path to the file to read
funcs
A tuple of functions to try to read the file (starting with first)
"""
for func in funcs:
assert callable(func)
try:
return func(file_path)
except Exception:
pass
raise IOError(f"failed to read {file_path}")
[docs]
def apply_to_files_or_skip(func: Callable, directory: Union[str, Path]):
"""
Generator for applying func to all files in directory.
Skip any files that raise an exception.
Parameters
----------
func
Any callable that takes a file path as the only input.
directory
A directory that exists.
Yields
------
outputs of func
"""
path = Path(directory)
assert path.is_dir(), f"{directory} is not a directory"
for fi in path.rglob("*"):
if os.path.isfile(fi):
try:
yield func(fi)
except Exception:
pass
[docs]
def get_progressbar(max_value, min_value=None, *args, **kwargs) -> Optional:
"""
Get a progress bar object using the ProgressBar2 library.
Fails gracefully if bar cannot be displayed (eg if no std out).
Args and kwargs are passed to ProgressBar constructor.
Parameters
----------
max_value
The highest number expected
min_value
The minimum number of updates required to show the bar
"""
def _new_update(bar):
"""A new update function that swallows attribute and index errors"""
old_update = bar.update
def update(value=None, force=False, **kwargs):
with contextlib.suppress((IndexError, ValueError, AttributeError)):
old_update(value=value, force=force, **kwargs)
return update
if min_value and max_value < min_value:
return None # no progress bar needed, return None
try:
ProgressBar = _get_progressbar()
bar = ProgressBar(max_value=max_value, *args, **kwargs)
bar.start()
bar.update = _new_update(bar)
bar.update(1)
except Exception: # this can happen when stdout is being redirected
return None # something went wrong, return None
return bar
[docs]
def iterate(obj):
"""
Return an iterable from any object.
If string, do not iterate characters, return str in tuple .
"""
if obj is None:
return ()
if isinstance(obj, str):
return (obj,)
return obj if isinstance(obj, Iterable) else (obj,)
[docs]
class DummyFile(object):
"""Dummy class to mock std out interface but go nowhere."""
[docs]
def write(self, x):
"""do nothing"""
[docs]
def flush(self):
"""do nothing"""
[docs]
def getattrs(obj: object, col_set: Collection, default_value: object = np.nan) -> dict:
"""
Parse an object for a collection of attributes, return a dict of values.
If obj does not have a requested attribute, or if its value is None, fill
with the default value.
Parameters
----------
obj
Any object.
col_set
A sequence of attributes to extract from obj.
default_value
If not attribute is found fill with this value.
"""
out = {}
if obj is None: # return empty dict if None
return out
for item in col_set:
try:
val = getattr(obj, item)
except (ValueError, AttributeError):
val = default_value
if val is None:
val = default_value
out[item] = val
return out
any_type = TypeVar("any_type")
[docs]
@singledispatch
def replace_null_nlsc_codes(
obspy_object: any_type, null_codes=NULL_SEED_CODES, replacement_value=""
) -> any_type:
"""
Iterate an obspy object and replace nullish nslc codes with some value.
Operates in place, but also returns the original object.
Parameters
----------
obspy_object
An obspy catalog, event, (or any sub element), stream, trace,
inventory, etc.
null_codes
The codes that are considered null values and should be replaced.
replacement_value
The value with which to replace the null_codes.
"""
wid_codes = tuple(x + "_code" for x in NSLC)
for wid, _, _ in yield_obj_parent_attr(obspy_object, cls=ev.WaveformStreamID):
for code in wid_codes:
if getattr(wid, code) in null_codes:
setattr(wid, code, replacement_value)
return obspy_object
@replace_null_nlsc_codes.register(obspy.Stream)
def _replace_null_stream(st, null_codes=NULL_SEED_CODES, replacement_value=""):
for tr in st:
_replace_null_trace(tr, null_codes, replacement_value)
return st
@replace_null_nlsc_codes.register(obspy.Trace)
def _replace_null_trace(tr, null_codes=NULL_SEED_CODES, replacement_value=""):
for code in NSLC:
val = getattr(tr.stats, code)
if val in null_codes:
setattr(tr.stats, code, replacement_value)
return tr
@replace_null_nlsc_codes.register(obspy.Inventory)
@replace_null_nlsc_codes.register(Station)
@replace_null_nlsc_codes.register(Channel)
def _replace_inv_nulls(inv, null_codes=NULL_SEED_CODES, replacement_value=""):
for code in ["location_code", "code"]:
for obj, _, _ in yield_obj_parent_attr(inv, has_attr=code):
if getattr(obj, code) in null_codes:
setattr(obj, code, replacement_value)
return inv
[docs]
def iter_files(
paths: Union[str, Iterable[str]],
ext: Optional[str] = None,
mtime: Optional[float] = None,
skip_hidden: bool = True,
) -> Iterable[str]:
"""
use os.scan dir to iter files, optionally only for those with given
extension (ext) or modified times after mtime
Parameters
----------
paths
The path to the base directory to traverse. Can also use a collection
of paths.
ext : str or None
The extensions to map.
mtime : int or float
Time stamp indicating the minimum mtime.
skip_hidden : bool
If True skip files or folders (they begin with a '.')
Yields
------
Paths, as strings, meeting requirements.
"""
try: # a single path was passed
for entry in os.scandir(paths):
if entry.is_file() and (ext is None or entry.name.endswith(ext)):
if mtime is None or entry.stat().st_mtime >= mtime:
if entry.name[0] != "." or not skip_hidden:
yield entry.path
elif entry.is_dir() and not (skip_hidden and entry.name[0] == "."):
yield from iter_files(
entry.path, ext=ext, mtime=mtime, skip_hidden=skip_hidden
)
except TypeError: # multiple paths were passed
for path in paths:
yield from iter_files(path, ext, mtime, skip_hidden)
except NotADirectoryError: # a file path was passed, just return it
yield paths
[docs]
def hash_file(path: Union[str, Path]):
"""
Calculate the sha256 hash of a file.
Reads the file in chunks to allow using large files. Taken from this stack
overflow answer: http://bit.ly/2Jqb1Jr
Parameters
----------
path
The path to the file to read.
Returns
-------
A str of hex for file hash
"""
path = Path(path)
hasher = hashlib.sha256()
with path.open("rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hasher.update(chunk)
return hasher.hexdigest()
[docs]
def hash_directory(
path: Union[Path, str],
match: str = "*",
exclude: Optional[Union[str, Collection[str]]] = None,
hidden=False,
) -> Dict[str, str]:
"""
Calculate the sha256 hash of all files in a directory.
Parameters
----------
path
The path to the directory
match
A unix-style matching string
exclude
A list of unix style strings to exclude
hidden
If True skip all files starting with a .
Returns
-------
A dict containing paths and md5 hashes.
"""
path = Path(path)
out = {}
excludes = iterate(exclude)
for sub_path in path.rglob(match):
keep = True
# skip directories
if sub_path.is_dir():
continue
# skip if matches on exclusion
for exc in excludes:
if fnmatch.fnmatch(sub_path.name, exc):
keep = False
break
if not hidden and sub_path.name.startswith("."):
keep = False
if keep:
relative_path = sub_path.relative_to(path)
out[str(PurePosixPath(relative_path))] = hash_file(sub_path)
return out
def _get_path(info, path, name, path_struct, name_strcut):
"""return a dict with path, and file name"""
if path is None: # if the path needs to be created
ext = info.get("ext", "")
# get name
fname = name or name_strcut.format_map(info)
fname = fname if fname.endswith(ext) else fname + ext # add ext
# get structure
psplit = path_struct.format_map(info).split("/")
path = join(*psplit, fname)
out_name = fname
else: # if the path is already known
out_name = os.path.basename(path)
return dict(path=path, filename=out_name)
[docs]
@contextlib.contextmanager
def suppress_warnings(category=Warning):
"""
Context manager for suppressing warnings.
Parameters
----------
category
The types of warnings to suppress. Must be a subclass of Warning.
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=category)
yield
return None
[docs]
def register_func(list_or_dict: Union[list, dict], key: Optional[str] = None):
"""
Decorator for registering a function name in a list or dict.
If list_or_dict is a list only append the name of the function. If it is
as dict append name (as key) and function as the value.
Parameters
----------
list_or_dict
A list or dict to which the wrapped function will be added.
key
The name to use, if different than the name of the function.
"""
def wrapper(func):
name = key or func.__name__
if hasattr(list_or_dict, "append"):
list_or_dict.append(name)
else:
list_or_dict[name] = func
return func
return wrapper
[docs]
def validate_version_str(version_str: str):
"""
Check the version string is of the form x.y.z.
If the version string is not valid raise a ValueError.
"""
is_str = isinstance(version_str, str)
# If version_str is not a str or doesnt have a len of 3
if not (is_str and len(version_str.split(".")) == 3):
msg = f"version must be a string of the form x.y.z, not {version_str}"
raise ValueError(msg)
# this will split out the dev version tags to just get latest version
out = version_str.split("dev")[0].split("+")[0]
return out
[docs]
def get_version_tuple(version_str: str) -> Tuple[int, int, int]:
"""
Convert a semantic version string to a tuple.
Parameters
----------
version_str
A version of the form "x.y.z". Google semantic versioning for more
details.
"""
version_str = validate_version_str(version_str)
split = version_str.split(".")
return int(split[0]), int(split[1]), int(split[2])
[docs]
def strip_prefix(some_str: str, prefixes: Union[str, Collection[str]]) -> str:
"""Strip a prefix of a string."""
out = some_str
for prefix in obsplus.utils.iterate(prefixes):
if out.startswith(prefix):
out = out[len(prefix) :]
return out
[docs]
class ObjectWrapper:
"""
A class for wrapping objects.
This is useful so array-like things can be packaged into numpy arrays
and pandas dataframes.
"""
__slots__ = ("data",) # this speeds up class creation/memory usage
def __init__(self, data):
self.data = data