Source code for rnaglib.dataset.rna_dataset

"""Main class for collections of RNAs."""

from collections.abc import Iterable
import copy
import json
import os
from pathlib import Path
from typing import Literal, Union
from torch.utils.data import Dataset

from bidict import bidict
import networkx as nx
import numpy as np

from rnaglib.transforms.featurize import FeaturesComputer
from rnaglib.transforms.represent import Representation
from rnaglib.transforms.transform import AnnotationTransform, Transform
from rnaglib.utils import download_graphs, dump_json, load_graph
from rnaglib.utils.graph_io import get_all_existing, get_name_extension


[docs] class RNADataset(Dataset): """This class is the main object to hold RNA data, and is compatible with Pytorch Dataset. A key feature is a bidict that holds a mapping between RNA names and their index in the dataset. This allows for constant time access to an RNA with a given name. The RNAs contained in an RNADataset can either live in the memory, or be specified as files. In the latter case, an RNADataset can be seen as an ordered list of file in a given directory. One can put a dataset in memory by calling to_memory() Once a dataset is built, you can save it, subset it and access its elements by name or by index. :param rnas: For use in memory, list of RNA objects represented as networkx graphs. :param dataset_path: If using filenames, this is the path to the folder containing the graphs to load. :param version: If using filenames, and no dataset_path is provided, this is the version of the RNA dataset download that will be used, and set as dataset_path :param redundancy: same as version, sets the redundancy mode to use if neither rnas nor dataset_path is provided. :param rna_id_subset: List of graphs filenames to grab in the dataset_path to keep instead of using all available. :param recompute_mapping: When loading a dataset, you can choose to use an existing bidict_mapping (for instance if some graphs are irrelevant) :param in_memory: When loading a dataset from files, you can choose to load the data in memory by setting in memory to true :param debug: if True, will only report 50 items :param get_pdbs: if True, will also fetch the corresponding structures. :param multigraph: Whether to load RNAs as multi-graphs or simple graphs. Multigraphs can have backbone and base pairs between the same two residues. :param transforms: An optional list of transforms to apply to rnas before calling the features computer and the representations in get_item :param features_computer: A FeaturesComputer object, useful to transform raw RNA data into tensors. :param representations: List of :class:`~rnaglib.representations.Representation` objects to apply to each item. Examples: --------- Create a default dataset:: >>> from rnaglib.dataset import RNADataset >>> dataset = RNADataset() Access the first item in the dataset:: >>> dataset[0] Each item is a dictionary with the key 'rna' holding annotations as a networkx Graph. >>> dataset['rna'].nodes() >>> dataset['rna'].edges() Access an RNA by its PDBID:: >>> dataset.get_pdbid('4nlf') .. Hint:: Pass ``debug=True`` to ``RNADataset`` to quickly load a small dataset for testing. """
[docs] def __init__( self, rnas: list[nx.Graph] = None, dataset_path: Union[str, os.PathLike] = None, version="2.0.2", redundancy="nr", rna_id_subset: list[str] = None, recompute_mapping: bool = True, in_memory: bool = None, features_computer: FeaturesComputer = None, representations: Union[list[Representation] , Representation] = None, debug: bool = False, get_pdbs: bool = True, multigraph: bool = False, transforms: Union[list[Transform], Transform] = None, ): self.transforms = [transforms] if transforms is not None and not isinstance(transforms, Iterable) else [] self.multigraph = multigraph self.version = version if dataset_path is not None: self.dataset_path = dataset_path # Distance is computed as a cached property # We potentially want to save distances and the bidict mapping self.distances_ = None self.distances_path = Path(dataset_path) / "distances.npz" if dataset_path is not None else None self.bidict_path = Path(dataset_path) / "bidict.json" if dataset_path is not None else None if rnas is None: if dataset_path is None: # By default, use non redundant (nr), v1.0.0 dataset of rglib dataset_path, structures_path = download_graphs(redundancy=redundancy, version=version, debug=debug, get_pdbs=get_pdbs) self.dataset_path = dataset_path self.structures_path = structures_path # One can restrict the number of graphs to use existing_all_rnas, extension = get_all_existing(dataset_path=self.dataset_path, all_rnas=rna_id_subset) self.extension = extension if recompute_mapping or not self.bidict_path.exists(): # Keep track of a list_id <=> system mapping. First remove extensions existing_all_rna_names = [get_name_extension(rna, permissive=True)[0] for rna in existing_all_rnas] self.all_rnas = bidict({rna: i for i, rna in enumerate(existing_all_rna_names)}) else: with self.bidict_path.open() as f: simple_dict = json.load(f) self.all_rnas = bidict(simple_dict) # If debugging, only keep the first few if debug: nb_items_to_keep = min(50, len(self.all_rnas)) self.all_rnas = bidict( {rna: i for idx, (rna, i) in enumerate(self.all_rnas.items()) if idx < nb_items_to_keep}) if in_memory is not None and in_memory: self.to_memory() else: self.rnas = None self.in_memory = False else: # handle default choice self.in_memory = in_memory if in_memory is not None else True assert self.in_memory, ( "Conflicting arguments: if an RNADataset is instantiated with a list of graphs, " "it must use 'in_memory=True'" ) self.rnas = rnas self.structures_path = None # Here we assume that rna lists contain a relevant rna.name field, which is the case for defaults rna_names = {rna.name for rna in rnas} assert "" not in rna_names, "Empty RNA name found" assert len(rna_names) == len( rnas, ), "When creating a RNAdataset from rnas, please use uniquely named networkx graphs" self.all_rnas = bidict({rna.name: i for i, rna in enumerate(rnas)}) # Now that we have the raw data setup, let us set up the features we want to be using: self.features_computer = FeaturesComputer() if features_computer is None else features_computer # Finally, let us set up the list of representations that we will be using if representations is None: self.representations = [] elif not isinstance(representations, list): self.representations = [representations] else: self.representations = representations
def __len__(self): """Return the length of the current dataset.""" return len(self.all_rnas) def __getitem__(self, idx): """Fetch one RNA and converts item from raw data to a dictionary. Applies representations and annotations to be used by loaders. :param idx: Index of dataset item to fetch. """ if idx >= len(self): raise IndexError # Recover rna name from passed index. rna_name = self.all_rnas.inv[idx] # Initialise paths nx_path, cif_path = None, None # Setting path to default path if no other is specified. if getattr(self, "dataset_path", None) is not None: nx_path = Path(self.dataset_path) / f"{rna_name}{self.extension}" if getattr(self, "structures_path", None) is not None: cif_path = Path(self.structures_path) / f"{rna_name}.cif" if self.in_memory: rna_graph = self.rnas[idx] else: rna_graph = load_graph(str(nx_path), multigraph=self.multigraph) rna_graph.name = rna_name # Compute features rna_dict = {"rna": rna_graph, "graph_path": nx_path, "cif_path": cif_path} if len(self.transforms) > 0: for transform in self.transforms: transform(rna_dict) features_dict = self.features_computer(rna_dict) # apply representations to the res_dict # each is a callable that updates the res_dict for rep in self.representations: rna_dict[rep.name] = rep(rna_graph, features_dict) return rna_dict def get_by_name(self, rna_name): """Grab an RNA by its name.""" rna_idx = self.all_rnas[rna_name] return self.__getitem__(rna_idx) def get_pdbid(self, pdbid): """Grab an RNA by its pdbid. Just a shortcut that includes a lower() call""" return self.get_by_name(pdbid.lower()) def to_memory(self): """Make in_memory=True from a dataset not in memory.""" if not hasattr(self, "rnas"): self.rnas = [load_graph(Path(self.dataset_path) / f"{g_name}{self.extension}") for g_name in self.all_rnas] for rna, name in zip(self.rnas, self.all_rnas, strict=False): rna.name = name self.in_memory = True def check_consistency(self): """Make sure all RNAs actually present when in_memory is true.""" if self.in_memory: assert list(self.all_rnas) == [rna.name for rna in self.rnas] else: print("Check consistency only works if in_memory is true.") @property def distances(self): """Using a cached property is useful for loading precomputed data. If this is the first call, try loading. Otherwise, return the loaded value """ if self.distances_ is not None: return self.distances_ if self.distances_path is not None and Path(self.distances_path).exists(): # Actually materialize memory (lightweight anyway) since npz loading is lazy all_distances = dict(np.load(self.distances_path).items()) # Filter to keep only square matrices with dimensions matching our dataset length self.distances_ = {k: v for k, v in all_distances.items() if v.shape[0] == v.shape[1] == len(self)} return self.distances_ return None def remove_distance(self, name): """Removes a distance from the dataset.""" if self.distances is not None and name in self.distances: del self.distances_[name] def add_distance(self, name, distance_mat): """Adds a distance matrix to the dataset.""" assert distance_mat.shape[0] == distance_mat.shape[1] == len(self) if self.distances is None: self.distances_ = {name: distance_mat} else: self.distances_[name] = distance_mat def save_distances(self, dump_path=None): """Saves distances to distance path.""" if self.distances is not None: dump_path = dump_path if dump_path is not None else self.distances_path np.savez(dump_path, **self.distances) def add_representation(self, representations: Union[list[Representation], Representation]): """Add a representation object to dataset. Provided representations are added on the fly to the dataset. :param representations: List of ``Representation`` objects to add. """ representations = [representations] if not isinstance(representations, list) else representations to_print = [repre.name for repre in representations] if len(representations) > 1 else representations[0].name print(f">>> Adding {to_print} to dataset representations.") # Remove old representations with the same name new_representations = set([repre.name for repre in representations]) to_remove = {repr.name for repr in self.representations if repr.name in new_representations} if len(to_remove) > 0: print(f"Removing old representations of {to_remove} from existing representation to avoid clash") self.representations = [repr for repr in self.representations if not repr.name in to_remove] self.representations.extend(representations) def remove_representation(self, names): """Removes specified representation.""" names = [names] if not isinstance(names, Iterable) else names for name in names: self.representations = [ representation for representation in self.representations if representation.name != name ] def add_feature( self, feature: Union[str, AnnotationTransform], feature_level: Literal["residue", "rna"] = "residue", *, # enforce keyword only arguments is_input: bool = True, ): """Add a feature to the dataset for model training. If you pass a string, we use it to pull the feature from the RNA dictionary. If you pass an AnnotationTransform, we check if it has been applied already , if not apply it to store the annotation in the dataset and then use it as a feature. :param feature: Can be a string representing a key in the RNA dict or an AnnotationTransform. :param feature_level: Residue-level (`residue`), or RNA-level (`rna`) feature. :param is_input: Are you using the feature on the input side (`True`) or as a prediction target (`False`)? """ feature_name = feature custom_encoders = None # using an existing key in the RNA dictionary as feature if isinstance(feature, Transform): # check if transform has already been applied g = self[0]["rna"] node = next(iter(g.nodes)) feature_exists = False if feature_level == "residue" and g.nodes[node].get(feature.name) is not None: feature_exists = True if feature_level == "rna" and g.graph.get(feature.name) is not None: feature_exists = True # Only apply transform if it hasn't been applied yet if not feature_exists: feature(self) feature_name = feature.name custom_encoders = {feature_name: feature.encoder} self.features_computer.add_feature( feature_names=feature_name, feature_level=feature_level, input_feature=is_input, custom_encoders=custom_encoders, ) def subset(self, list_of_ids=None, list_of_names=None): """Create another dataset with only the specified graphs. :param list_of_names: a list of rna names (no extension is expected) :param list_of_ids: a list of rna ids :return: An RNADataset with only the specified graphs/ids """ # You can't subset on both simultaneously assert list_of_ids is None or list_of_names is None if list_of_names is not None: existing_names = set(self.all_rnas.keys()) list_of_ids = [self.all_rnas[name] for name in list_of_names if name in existing_names] else: existing_ids = set(self.all_rnas.values()) list_of_ids = [id_rna for id_rna in list_of_ids if id_rna in existing_ids] # Copy existing dataset, avoid expensive deep copy of rnas if in memory temp = self.rnas self.rnas = None subset = copy.deepcopy(self) self.rnas = temp # Subset the bidict of names and the rna if in_memory if self.in_memory: subset.rnas = [self.rnas[i] for i in list_of_ids] subset_names = [self.all_rnas.inv[i] for i in list_of_ids] subset.all_rnas = bidict({rna: i for i, rna in enumerate(subset_names)}) # Update the distance matrices if self.distances is not None: for distance_name in self.distances: subset.add_distance(distance_name, self.distances[distance_name][np.ix_(list_of_ids, list_of_ids)]) return subset def save(self, dump_path, *, recompute=False, verbose=True): """Save a local copy of the dataset.""" print(f"dumping {len(self.all_rnas)} rnas") Path(dump_path).mkdir(parents=True, exist_ok=True) dump_dists = Path(dump_path) / "distances.npz" dump_bidict = Path(dump_path) / "bidict.json" self.save_distances(dump_path=dump_dists) with open(dump_bidict, "w") as js: json.dump(dict(self.all_rnas), js) # Check if all files are already there existing_files = set(os.listdir(dump_path)) to_dump = set([x + '.json' for x in self.all_rnas.keys()]) if to_dump.issubset(existing_files) and not recompute: if verbose: print('files already exist, set "recompute=True" to overwrite') for rna_name, i in self.all_rnas.items(): if not self.in_memory: rna_graph = load_graph(Path(self.dataset_path) / f"{rna_name}.json") else: rna_graph = self.rnas[i] dump_json(Path(dump_path) / f"{rna_name}.json", rna_graph)
if __name__ == "__main__": from rnaglib.transforms import GraphRepresentation features_computer = FeaturesComputer(nt_features="nt_code", nt_targets="binding_protein") graph_rep = GraphRepresentation(framework="dgl") all_rnas = [ "1a9n.json", "1b23.json", "1b7f.json", "1csl.json", "1d4r.json", "1dfu.json", "1duq.json", "1e8o.json", "1ec6.json", "1et4.json", ] all_rna_names = [name[:-5] for name in all_rnas] script_dir = Path(__file__).resolve().parent dataset_path = script_dir / "../data/test" # # First case # supervised_dataset = RNADataset(all_rnas=all_rnas, # features_computer=features_computer, # representations=[graph_rep]) # g1 = supervised_dataset[0] # a = list(g1['rna'].nodes(data=True))[0][1] # Test in_memory field # supervised_dataset = RNADataset(dataset_path=dataset_path, representations=graph_rep, in_memory=False) # g2 = supervised_dataset[0] # Test subsetting # supervised_dataset = RNADataset(dataset_path=dataset_path, representations=graph_rep, in_memory=False) # subset = supervised_dataset.subset(list_of_names=all_rna_names[:5]) # subset2 = subset.subset(list_of_ids=[1, 3, 4]) # Test saving # supervised_dataset = RNADataset(dataset_path=dataset_path, representations=graph_rep, in_memory=True) # supervised_dataset.save(os.path.join(script_dir, "../data/test_dump")) # supervised_dataset.check_consistency()