RNA Tasks¶
A “task” is a machine learning term refering to a learning problem formulated in a precise way with a precise measure of performance. Tasks are mostly used to benchmark modeling approaches in a fair and comparable way. While tasks are not necessarily directly transferable to an actual application, the closer they get to relevance, the higher the chance that the best model on a benchmark is also the best for applications.
In RNAGlib, we propose a set of benchmark tasks relevant to the modeling of RNA structures. A detailed description of the proposed tasks can be found here.
the Task
object uses all objects previously described in this tutorial section (RNA, RNATransforms, RNADataset, DSTransforms).
As sketched below, it contains:
an annotated dataset
train, validation and test splits
precise metrics computation
Creating a Task¶
For complete examples on how to build a task, see this detailed tutorial or source code of of the provided tasks.
Briefly, the idea is to implement a subclass of Task, with process() and post_process() methods. The process() function defines the steps to select and annotate the data (using RNATransforms). The post_process() ones compute distance matrices between points and optionnally remove redundancy (using DSTransforms).
In addition, one should implement a get_task_vars() which creates a FeaturesComputer with the features to use and predict, and a default_splitter() that returns the splitter one wants to use. Those steps are summarized below in a simplified version of the RNA_CM task:
from rnaglib.tasks import ResidueClassificationTask
from rnaglib.dataset import RNADataset
from rnaglib.transforms import FeaturesComputer
from rnaglib.dataset_transforms import RandomSplitter
class DummyTask(ResidueClassificationTask):
input_var = "nt_code"
target_var = "is_modified"
name = "rna_dummy"
def __init__(self, **kwargs):
super().__init__(additional_metadata={"multi_label": False}, **kwargs)
def process(self) -> RNADataset:
return RNADataset(in_memory=self.in_memory, debug=True)
def post_process(self):
pass
def get_task_vars(self) -> FeaturesComputer:
return FeaturesComputer(nt_features=self.input_var, nt_targets=self.target_var)
@property
def default_splitter(self):
return RandomSplitter()
Loading a Task¶
Once created, tasks can be saved and loaded seamlessly, such that we the following output:
>>> task = DummyTask(root="test_task")
>>> task = DummyTask(root="test_task")
Task was found and not overwritten
Observe that on second call, the computations were not run a second time.
Moreover, for the set of proposed tasks (available in the rnaglib.task.TASKS
variable), we host precomputed tasks on Zenodo.
>>> from rnaglib.tasks import BindingSite
>>> task_site = BindingSite(root='my_root')
Downloading task dataset from Zenodo...
Done
Users can access our tasks by id, using rnaglib.task.get_task
small function.
Using a Task¶
Once equiped with a Task, users can get datasets or dataloaders from it and train machine learning models.
>>> from rnaglib.transforms import GraphRepresentation
>>> task_site.add_representation(GraphRepresentation(framework='pyg'))
>>> train_loader, _, test_loader = task_site.get_split_loaders()
>>> for batch in train_loader:
>>> print(batch['graph']) # train
>>> break
DataBatch(x=[20, 4], edge_index=[2, 54], edge_attr=[54], y=[20], batch=[20], ptr=[2])
Importantly, the task is equipped with a principled way to compute metrics over predictions.
>>> import torch
>>> task_site.set_loaders(batch_size=1, recompute=False)
>>> all_preds, all_probs, all_labels = [], [], []
>>> for batch in task_site.train_dataloader:
>>> graph = batch['graph']
>>> label = graph.y
>>> probs = torch.rand(label.shape) # insert predictions here
>>> all_labels.append(label)
>>> all_probs.append(probs)
>>> all_preds.append(torch.round(probs))
>>>
>>> task_site.compute_metrics(all_preds, all_probs, all_labels)
{'accuracy': 0.4943, 'auc': 0.4765, 'balanced_accuracy': 0.4794, ... 'global_auc': 0.4846}