"""
Module for loading, (and downloading) data sets.
"""
import abc
import copy
import inspect
import json
import os
import shutil
from collections import OrderedDict
from contextlib import suppress
from distutils.dir_util import copy_tree
from functools import lru_cache
from pathlib import Path
from types import MappingProxyType as MapProxy
from typing import Union, Optional, Tuple, TypeVar
from warnings import warn
from obspy.clients.fdsn import Client
from pkg_resources import iter_entry_points
import obsplus
from obsplus import copy_dataset
from obsplus.constants import DATA_TYPES
from obsplus.exceptions import (
FileHashChangedError,
MissingDataFileError,
DataVersionError,
)
from obsplus.interfaces import WaveformClient, EventClient, StationClient
from obsplus.utils.dataset import _create_opsdata
from obsplus.utils.events import get_event_client
from obsplus.utils.misc import (
hash_directory,
iterate,
get_version_tuple,
validate_version_str,
)
from obsplus.utils.stations import get_station_client
from obsplus.utils.waveforms import get_waveform_client
DataSetType = TypeVar("DataSetType", bound="DataSet")
[docs]
class DataSet(abc.ABC):
"""
Abstract Base Class for downloading and serving datasets.
This is not intended to be used directly, but rather through subclassing.
Parameters
----------
base_path
The path to which the dataset will be saved.
Attributes
----------
data_path
The path containing the data. By default it is base_path / name.
source_path
The path which contains the original files included in the dataset
before download. By default this is found in the same directory as
the dataset's code (.py) file in a folder with the same name as the
dataset.
Notes
-----
Importantly, each dataset references *two* directories, the source_path
and data_path. The source_path contains all data included within the
dataset and should not be altered. The data_path has a copy of
everything in the source_path, plus the files created during the
downloading process.
The base_path (the parent of data_path) is resolved for each
dataset using the following priorities:
1. The `base_path` provided to `Dataset`'s __init__ method.
2. .data_path.txt file stored in the data source
3. An environmental name OPSDATA_PATH
4. The opsdata_path variable from obsplus.constants
By default the data will be downloaded to the user's home directory
in a folder called "opsdata", but again, this is easily changed
by setting the OPSDATA_PATH environmental variable.
"""
_entry_points = {}
_datasets = {}
data_loaded = False
# variables for hashing datafiles and versioning
_version_filename = "dataset_version.txt"
_hash_filename = "dataset_hash.json"
# the name of the file that saves where the data file were downloaded
_saved_dataset_path_filename = ".dataset_data_path.txt"
_hash_excludes = (
"readme.txt",
_version_filename,
_hash_filename,
_saved_dataset_path_filename,
)
# generic functions for loading data (WaveBank, event, stations)
_load_funcs = MapProxy(
dict(
waveform=get_waveform_client,
event=get_event_client,
station=get_station_client,
)
)
# flags to determine if data should be loaded into memory
_load_waveforms = False
_load_stations = True
_load_events = True
# cache for instantiated datasets
_loaded_datasets = {}
_verbose = True
def __init_subclass__(cls, **kwargs):
"""Register subclasses of datasets."""
assert isinstance(cls.name, str), "name must be a string"
validate_version_str(cls.version)
# Register the subclass as a dataset.
DataSet._datasets[cls.name.lower()] = cls
# --- logic for loading and caching data
[docs]
def __init__(self, base_path=None):
"""download and load data into memory."""
self.base_path = self._get_opsdata_path(base_path)
# create the dataset's base directory
self.data_path.mkdir(exist_ok=True, parents=True)
# run the download logic if needed
self._run_downloads()
# cache loaded dataset
self.data_loaded = True
if not base_path and self.name not in self._loaded_datasets:
self._loaded_datasets[self.name] = self.copy(deep=True)
def _get_opsdata_path(self, opsdata_path: Optional[Path] = None) -> Path:
"""
Get the location where datasets are stored.
Returns
-------
A path to the opsdata directory.
"""
if opsdata_path is None:
opsdata_path = getattr(self._saved_data_path, "parent", None)
if opsdata_path is None:
# next look for env variable
opsdata_path_default = obsplus.constants.OPSDATA_PATH
opsdata_path = os.getenv("OPSDATA_PATH", opsdata_path_default)
# ensure the data path exists
_create_opsdata(opsdata_path)
return Path(opsdata_path)
def _run_downloads(self) -> None:
"""Iterate each kind of data and download if needed."""
# Make sure the version of the dataset is okay
version_ok = self.check_version()
downloaded = False
for what in DATA_TYPES:
needs_str = f"{what}s_need_downloading"
if getattr(self, needs_str) or (not version_ok):
# this is the first type of data to be downloaded, run hook
# and copy data from data source.
if not downloaded and self.source_path.exists():
copy_tree(str(self.source_path), str(self.data_path))
self.pre_download_hook()
downloaded = True
# download data, test termination criteria
self._log(f"downloading {what} data for {self.name} dataset ...")
getattr(self, "download_" + what + "s")()
assert not getattr(self, needs_str), f"Download {what} failed"
self._log(f"finished downloading {what} data for {self.name}")
self._write_readme() # make sure readme has been written
# some data were downloaded, call post download hook
if downloaded:
self.check_hashes()
self.post_download_hook()
# write a new version file
self.write_version()
# write out a new saved datafile path
self._save_data_path()
def _load(self, what, path):
"""Load the client-like objects from disk."""
try:
client = self._load_funcs[what](path)
except TypeError:
warn(f"failed to load {what} from {path}, returning None")
return None
# load data into memory (eg load event bank contents into catalog)
if getattr(self, f"_load_{what}s"):
return getattr(client, f"get_{what}s")()
else:
return client
[docs]
def copy(self: DataSetType, deep=True) -> DataSetType:
"""
Return a copy of the dataset.
Parameters
----------
deep
If True deep copy the objects attached to the dataset.
Notes
-----
This only copies data in memory, not on disk. If you plan to make
any changes to the dataset's on disk resources please use
:method:`~obsplus.Dataset.copy_to`.
"""
return copy.deepcopy(self) if deep else copy.copy(self)
[docs]
def copy_to(
self: DataSetType, destination: Optional[Union[str, Path]] = None
) -> DataSetType:
"""
Copy the dataset to a destination.
If the destination already exists simply do nothing.
Parameters
----------
destination
The destination to copy the dataset. It will be created if it
doesnt exist. If None is provided use tmpfile to create a temporary
directory.
Returns
-------
A new dataset object which refers to the copied files.
"""
return copy_dataset(self, destination)
[docs]
def get_fetcher(self, **kwargs) -> "obsplus.Fetcher":
"""
Return a Fetcher from the data.
kwargs are passed to :class:`~obsplus.structures.Fetcher`'s constructor.
See its documentation for acceptable kwargs.
"""
assert self.data_loaded, "data have not been loaded into memory"
# get events/waveforms/stations and put into dict for the Fetcher
fetch_kwargs = {
"waveforms": self.waveform_client,
"events": self.event_client,
"stations": self.station_client,
}
fetch_kwargs.update(kwargs)
return obsplus.Fetcher(**fetch_kwargs)
__call__ = get_fetcher
def _write_readme(self, filename="readme.txt"):
"""Writes the classes docstring to a file."""
path = self.data_path / filename
if not path.exists():
with path.open("w") as fi:
fi.write(str(self.__doc__))
def _save_data_path(self, path=None):
"""Save the path to where the data where downloaded in source folder."""
path = Path(path or self._path_to_saved_path_file)
path.parent.mkdir(exist_ok=True, parents=True)
with path.open("w") as fi:
fi.write(str(self.data_path))
[docs]
@classmethod
def load_dataset(
cls: DataSetType, name: Union[str, "DataSet"], silent=False
) -> DataSetType:
"""
Get a loaded dataset.
Will ensure all files are downloaded and the appropriate data are
loaded into memory.
Parameters
----------
name
The name of the dataset to load or a DataSet object. If a DataSet
object is passed a copy of it will be returned.
Examples
--------
>>> # --- Load an example dataset for testing
>>> import obsplus
>>> ds = obsplus.load_dataset('default_test')
>>> # If you plan to make changes to the dataset be sure to copy it first
>>> # The following will copy all files in the dataset to a tmpdir
>>> ds2 = obsplus.copy_dataset('default_test')
>>> # --- Use dataset clients to load waveforms, stations, and events
>>> cat = ds.event_client.get_events()
>>> st = ds.waveform_client.get_waveforms()
>>> inv = ds.station_client.get_stations()
>>> # --- get a fetcher for more "dataset aware" querying
>>> fetcher = ds.get_fetcher()
"""
# Just copy and return if a dataset is passed.
if isinstance(name, DataSet):
return name.copy()
name = name.lower()
cls._load_dataset_entry_point(name)
if name not in cls._datasets:
# The dataset has not been discovered; try to load entry points
msg = f"{name} is not in the known datasets {list(cls._datasets)}"
raise ValueError(msg)
if name in cls._loaded_datasets:
# The dataset has already been loaded, simply return a copy
return cls._loaded_datasets[name].copy()
else: # The dataset has been discovered but not loaded; just loaded
return cls._datasets[name]()
[docs]
def delete_data_directory(self):
"""
Delete the datafiles of a dataset.
This will force the data to be re-copied from the source files and
download logic to be run.
"""
dataset = DataSet.load_dataset(self)
shutil.rmtree(dataset.data_path)
@classmethod
def _load_dataset_entry_point(cls, name=None, load=True):
"""
Load and cache the dataset entry points.
Parameters
----------
name
A string id of the dataset
load
If True, load the code associated with the entry point.
"""
def _load_ep(ep):
"""Load the entry point, ignore removed datasets."""
# If a plugin was register but no longer exists it can raise.
with suppress(ModuleNotFoundError):
ep.load()
assert name in cls._datasets, "dataset should be registered."
if name in cls._entry_points: # entry point has been registered
if name in cls._datasets: # and loaded, return
return
elif load: # it has not been loaded, try loading it.
_load_ep(cls._entry_points[name])
# it has not been found, iterate entry points and update
eps = {x.name: x for x in iter_entry_points("obsplus.datasets")}
cls._entry_points.update(eps)
# stop if we don't need to load
if not load:
return
# now iterate through all names, or just selected name, and load
for name in set(iterate(name or eps)) & set(eps):
_load_ep(eps[name])
# --- prescribed Paths for data
@property
def data_path(self) -> Path:
"""
Return a path to where the dataset's data was/will be downloaded.
"""
return self.base_path / self.name
@property
def source_path(self) -> Path:
"""
Return a path to the directory where the data files included with
the dataset live.
"""
try:
path = Path(inspect.getfile(self.__class__)).parent
except (AttributeError, TypeError):
path = Path(__file__)
return path / self.name
@property
def _saved_data_path(self):
"""Load the saved data source path, else return None."""
expected_path = self._path_to_saved_path_file
if expected_path.exists():
loaded_path = Path(expected_path.open("r").read())
if loaded_path.exists():
return loaded_path
return None
@property
def _path_to_saved_path_file(self):
"""
A path to the file which keeps track of where data are downloaded.
"""
return self.source_path / self._saved_dataset_path_filename
@property
def _version_path(self):
"""A path to the saved version file."""
return self.data_path / self._version_filename
@property
@lru_cache()
def data_files(self) -> Tuple[Path, ...]:
"""
Return a list of top-level files associated with the dataset.
Hidden files are ignored.
"""
file_iterator = self.source_path.glob("*")
files = [x for x in file_iterator if not x.is_dir()]
return tuple([x for x in files if not x.name.startswith(".")])
@property
def waveform_path(self) -> Path:
"""Return the path to the waveforms."""
return self.data_path / "waveforms"
@property
def event_path(self) -> Path:
"""Return the path to the events."""
return self.data_path / "events"
@property
def station_path(self) -> Path:
"""Return the path to the stations."""
return self.data_path / "stations"
# --- checks for if each type of data is downloaded
@property
def waveforms_need_downloading(self) -> bool:
"""
Returns True if waveform data need to be downloaded.
"""
return not self.waveform_path.exists()
@property
def events_need_downloading(self) -> bool:
"""
Returns True if event data need to be downloaded.
"""
return not self.event_path.exists()
@property
def stations_need_downloading(self) -> bool:
"""
Returns True if station data need to be downloaded.
"""
return not self.station_path.exists()
@property
@lru_cache()
def waveform_client(self) -> Optional[WaveformClient]:
"""A cached property for a waveform client"""
return self._load("waveform", self.waveform_path)
@property
@lru_cache()
def event_client(self) -> Optional[EventClient]:
"""A cached property for an event client"""
return self._load("event", self.event_path)
@property
@lru_cache()
def station_client(self) -> Optional[StationClient]:
"""A cached property for a station client"""
return self._load("station", self.station_path)
@property
@lru_cache()
def _download_client(self):
"""
Return an instance of the IRIS client, subclasses can override
to use different clients.
"""
return Client("IRIS")
@_download_client.setter
def _download_client(self, item):
"""just allow this to be overwritten"""
self.__dict__["client"] = item
def _log(self, msg):
"""Simple way to customize dataset logging."""
print(msg)
[docs]
def create_sha256_hash(self, path=None, hidden=False) -> dict:
"""
Create a sha256 hash of the dataset's data files.
The output is stored in a simple json file. Keys are paths (relative
to dataset base path) and values are files hashes.
If you want to update/create the hash file in the dataset's source
this can be done by passing the dataset's source_path as the path
argument.
Parameters
----------
path
The path to which the hash data is saved. If None use data_path.
hidden
If True also include hidden files.
"""
kwargs = dict(exclude=self._hash_excludes, hidden=hidden)
out = hash_directory(self.data_path, **kwargs)
# sort dict to mess less with git
sort_dict = OrderedDict(sorted(out.items()))
# get path and dump json
default_path = Path(self.data_path) / self._hash_filename
_path = path or default_path
hash_path = _path / self._hash_filename if _path.is_dir() else _path
with hash_path.open("w") as fi:
json.dump(sort_dict, fi, sort_keys=True, indent=2)
return out
[docs]
def check_hashes(self, check_hash=False):
"""
Check that the files are all there and have the correct Hashes.
Parameters
----------
check_hash
If True check the hash of the files.
Raises
------
FileHashChangedError
If one of the file hashes is not as expeted.
MissingDataFileError
If one the data files was not downloaded.
"""
# If there is not a pre-existing hash file return
hash_path = Path(self.data_path / self._hash_filename)
if not hash_path.exists():
return
# get old and new hash, and overlaps
old_hash = json.load(hash_path.open())
current_hash = hash_directory(self.data_path, exclude=self._hash_excludes)
overlap = set(old_hash) & set(current_hash) - set(self._hash_excludes)
# get any files with new hashes
has_changed = {x for x in overlap if old_hash[x] != current_hash[x]}
missing = (set(old_hash) - set(current_hash)) - set(self._hash_excludes)
if has_changed and check_hash:
msg = (
f"The hash for dataset {self.name} did not match the "
f"expected values for the following files:\n{has_changed}"
)
raise FileHashChangedError(msg)
if missing:
msg = f"Dataset {self.name} is missing files: \n{missing}"
raise MissingDataFileError(msg)
[docs]
def check_version(self) -> bool:
"""
Check the version of the dataset.
Verifies the version string in the dataset class definition matches
the one saved on disk. Returns True if all is well else raises a
DataVersionError.
Parameters
----------
path
Expected path of the version file.
Raises
------
DataVersionError
If any version problems are discovered.
"""
redownload_msg = f"Delete the following directory {self.data_path}"
try:
version = self.read_data_version()
except (DataVersionError, ValueError): # failed to read version
need_dl = (getattr(self, f"{x}s_need_downloading") for x in DATA_TYPES)
if not any(need_dl): # Something is a little weird
warn("Version file is missing. Attempting to re-download the dataset.")
return False
# Check the version number
if get_version_tuple(version) < get_version_tuple(self.version):
msg = f"Dataset version is out of date: {version} < {self.version}. "
raise DataVersionError(msg + redownload_msg)
elif get_version_tuple(version) > get_version_tuple(self.version):
msg = f"Dataset version mismatch: {version} > {self.version}."
msg = msg + " It may be necessary to reload the dataset."
warn(msg + redownload_msg)
return True # All is well. Continue.
[docs]
def write_version(self, path: Optional[Union[Path, str]] = None):
"""Write the version string to disk."""
version_path = path or self._version_path
with version_path.open("w") as fi:
fi.write(self.version)
[docs]
def read_data_version(self, path: Optional[Union[Path, str]] = None) -> str:
"""
Read the data version from disk.
Return a 3 length tuple from the semantic version string (of the
form xx.yy.zz). Raise a DataVersionError if not found.
"""
version_path = path or self._version_path
if not version_path.exists():
raise DataVersionError(f"{version_path} does not exist!")
with version_path.open("r") as fi:
version_str = fi.read()
validate_version_str(version_str)
return version_str
# --- Abstract properties subclasses should implement
@property
@abc.abstractmethod
def name(self) -> str:
"""
Name of the dataset
"""
@property
@abc.abstractmethod
def version(self) -> str:
"""
Dataset version. Should be a str of the form x.y.z
"""
@property
def version_tuple(self) -> Tuple[int, int, int]:
"""
Return a tuple of the version string.
"""
validate_version_str(self.version)
vsplit = self.version.split(".")
return int(vsplit[0]), int(vsplit[1]), int(vsplit[2])
# --- Abstract methods subclasses should implement
[docs]
def download_events(self) -> None:
"""
Method to ensure the events have been downloaded.
Events should be written in an obspy-readable format to
self.event_path. If not implemented this method will create an empty
directory.
"""
self.event_path.mkdir(exist_ok=True, parents=True)
[docs]
def download_stations(self) -> None:
"""
Method to ensure inventories have been downloaded.
Station data should be written in an obspy-readable format to
self.station_path. Since there is not yet a functional StationBank,
this method must be implemented by subclass.
"""
self.station_path.mkdir(exist_ok=True, parents=True)
[docs]
def pre_download_hook(self):
"""Code to run before any downloads."""
[docs]
def post_download_hook(self):
"""Code to run after any downloads."""
def __str__(self):
return f"Dataset: {self.name}"
def __repr__(self):
return f"{str(self)} with description: {self.__doc__}"
load_dataset = DataSet.load_dataset