"""
Data fetcher class stuff
"""
from __future__ import annotations
import copy
import functools
import warnings
from collections import namedtuple
from collections.abc import Callable
from functools import partial
from typing import ClassVar
import numpy as np
import obspy
import pandas as pd
from obspy import Stream
import obsplus
from obsplus import events_to_df, picks_to_df, stations_to_df
from obsplus.bank.wavebank import WaveBank
from obsplus.constants import (
LARGEDT64,
NSLC,
WAVEFETCHER_OVERRIDES,
bulk_waveform_arg_type,
event_clientable_type,
event_time_type,
get_waveforms_parameters,
station_clientable_type,
stream_proc_type,
waveform_clientable_type,
)
from obsplus.exceptions import TimeOverflowWarning
from obsplus.utils.docs import compose_docstring
from obsplus.utils.events import get_event_client
from obsplus.utils.misc import register_func, suppress_warnings
from obsplus.utils.pd import filter_index, get_seed_id_series
from obsplus.utils.stations import get_station_client
from obsplus.utils.time import (
get_reference_time,
make_time_chunks,
to_datetime64,
to_timedelta64,
to_utc,
)
from obsplus.utils.waveforms import get_waveform_client
EventStream = namedtuple("EventStream", "event_id stream")
# ---------------------- fetcher constructor stuff
def _enable_swaps(cls):
"""
Enable swapping out events, stations, and picks info on any
function that gets waveforms.
"""
for name, value in cls.__dict__.items():
if "waveform" in name or name == "__call__":
setattr(cls, name, _temporary_override(value))
return cls
def _temporary_override(func):
"""
Decorator to enable temporary override of various parameters.
This is commonly used for events, stations, picks, etc.
"""
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
inter = WAVEFETCHER_OVERRIDES.intersection(kwargs)
if inter:
self = self.copy()
self.set_events(kwargs.pop("events", self.event_client))
self.set_stations(kwargs.pop("stations", self.station_client))
self.set_waveforms(kwargs.pop("waveforms", self.waveform_client))
return func(self, *args, **kwargs)
return wrapper
# ---------------------------------- Wavefetcher class
fetcher_waveform_type = waveform_clientable_type | obsplus.WaveBank
fetcher_event_type = event_clientable_type | pd.DataFrame | obsplus.EventBank
fetcher_station_type = station_clientable_type | pd.DataFrame
[docs]
@_enable_swaps
class Fetcher:
"""
A class for serving up data from various sources.
Integrates station, event, and waveform clients to enable dataset-aware
querying.
Parameters
----------
waveforms
Any argument from which a waveform client can be extracted. This
includes an obspy waveforms, directory of waveform files, or an object
with a `get_waveforms` method.
stations
Any argument from which an station client can be extracted. This
includes an obspy Inventory, directory of station files, or an object
with a `get_stations` method.
events
Any argument from which an event client can be extracted. This
includes an obspy Catalog, directory of event files, or an object with
a `get_events` method.
picks
Data from which picks can be extracted. A dataframe, events, or
event_client are all acceptable.
stream_processor
A callable that takes an obspy waveforms as input and returns an obspy
waveforms.
time_before
The default time before an given time to fetch.
time_after
The default time after a supplied time to fetch.
event_query
A dict of arguments used to filter events.
Examples
--------
>>> import obsplus
>>> import obspy
>>> #--- Init a Fetcher
>>> # from a dataset
>>> ds = obsplus.load_dataset('bingham_test')
>>> ds_fetcher = ds.get_fetcher()
>>> assert isinstance(ds_fetcher, obsplus.Fetcher)
>>> # from separate clients (includes Stream, Inventory, Catalog)
>>> waveforms = ds.waveform_client
>>> events = ds.event_client
>>> stations = ds.station_client
>>> kwargs = dict(events=events, waveforms=waveforms, stations=stations)
>>> fetcher = obsplus.Fetcher(**kwargs)
>>> assert isinstance(fetcher, obsplus.Fetcher)
>>> # --- get contiguous (not event) waveform data
>>> # simple get_waveform calls are passed to the waveforms client
>>> fetcher = obsplus.load_dataset('ta_test').get_fetcher()
>>> t1 = obspy.UTCDateTime('2007-02-15')
>>> t2 = t1 + 60
>>> station = 'M14A'
>>> st = fetcher.get_waveforms(starttime=t1, endtime=t2, station=station)
>>> print(st)
3 Trace(s) ...
>>> # iterate over a range of times
>>> t1 = obspy.UTCDateTime('2007-02-16')
>>> t2 = t1 + (3600 * 24)
>>> for st in fetcher.yield_waveforms(starttime=t1, endtime=t2):
... assert len(st)
>>> # --- get event waveforms
>>> fetcher = obsplus.load_dataset('bingham_test').get_fetcher()
>>> # iterate each event yielding streams 30 seconds after origin
>>> kwargs = dict(time_before=0, time_after=30, reference='origin')
>>> for event_id, st in fetcher.yield_event_waveforms(**kwargs):
... assert isinstance(event_id, str)
... assert isinstance(st, obspy.Stream)
"""
[docs]
def __init__(
self,
waveforms: fetcher_waveform_type,
stations: fetcher_event_type | None = None,
events: fetcher_event_type | None = None,
picks: pd.DataFrame | None = None,
stream_processor: stream_proc_type | None = None,
time_before: float | None = None,
time_after: float | None = None,
event_query: dict | None = None,
):
# if fetch_arg is a WaveFetcher just update dict and return
if isinstance(waveforms, Fetcher):
self.__dict__.update(waveforms.__dict__)
return
# get clients for each data types
self.set_waveforms(waveforms)
self.set_events(events)
self.set_stations(stations)
self._picks_input = picks
# waveforms processor for applying filters and such
self.stream_processor = stream_processor
# set event time/query parameters
self.time_before = to_timedelta64(time_before)
self.time_after = to_timedelta64(time_after)
self.event_query = event_query or {}
[docs]
def set_events(self, events: fetcher_event_type):
"""
Set event state in fetcher.
Parameters
----------
events
Data representing events, from which a client or dataframe can
be obtained.
"""
# set event and dataframe
try:
self.event_client = get_event_client(events)
except TypeError:
self.event_client = getattr(self, "event_client", None)
try:
self.event_df = events_to_df(events)
except TypeError:
self.event_df = None
self._picks_df = None
[docs]
def set_stations(self, stations: fetcher_station_type):
"""
Set the station state in fetcher.
Parameters
----------
stations
Data representing stations, from which a client or dataframe
can be inferred.
"""
try:
self.station_client = get_station_client(stations)
except TypeError:
self.station_client = getattr(self, "station_client", None)
try:
# since its common for inventories to have far out enddates this
# can raise a warning. These are safe to ignore.
with suppress_warnings(category=TimeOverflowWarning):
self.station_df = stations_to_df(stations)
except TypeError:
# if unable to get station info from stations use waveform client
try:
self.station_df = stations_to_df(self.waveform_client)
except TypeError:
# if no waveforms try events
try:
self.station_df = stations_to_df(self.event_client)
except TypeError:
self.station_df = None
# make sure seed_id is set
if self.station_df is not None:
self.station_df["seed_id"] = get_seed_id_series(self.station_df)
# ------------------------ continuous data fetching methods
# ------------------------ event waveforms fetching methods
reference_funcs: ClassVar = {} # stores funcs for getting event reference times
def __call__(
self,
time_arg: event_time_type,
time_before: float | None = None,
time_after: float | None = None,
*args,
**kwargs,
) -> obspy.Stream:
"""
Using a reference time, return a waveforms that encompasses that time.
Parameters
----------
time_arg
The argument that will indicate a start time. Can be a one
length events, and event, a float, or a UTCDatetime object
time_before
The time before time_arg to include in waveforms
time_after
The time after time_arg to include in waveforms
Returns
-------
obspy.Stream
"""
tbefore = to_timedelta64(time_before, default=self.time_before)
tafter = to_timedelta64(time_after, default=self.time_after)
assert (tbefore is not None) and (tafter is not None)
# get the reference time from the object
time = to_datetime64(get_reference_time(time_arg))
t1 = time - tbefore
t2 = time + tafter
return self.get_waveforms(starttime=to_utc(t1), endtime=to_utc(t2), **kwargs)
# ------------------------------- misc
[docs]
def copy(self) -> Fetcher:
"""Return a deep copy of the fetcher."""
return copy.deepcopy(self)
def _get_bulk_wf(self, *args, **kwargs):
"""
get the wave forms using the client, apply processor if it is defined
"""
out = self.waveform_client.get_waveforms_bulk(*args, **kwargs)
if callable(self.stream_processor):
return self.stream_processor(out) or out
else:
return out
def _get_bulk_args(
self, starttime=None, endtime=None, **kwargs
) -> bulk_waveform_arg_type:
"""
Get the bulk waveform arguments based on given start/end times.
This method also takes into account data availability as contained
in the stations data.
Parameters
----------
starttime
Start times for query.
endtime
End times for query.
Returns
-------
List of tuples of the form:
[(network, station, location, channel, starttime, endtime)]
"""
station_df = self.station_df.copy()
inv = station_df[filter_index(station_df, **kwargs)]
# replace None/Nan with larger number
inv.loc[inv["end_date"].isnull(), "end_date"] = LARGEDT64
inv["end_date"] = inv["end_date"].astype("datetime64[ns]")
# get start/end of the inventory
inv_start = inv["start_date"].min()
inv_end = inv["end_date"].max()
# remove station/channels that dont have data for requested time
min_time = to_datetime64(starttime, default=inv_start).min()
max_time = to_datetime64(endtime, default=inv_end).max()
con1, con2 = (inv["start_date"] > max_time), (inv["end_date"] < min_time)
df = inv[~(con1 | con2)].set_index("seed_id")[list(NSLC)]
if df.empty: # return empty list if no data found
return []
if isinstance(starttime, pd.Series):
# Have to get clever here to make sure only active stations get used
# and indices are not duplicated.
indexer = list(set(starttime.index).intersection(df.index))
new_start = starttime.loc[indexer]
new_end = endtime.loc[list(set(endtime.index).intersection(df.index))]
df["starttime"] = new_start.loc[~new_start.index.duplicated()]
df["endtime"] = new_end.loc[~new_end.index.duplicated()]
else:
df["starttime"] = starttime
df["endtime"] = endtime
# remove any rows that don't have defined start/end times
out = df[~(df["starttime"].isnull() | df["endtime"].isnull())]
# ensure we have UTCDateTime objects
out["starttime"] = [to_utc(x) for x in out["starttime"]]
out["endtime"] = [to_utc(x) for x in out["endtime"]]
# convert to list of tuples and return
return [tuple(x) for x in out.to_records(index=False)]
@property
def picks_df(self):
"""Return a dataframe from the picks (if possible)"""
if self._picks_df is None:
try:
df = picks_to_df(self.event_client)
except TypeError:
self._picks_df = None
else:
self._picks_df = df
return self._picks_df
@picks_df.setter
def picks_df(self, item):
setattr(self, "_picks_df", item)
# ------------------------ functions for getting reference times
@register_func(Fetcher.reference_funcs, key="origin")
def _get_origin_reference_times(fetcher: Fetcher) -> pd.Series:
"""Get the reference times for origins."""
event_df = fetcher.event_df[["time", "event_id"]].set_index("event_id")
inv_df = fetcher.station_df
# iterate each event and add rows for each channel in inventory
dfs = []
for eid, ser in event_df.iterrows():
inv = inv_df.copy()
inv["event_id"] = eid
inv["time"] = ser["time"]
dfs.append(inv)
# get output
out = (
pd.concat(dfs, ignore_index=True)
.reset_index()
.set_index("seed_id")[["event_id", "time"]]
.dropna(subset=["time"])
)
return out
@register_func(Fetcher.reference_funcs, key="p")
def _get_p_reference_times(fetcher: Fetcher) -> pd.Series:
"""Get the reference times for p arrivals."""
return _get_phase_reference_time(fetcher, "p")
@register_func(Fetcher.reference_funcs, key="s")
def _get_s_reference_times(fetcher: Fetcher) -> pd.Series:
"""Get the reference times for s arrivals."""
return _get_phase_reference_time(fetcher, "s")
def _get_phase_reference_time(fetcher: Fetcher, phase):
"""
Get reference times to specified phases, apply over all channels in a
station.
"""
pha = phase.upper()
# ensure the pick_df and inventory df exist
pick_df = fetcher.picks_df
inv_df = fetcher.station_df
assert pick_df is not None and inv_df is not None
# filter dataframes for phase of interest
assert (pick_df["phase_hint"].str.upper() == pha).any(), f"no {phase} picks found"
pick_df = pick_df[pick_df["phase_hint"] == pha]
# merge inventory and pick df together, ensure time is datetime64
columns = ["time", "station", "event_id"]
merge = pd.merge(inv_df, pick_df[columns], on="station", how="left")
merge["time"] = to_datetime64(merge["time"])
assert merge["seed_id"].astype(bool).all()
return merge.set_index("seed_id")[["time", "event_id"]]