Source code for rnaglib.tasks.RNA_Prot.protein_binding_site
import os
from tqdm import tqdm
from rnaglib.dataset import RNADataset
from rnaglib.dataset_transforms import ClusterSplitter
from rnaglib.encoders import BoolEncoder
from rnaglib.tasks import ResidueClassificationTask
from rnaglib.transforms import ConnectedComponentPartition, DummyFilter, FeaturesComputer, ResidueAttributeFilter
[docs]
class ProteinBindingSite(ResidueClassificationTask):
"""The job is to predict a binary variable
at each residue representing the probability that a residue belongs to
a protein-binding interface
Task type: binary classification
Task level: residue-level
:param tuple[int] size_thresholds: range of RNA sizes to keep in the task dataset(default (15, 500))
"""
target_var = "protein_content_8.0" # "protein_binding"
input_var = "nt_code"
name = "rna_prot"
default_metric = "balanced_accuracy"
version = "2.0.2"
[docs]
def __init__(self, size_thresholds=(15, 500), **kwargs):
meta = {"multi_label": False}
super().__init__(additional_metadata=meta, size_thresholds=size_thresholds, **kwargs)
@property
def default_splitter(self):
"""Returns the splitting strategy to be used for this specific task. Canonical splitter is ClusterSplitter which is a
similarity-based splitting relying on clustering which could be refined into a sequencce- or structure-based clustering
using distance_name argument
:return: the default splitter to be used for the task
:rtype: Splitter
"""
return ClusterSplitter(distance_name="USalign", debug=self.debug)
def get_task_vars(self):
"""Specifies the `FeaturesComputer` object of the tasks which defines the features which have to be added to the RNAs
(graphs) and nucleotides (graph nodes)
:return: the features computer of the task
:rtype: FeaturesComputer
"""
return FeaturesComputer(
nt_features=self.input_var,
nt_targets=self.target_var,
custom_encoders={self.target_var: BoolEncoder()},
)
def process(self) -> RNADataset:
""""
Creates the task-specific dataset.
:return: the task-specific dataset
:rtype: RNADataset
"""
# Define your transforms
filters = ResidueAttributeFilter(attribute=self.target_var, value_checker=lambda val: val is not None)
if self.debug:
filters = DummyFilter()
connected_components_partition = ConnectedComponentPartition()
# Run through database, applying our filters
dataset = RNADataset(debug=self.debug, in_memory=False, version=self.version)
all_rnas = []
os.makedirs(self.dataset_path, exist_ok=True)
for rna in tqdm(dataset, total=len(dataset)):
if filters.forward(rna):
for rna_connected_component in connected_components_partition(rna):
if self.size_thresholds is not None:
if not self.size_filter.forward(rna_connected_component):
continue
rna = rna_connected_component["rna"]
self.add_rna_to_building_list(all_rnas=all_rnas, rna=rna)
dataset = self.create_dataset_from_list(rnas=all_rnas)
return dataset