Source code for obsplus.utils.bank

"""
Utils for banks
"""
import contextlib
import os
import re
import sqlite3
import time
import warnings
from typing import Optional, Sequence, List

import obspy
import pandas as pd
import numpy as np
from tables.exceptions import ClosedNodeError

from obsplus.constants import (
    NSLC,
    WAVEFORM_STRUCTURE,
    WAVEFORM_NAME_STRUCTURE,
    SMALLDT64,
    LARGEDT64,
    READ_HDF5_KWARGS,
)
from obsplus.exceptions import UnsupportedKeyword
from obsplus.utils.geodetics import map_longitudes
from obsplus.utils.misc import READ_DICT, _get_path
from obsplus.utils.mseed import summarize_mseed
from obsplus.utils.time import to_datetime64, _dict_times_to_ns

# functions for summarizing the various formats
summarizing_functions = dict(mseed=summarize_mseed)

# extensions
WAVEFORM_EXT = ".mseed"
EVENT_EXT = ".xml"
STATION_EXT = ".xml"


# name structures


def _get_time_values(time1, time2=None):
    """get the time values from a UTCDateTime object or two"""
    tvals = "year month day hour minute second microsecond".split()
    utc1 = time1
    split = re.split("-|:|T|[.]", str(utc1).replace("Z", ""))
    assert len(tvals) == len(split)
    out = {key: val for key, val in zip(tvals, split)}
    out["julday"] = "%03d" % utc1.julday
    out["starttime"] = utc1.timestamp
    if time2:
        out["endtime"] = time2.timestamp
    out["time"] = str(utc1).replace(":", "-").split(".")[0]
    return out


