__author__ = "Christopher Tomkins-Tinch"
__copyright__ = "Copyright 2015, Christopher Tomkins-Tinch"
__email__ = "tomkinsc@broadinstitute.org"
__license__ = "MIT"
# built-ins
import os
import sys
import re
from functools import partial
from abc import ABCMeta, abstractmethod
from wrapt import ObjectProxy
from contextlib import contextmanager
try:
from connection_pool import ConnectionPool
except ImportError:
# we just won't pool connections if it's not installed
# Should there be a warning? Should there be a runtime flag?
pass
import copy
import collections
# module-specific
import snakemake.io
from snakemake.logging import logger
from snakemake.common import parse_uri
[docs]class StaticRemoteObjectProxy(ObjectProxy):
"""Proxy that implements static-ness for remote objects.
The constructor takes a real RemoteObject and returns a proxy that
behaves the same except for the exists() and mtime() methods.
"""
[docs] def exists(self):
return True
[docs] def mtime(self):
return float("-inf")
[docs] def is_newer(self, time):
return False
def __copy__(self):
copied_wrapped = copy.copy(self.__wrapped__)
return type(self)(copied_wrapped)
def __deepcopy__(self):
copied_wrapped = copy.deepcopy(self.__wrapped__)
return type(self)(copied_wrapped)
[docs]class AbstractRemoteProvider:
"""This is an abstract class to be used to derive remote provider classes. These might be used to hold common credentials,
and are then passed to RemoteObjects.
"""
__metaclass__ = ABCMeta
supports_default = False
allows_directories = False
def __init__(
self, *args, keep_local=False, stay_on_remote=False, is_default=False, **kwargs
):
self.args = args
self.stay_on_remote = stay_on_remote
self.keep_local = keep_local
self.is_default = is_default
self.kwargs = kwargs
[docs] def remote(
self, value, *args, keep_local=None, stay_on_remote=None, static=False, **kwargs
):
if snakemake.io.is_flagged(value, "temp"):
raise SyntaxError("Remote and temporary flags are mutually exclusive.")
if snakemake.io.is_flagged(value, "protected"):
raise SyntaxError("Remote and protected flags are mutually exclusive.")
if keep_local is None:
keep_local = self.keep_local
if stay_on_remote is None:
stay_on_remote = self.stay_on_remote
def _set_protocol(value):
"""Adds the default protocol to `value` if it doesn't already have one"""
protocol = self.default_protocol
for p in self.available_protocols:
if value.startswith(p):
value = value[len(p) :]
protocol = p
break
return protocol, value
if isinstance(value, str):
protocol, value = _set_protocol(value)
value = protocol + value if stay_on_remote else value
else:
protocol, value = list(zip(*[_set_protocol(v) for v in value]))
if len(set(protocol)) != 1:
raise SyntaxError("A single protocol must be used per RemoteObject")
protocol = set(protocol).pop()
value = [protocol + v if stay_on_remote else v for v in value]
if "protocol" not in kwargs:
if "protocol" not in self.kwargs:
kwargs["protocol"] = protocol
else:
kwargs["protocol"] = self.kwargs["protocol"]
provider = sys.modules[self.__module__] # get module of derived class
remote_object = provider.RemoteObject(
*args,
keep_local=keep_local,
stay_on_remote=stay_on_remote,
provider=self,
**kwargs,
)
if static:
remote_object = StaticRemoteObjectProxy(remote_object)
return snakemake.io.flag(value, "remote_object", remote_object)
[docs] def glob_wildcards(self, pattern, *args, **kwargs):
args = self.args if not args else args
kwargs = self.kwargs if not kwargs else kwargs
referenceObj = snakemake.io._IOFile(self.remote(pattern, *args, **kwargs))
remote_object = snakemake.io.get_flag_value(referenceObj, "remote_object")
if not remote_object.stay_on_remote:
pattern = "./" + remote_object.name
pattern = os.path.normpath(pattern)
key_list = [k for k in remote_object.list]
return snakemake.io.glob_wildcards(pattern, files=key_list)
[docs] @abstractmethod
def default_protocol(self):
"""The protocol that is prepended to the path when no protocol is specified."""
pass
[docs] @abstractmethod
def available_protocols(self):
"""List of valid protocols for this remote provider."""
pass
[docs] @abstractmethod
def remote_interface(self):
pass
[docs]class AbstractRemoteObject:
"""This is an abstract class to be used to derive remote object classes for
different cloud storage providers. For example, there could be classes for interacting with
Amazon AWS S3 and Google Cloud Storage, both derived from this common base class.
"""
__metaclass__ = ABCMeta
def __init__(
self,
*args,
protocol=None,
keep_local=False,
stay_on_remote=False,
provider=None,
**kwargs,
):
assert protocol is not None
# self._iofile must be set before the remote object can be used, in io.py or elsewhere
self._iofile = None
self.args = args
self.kwargs = kwargs
self.keep_local = keep_local
self.stay_on_remote = stay_on_remote
self.provider = provider
self.protocol = protocol
[docs] async def inventory(self, cache: snakemake.io.IOCache):
"""From this file, try to find as much existence and modification date
information as possible.
"""
# If this is implemented in a remote object, results have to be stored in
# the given IOCache object.
pass
[docs] @abstractmethod
def get_inventory_parent(self):
pass
@property
def _file(self):
if self._iofile is None:
return None
return self._iofile._file
[docs] def file(self):
return self._file
[docs] def local_file(self):
if self.stay_on_remote:
return self._file[len(self.protocol) :]
else:
return self._file
[docs] def remote_file(self):
return self.protocol + self.local_file()
[docs] @abstractmethod
def close(self):
pass
[docs] @abstractmethod
def exists(self):
pass
[docs] @abstractmethod
def mtime(self):
pass
[docs] @abstractmethod
def size(self):
pass
[docs] @abstractmethod
def download(self, *args, **kwargs):
pass
[docs] @abstractmethod
def upload(self, *args, **kwargs):
pass
[docs] @abstractmethod
def list(self, *args, **kwargs):
pass
[docs] @abstractmethod
def name(self, *args, **kwargs):
pass
[docs] @abstractmethod
def remote(self, value, keep_local=False, stay_on_remote=False):
pass
[docs] @abstractmethod
def remove(self):
raise NotImplementedError("Removal of files is unavailable for this remote")
[docs] def local_touch_or_create(self):
self._iofile.touch_or_create()
[docs]class DomainObject(AbstractRemoteObject):
"""This is a mixin related to parsing components
out of a location path specified as
(host|IP):port/remote/location
"""
def __init__(self, *args, **kwargs):
super(DomainObject, self).__init__(*args, **kwargs)
@property
def _matched_address(self):
return re.search(
r"^(?P<protocol>[a-zA-Z]+\://)?(?P<host>[A-Za-z0-9\-\.]+)(?:\:(?P<port>[0-9]+))?(?P<path_remainder>.*)$",
self.local_file(),
)
@property
def name(self):
return self.path_remainder
# if we ever parse out the protocol directly
# @property
# def protocol(self):
# if self._matched_address:
# return self._matched_address.group("protocol")
@property
def host(self):
if self._matched_address:
return self._matched_address.group("host")
@property
def port(self):
if self._matched_address:
return self._matched_address.group("port")
@property
def path_prefix(self):
# this is the domain and port, however specified before the path remainder
return self._iofile._file[: self._iofile._file.index(self.path_remainder)]
@property
def path_remainder(self):
if self._matched_address:
return self._matched_address.group("path_remainder")
@property
def local_path(self):
return self._iofile._file
@property
def remote_path(self):
return self.path_remainder
[docs]class PooledDomainObject(DomainObject):
"""This adds conection pooling to DomainObjects
out of a location path specified as
(host|IP):port/remote/location
"""
connection_pools = {}
def __init__(self, *args, pool_size=100, immediate_close=False, **kwargs):
super(PooledDomainObject, self).__init__(*args, **kwargs)
self.pool_size = 100
self.immediate_close = immediate_close
[docs] def get_default_kwargs(self, **defaults):
defaults.setdefault("host", self.host)
defaults.setdefault("port", int(self.port) if self.port else None)
return defaults
[docs] def get_args_to_use(self):
"""merge the objects args with the parent provider
Positional Args: use those of object or fall back to ones from provider
Keyword Args: merge with any defaults
"""
# if args have been provided to remote(),
# use them over those given to RemoteProvider()
args_to_use = self.provider.args
if len(self.args):
args_to_use = self.args
# use kwargs passed in to remote() to override those given to the RemoteProvider()
# default to the host and port given as part of the file,
# falling back to one specified as a kwarg to remote() or the RemoteProvider
# (overriding the latter with the former if both)
kwargs_to_use = self.get_default_kwargs()
for k, v in self.provider.kwargs.items():
kwargs_to_use[k] = v
for k, v in self.kwargs.items():
kwargs_to_use[k] = v
return args_to_use, kwargs_to_use
[docs] @contextmanager
def get_connection(self):
"""get a connection from a pool or create a new one"""
if not self.immediate_close and "connection_pool" in sys.modules:
# if we can (and the user doesn't override) use a pool
with self.connection_pool.item() as conn:
yield conn
else:
# otherwise create a one-time connection
args_to_use, kwargs_to_use = self.get_args_to_use()
conn = self.create_connection(*args_to_use, **kwargs_to_use)
try:
yield conn
finally:
self.close_connection(conn)
@property
def conn_keywords(self):
"""returns list of keywords relevant to a unique connection"""
return ["host", "port", "username"]
@property
def connection_pool(self):
"""set up a pool of re-usable active connections"""
# merge this object's values with those of its parent provider
args_to_use, kwargs_to_use = self.get_args_to_use()
# hashing connection pool on tuple of relevant arguments. There
# may be a better way to do this
conn_pool_label_tuple = (
type(self),
*args_to_use,
*[kwargs_to_use.get(k, None) for k in self.conn_keywords],
)
if conn_pool_label_tuple not in self.connection_pools:
create_callback = partial(
self.create_connection, *args_to_use, **kwargs_to_use
)
self.connection_pools[conn_pool_label_tuple] = ConnectionPool(
create_callback, close=self.close_connection, max_size=self.pool_size
)
return self.connection_pools[conn_pool_label_tuple]
[docs] @abstractmethod
def create_connection(self):
"""handle the protocol specific job of creating a connection"""
pass
[docs] @abstractmethod
def close_connection(self, connection):
"""handle the protocol specific job of closing a connection"""
pass
[docs]class AutoRemoteProvider:
@property
def protocol_mapping(self):
# automatically gather all RemoteProviders
import pkgutil
import importlib.util
provider_list = []
for remote_submodule in pkgutil.iter_modules(snakemake.remote.__path__):
path = (
os.path.join(remote_submodule.module_finder.path, remote_submodule.name)
+ ".py"
)
module_name = remote_submodule.name
spec = importlib.util.spec_from_file_location(module_name, path)
module = importlib.util.module_from_spec(spec)
try:
sys.modules[module_name] = module
spec.loader.exec_module(module)
except Exception as e:
logger.debug(f"Autoloading {module_name} failed: {e}")
continue
provider_list.append(module.RemoteProvider)
# assemble scheme mapping
protocol_dict = {}
for Provider in provider_list:
for protocol in Provider().available_protocols:
protocol_short = protocol[:-3] # remove "://" suffix
protocol_dict[protocol_short] = Provider
return protocol_dict
[docs] def remote(self, value, *args, provider_kws=None, **kwargs):
if isinstance(value, str):
values = [value]
elif isinstance(value, collections.abc.Iterable):
values = value
else:
raise TypeError(f"Invalid type ({type(value)}) passed to remote: {value}")
provider_remote_list = []
for value in values:
# select provider
o = parse_uri(value)
Provider = self.protocol_mapping.get(o.scheme)
if Provider is None:
raise TypeError(f"Could not find remote provider for: {value}")
# use provider's remote
provider_kws = {} if provider_kws is None else provider_kws.copy()
provider_remote_list.append(
Provider(**provider_kws).remote(value, *args, **kwargs)
)
return (
provider_remote_list[0]
if len(provider_remote_list) == 1
else provider_remote_list
)
AUTO = AutoRemoteProvider()