Source code for rnaglib.tasks.RNA_CM.chemical_modification

import os
from tqdm import tqdm

from rnaglib.dataset import RNADataset
from rnaglib.tasks import ResidueClassificationTask
from rnaglib.transforms import FeaturesComputer
from rnaglib.transforms import ResidueAttributeFilter, DummyFilter
from rnaglib.transforms import ConnectedComponentPartition
from rnaglib.dataset_transforms import ClusterSplitter


[docs] class ChemicalModification(ResidueClassificationTask): """Residue-level binary classification task to predict whether a given residue is chemically modified. Task type: binary classification Task level: residue-level :param tuple[int] size_thresholds: range of RNA sizes to keep in the task dataset(default (15, 500)) """ target_var = "is_modified" input_var = "nt_code" name = "rna_cm" default_metric = "balanced_accuracy" version = "2.0.2"
[docs] def __init__(self, size_thresholds=(15, 500), **kwargs): meta = {'multi_label': False} super().__init__(additional_metadata=meta, size_thresholds=size_thresholds, **kwargs)
@property def default_splitter(self): """Returns the splitting strategy to be used for this specific task. Canonical splitter is ClusterSplitter which is a similarity-based splitting relying on clustering which could be refined into a sequencce- or structure-based clustering using distance_name argument :return: the default splitter to be used for the task :rtype: Splitter """ return ClusterSplitter(distance_name="USalign") def get_task_vars(self): """Specifies the `FeaturesComputer` object of the tasks which defines the features which have to be added to the RNAs (graphs) and nucleotides (graph nodes) :return: the features computer of the task :rtype: FeaturesComputer """ return FeaturesComputer(nt_targets=self.target_var, nt_features=self.input_var) def process(self) -> RNADataset: """ Creates the task-specific dataset. :return: the task-specific dataset :rtype: RNADataset """ # Define your transforms residue_attribute_filter = ResidueAttributeFilter( attribute=self.target_var, value_checker=lambda val: val == True ) if self.debug: residue_attribute_filter = DummyFilter() connected_components_partition = ConnectedComponentPartition() # Run through database, applying our filters dataset = RNADataset(debug=self.debug, in_memory=self.in_memory, version=self.version) all_rnas = [] for rna in tqdm(dataset): for rna_connected_component in connected_components_partition(rna): if residue_attribute_filter.forward(rna_connected_component): if self.size_thresholds is not None and not self.size_filter.forward(rna_connected_component): continue rna = rna_connected_component["rna"] self.add_rna_to_building_list(all_rnas=all_rnas, rna=rna) dataset = self.create_dataset_from_list(all_rnas) print(f"len of process: {len(dataset)}") return dataset