[docs] def summarize_generic_stream(path, format=None) -> List[dict]: """ Return summary information for a stream stored on disk. Parameters ---------- path The path to the stream. format The format code, as used by obspy. """ st = _try_read_stream(path, format=format) or _try_read_stream(path) or [] out = [] for tr in st: summary = { "starttime": tr.stats.starttime._ns, "endtime": tr.stats.endtime._ns, "sampling_period": int(tr.stats.delta * 1_000_000_000), "path": path, } summary.update(dict((x, c) for x, c in zip(NSLC, tr.id.split(".")))) out.append(summary) return out
def _summarize_wave_file(path, format, summarizer=None): """ Summarize waveform files for indexing. Note: this function is a bit ugly, but it gets called *a lot* so be sure to profile before refactoring. """ if summarizer is not None: try: return summarizer(path) except Exception: pass return summarize_generic_stream(path, format) def _summarize_trace( trace: obspy.Trace, path: Optional[str] = None, name: Optional[str] = None, path_struct: Optional[str] = None, name_struct: Optional[str] = None, ) -> dict: """ Function to extract info from traces for indexing. Parameters ---------- trace The trace object path Other Parameters to the file name Name of the file path_struct directory structure to create name_struct structure of the file name if name is not used. """ assert hasattr(trace, "stats"), "only a trace object is accepted" out = {"seedid": trace.id, "ext": WAVEFORM_EXT} t1, t2 = trace.stats.starttime, trace.stats.endtime out.update(_get_time_values(t1, t2)) out.update(dict((x, c) for x, c in zip(NSLC, trace.id.split(".")))) path_struct = path_struct if path_struct is not None else WAVEFORM_STRUCTURE name_struct = name_struct or WAVEFORM_NAME_STRUCTURE out.update(_get_path(out, path, name, path_struct, name_struct)) return out def _remove_base_path(series: pd.Series, base="") -> pd.Series: """ Ensure paths stored in column name use unix style paths and have base path removed. """ if series.empty: return series unix_paths = series.str.replace(os.sep, "/", regex=False) unix_base_path = str(base).replace(os.sep, "/") # NOTE: this uses "/" not os.sep because we've already swapped out os.sep return unix_paths.str.replace(unix_base_path + "/", "", regex=False) def _natify_paths(series: pd.Series) -> pd.Series: """ Natify paths in a series. IE, on windows replace / with \ """ if series.empty: return series return series.str.replace("/", os.sep) class _IndexCache: """A simple class for caching indexes""" def __init__(self, bank, cache_size=5): self.max_size = cache_size self.bank = bank self.cache = pd.DataFrame( index=range(cache_size), columns="t1 t2 kwargs cindex".split() ) self._current_index = 0 # self.next_index = itertools.cycle(self.cache.index) def __call__(self, starttime, endtime, buffer, **kwargs): """get start and end times, perform in kernel lookup""" starttime, endtime = self._get_times(starttime, endtime) self._validate_kwargs(kwargs) # find out if the query falls within one cached times con1 = self.cache.t1 <= starttime con2 = self.cache.t2 >= endtime con3 = self.cache.kwargs == self._kwargs_to_str(kwargs) cached_index = self.cache[con1 & con2 & con3] if not len(cached_index): # query is not cached get it from hdf5 file where = _get_kernel_query( starttime.astype(np.int64), endtime.astype(np.int64), int(buffer) ) raw_index = self._get_index(where, **kwargs) # replace "None" with None ic = self.bank.index_str raw_index.loc[:, ic] = raw_index.loc[:, ic].replace(["None"], [None]) # convert data types used by bank back to those seen by user index = raw_index.astype(dict(self.bank._dtypes_output)) self._set_cache(index, starttime, endtime, kwargs) else: index = cached_index.iloc[0]["cindex"] # trim down index con1 = index["starttime"] >= (endtime + buffer) con2 = index["endtime"] <= (starttime - buffer) return index[~(con1 | con2)] @staticmethod def _get_times(starttime, endtime): """Return starttimes and endtimes.""" # get defaults if starttime or endtime is none starttime = None if pd.isnull(starttime) else starttime endtime = None if pd.isnull(endtime) else endtime starttime = to_datetime64(starttime or SMALLDT64) endtime = to_datetime64(endtime or LARGEDT64) if starttime is not None and endtime is not None: if starttime > endtime: msg = "starttime cannot be greater than endtime." raise ValueError(msg) return starttime, endtime def _validate_kwargs(self, kwargs): """Ensure kwargs are supported.""" kwarg_set = set(kwargs) if not kwarg_set.issubset(READ_HDF5_KWARGS): bad_kwargs = kwarg_set - set(READ_HDF5_KWARGS) msg = f"The following kwargs are not supported: {bad_kwargs}. " raise UnsupportedKeyword(msg) def _set_cache(self, index, starttime, endtime, kwargs): """Cache the current index""" ser = pd.Series( { "t1": starttime, "t2": endtime, "cindex": index, "kwargs": self._kwargs_to_str(kwargs), } ) self.cache.loc[self._get_next_index()] = ser def _kwargs_to_str(self, kwargs): """convert kwargs to a string""" keys = sorted(list(kwargs.keys())) ou = str([(item, kwargs[item]) for item in keys]) return ou def _get_index(self, where, fail_counts=0, **kwargs): """read the hdf5 file""" try: return pd.read_hdf( self.bank.index_path, self.bank._index_node, where=where, **kwargs ) except (ClosedNodeError, Exception) as e: # Sometimes in concurrent updates the nodes need time to open/close if fail_counts > 10: raise e # Wait a bit and try again (up to 10 times) time.sleep(0.1) return self._get_index(where, fail_counts=fail_counts + 1, **kwargs) def clear_cache(self): """removes all cached dataframes.""" self.cache = pd.DataFrame( index=range(self.max_size), columns="t1 t2 kwargs cindex".split() ) def _get_next_index(self): """ Get the next index value on cache. Note we can't use itertools.cycle here because it cant be pickled. """ if self._current_index == len(self.cache.index) - 1: self._current_index = 0 else: self._current_index += 1 return self.cache.index[self._current_index]
[docs] @contextlib.contextmanager def sql_connection(path, **kwargs): """ A simple context manager for a sqlite connection. Parameters ---------- path A path to the sqlite database. """ con = sqlite3.connect(str(path), **kwargs) with con: yield con con.close() # this is needed on windows but not linux, weird...
def _get_kernel_query(starttime: int, endtime: int, buffer: int): """ Create a HDF5 kernel query based on start and end times. This is necessary because hdf5 doesnt accept inverted conditions. A slight buffer is applied to the ranges to make sure no edge files are excluded. """ t1 = starttime - buffer t2 = endtime + buffer con = ( f"(starttime>{t1:d} & starttime<{t2:d}) | " f"((endtime>{t1:d} & endtime<{t2:d}) | " f"(starttime<{t1:d} & endtime>{t2:d}))" ) return con # --- SQL stuff def _str_of_params(value): """ Make sure a list of params is returned. This allows user to specify a single parameter, a list, set, nparray, etc. to match on. """ if isinstance(value, str): return value else: # try to coerce in a list of str try: return [str(x) for x in value] except TypeError: # else fallback to str repr return str(value) def _make_wheres(queries): """Create the where queries, join with AND clauses""" def _rename_keys(kwargs): """re-word some keys to make automatic sql generation easier""" if "eventid" in kwargs: kwargs["event_id"] = kwargs.pop("eventid") if "event_id" in kwargs: kwargs["event_id"] = _str_of_params(kwargs["event_id"]) if "event_description" in kwargs: kwargs["event_description"] = _str_of_params(kwargs["event_description"]) if "endtime" in kwargs: kwargs["maxtime"] = kwargs.pop("endtime") if "starttime" in kwargs: kwargs["mintime"] = kwargs.pop("starttime") return kwargs def _handle_nat(kwargs): """add a mintime that will exclude NaT values if endtime is used""" if "maxtime" in kwargs and "mintime" not in kwargs: kwargs["mintime"] = SMALLDT64.astype(np.int64) + 1 return kwargs def _handle_dateline_transversal(kwargs, out): """Check if dateline should be transversed by query.""" # if longitudes aren't being used bail out if not {"minlongitude", "maxlongitude"}.issubset(set(kwargs)): return kwargs, out # if dateline is not to be transversed by query bail out long_array = np.array([kwargs["minlongitude"], kwargs["maxlongitude"]]) minlong, maxlong = map_longitudes(long_array) if not minlong > maxlong: return kwargs, out # remove min/max long from query dict and reform to two queries. kwargs.pop("minlongitude"), kwargs.pop("maxlongitude") cond = f"(( longitude > {minlong}) OR ( longitude < {maxlong})) " out.append(cond) return kwargs, out def _build_query(kwargs): """iterate each key/value and build query""" out = [] kwargs, out = _handle_dateline_transversal(kwargs, out) for key, val in kwargs.items(): # deal with simple min/max if key.startswith("min"): out.append(f"{key.replace('min', '')} > {val}") elif key.startswith("max"): out.append(f"{key.replace('max', '')} < {val}") # deal with equals or ins elif isinstance(val, Sequence): if isinstance(val, str): val = [val] tup = str(tuple(val)).replace(",)", ")") # no trailing comma out.append(f"{key} IN {tup}") else: out.append(f"{key} = {val}") return " AND ".join(out).replace("'", '"') kwargs = _dict_times_to_ns(queries) kwargs = _rename_keys(kwargs) kwargs = _handle_nat(kwargs) return _build_query(kwargs) def _make_sql_command(cmd, table_name, columns=None, **kwargs) -> str: """build a sql command""" # get columns if columns: col = [columns] if isinstance(columns, str) else columns col += ["event_id"] # event_id is used as index columns = ", ".join(col) elif cmd.upper() == "DELETE": columns = "" else: columns = "*" limit = kwargs.pop("limit", None) wheres = _make_wheres(kwargs) sql = f'{cmd.upper()} {columns} FROM "{table_name}"' if wheres: sql += f" WHERE {wheres}" if limit: sql += f" LIMIT {limit}" return sql + ";" def _read_table(table_name, con, columns=None, **kwargs) -> pd.DataFrame: """ Read a SQLite table. Parameters ---------- table_name con columns Returns ------- """ # first ensure all times are ns (as ints) sql = _make_sql_command("select", table_name, columns=columns, **kwargs) # replace "None" with None return pd.read_sql(sql, con) def _get_tables(con): """Return a list of table in sqlite database""" out = con.execute("SELECT name FROM sqlite_master WHERE type='table';") return set(out) def _drop_rows(table_name, con, columns=None, **kwargs): """Drop indicies in table""" sql = _make_sql_command("delete", table_name, columns=columns, **kwargs) con.execute(sql) def _try_read_stream(stream_path, format=None, **kwargs): """Try to read a waveforms from file, if raises return None""" read = READ_DICT.get(format, obspy.read) stt = None try: stt = read(stream_path, **kwargs) except Exception: try: stt = obspy.read(stream_path, **kwargs) except Exception: warnings.warn("obspy failed to read %s" % stream_path, UserWarning) else: msg = f"{stream_path} was read but is not of format {format}" warnings.warn(msg, UserWarning) finally: return stt if stt else None