Source code for rnaglib.transforms.transform

import os
from joblib import Parallel, delayed
from typing import List, Union, Any, Iterable, Generator, TYPE_CHECKING


[docs] class Transform: """Transforms modify and add information to an RNA graph via the ``networkx.Graph`` data structure. Receives an RNA graph and returns an RNA graph. This can be applied at dataset construction time, or a retrieval. Implementation inspired by torch-geometric Transforms library. :param parallel: whether to run the transform in parallel. :param num_workers: if running in parallel, number of jobs to use. Example -------- Transforms are callable objects that modify an RNA graph or dataset passed to it:: >>> from rnaglib.transforms import Transform >>> t = Transform() >>> dataset = RNADataset(debug=True) >>> t(dataset[0]) >>> t(dataset) """
[docs] def __init__( self, parallel: bool = False, num_workers: int = -1, ): self.parallel = parallel self.num_workers = num_workers
def __call__(self, data: Any) -> Any: RNADataset = __import__("rnaglib.dataset").dataset.RNADataset if isinstance(data, (list, Generator, RNADataset)): if self.parallel: return list(Parallel(n_jobs=self.num_workers)(delayed(self.forward)(d) for d in data)) else: return [self.forward(d) for d in data] else: return self.forward(data) def forward(self, data: Any) -> Any: raise NotImplementedError def __repr__(self) -> str: return f"{self.__class__.__name__}()"
class IdentityTransform(Transform): def __init__(self, **kwargs): super().__init__(**kwargs) def forward(self, data: Any) -> Any: return data
[docs] class AnnotationTransform(Transform): """A transform that computes an annotation for the RNA. Same logic as the base class but implements caching logic. """
[docs] def __init__(self, **kwargs): super().__init__(**kwargs)
[docs] class FilterTransform(Transform): """Reject items from a dataset based on some conditions. The ``forward()`` method returns True/False for the given RNA and the ``__call__()`` method returns the RNAs which pass the ``forward()`` filter. """ def __call__(self, data: Any) -> Union[bool, Iterable[Any]]: """Apply the filter and return an iterator over the RNAs that pass.""" RNADataset = __import__("rnaglib.dataset").dataset.RNADataset if not isinstance(data, (list, Generator, RNADataset)): return self.forward(data) if self.parallel: keeps = Parallel(n_jobs=self.num_workers)(delayed(self.forward)(d) for d in data) return (d for d, keep in zip(data, keeps) if keep) else: return (d for d in data if self.forward(d)) def forward(self, data: dict) -> bool: """Returns true/ or false on the given RNA""" raise NotImplementedError
[docs] class PartitionTransform(Transform): """Break up a whole RNAs into substructures. Returns a new flat iterator over RNA data items. For example, splitting a list of multi-chain RNAs into a flat list of single-chain RNAs. """ def __call__(self, data: Any) -> Iterable[Any]: RNADataset = __import__("rnaglib.dataset").dataset.RNADataset if isinstance(data, (list, Generator, RNADataset)): for rna in data: yield from self.forward(rna) pass else: yield from self.forward(data) pass def new_name(self, rna_partition: dict): """Compute the name of the given partition of RNA""" raise NotImplementedError
class Compose(Transform): """Combine multiple transforms into one, applying each individual transform on each item consecutively. :param transforms: List of transforms to join together. """ def __init__(self, transforms: List[Transform], **kwargs): self.transforms = transforms super().__init__(**kwargs) pass def forward(self, data: Any): for tr in self.transforms: data = tr(data) return data def __repr__(self) -> str: """From PyG""" args = [f" {transform}" for transform in self.transforms] return "{}([\n{}\n])".format(self.__class__.__name__, ",\n".join(args)) class ComposeFilters: """Composes several filters together. :param filters: List of filter transforms to compose. """ def __init__(self, filters: List[FilterTransform]): self.filters = filters def __call__(self, data: dict) -> bool: RNADataset = __import__("rnaglib.dataset").dataset.RNADataset if not isinstance(data, (list, Generator, RNADataset)): raise ValueError("Filter compose only works on collections of RNAs") for filter_fn in self.filters: data = (d for d in data if filter_fn.forward(d)) return data def forward(self, data: dict) -> bool: all_true = True for filter_fn in self.filters: all_true = all_true and filter_fn.forward(data) if not all_true: return False return True def __repr__(self) -> str: args = [f" {filter_fn}" for filter_fn in self.filters] return "{}([\n{}\n])".format(self.__class__.__name__, ",\n".join(args))