import os
import hashlib
from pathlib import Path
import json
from functools import cached_property
import numpy as np
from sklearn.metrics import matthews_corrcoef, f1_score, accuracy_score, roc_auc_score
import torch
from import DataLoader
from typing import Union, Optional
from rnaglib.data_loading import RNADataset, Collater
from rnaglib.transforms import FeaturesComputer
from rnaglib.splitters import Splitter, RandomSplitter
from rnaglib.utils import DummyResidueModel, DummyGraphModel, tonumpy
[docs]class Task:
"""Abstract class for a benchmarking task using the rnaglib datasets.
This class handles the logic for building the underlying dataset which is held in an
object. Once the dataset is created, the splitter is invoked to create the train/val/test indices.
Tasks also define an evaluate() function to yield appropriate model performance metrics.
:param root: path to a folder where the task information will be stored for fast loading.
:param recompute: whether to recompute the task info from scratch or use what is stored in root.
:param splitter: rnaglib.splitters.Splitter object that handles splitting of data into train/val/test indices.
If None uses task's default_splitter() attribute.
[docs] def __init__(
root: Union[str, os.PathLike],
recompute: bool = False,
splitter: Splitter = None,
debug: bool = False,
save: bool = True,
in_memory: bool = True,
self.root = root
self.dataset_path = os.path.join(self.root, "dataset")
self.recompute = recompute
self.debug = debug = save
self.in_memory = in_memory
self.metadata = self.init_metadata()
# Load or create dataset
if not os.path.exists(self.dataset_path) or recompute:
print(">>> Creating task dataset from scratch...")
self.dataset = self.process()
self.dataset, self.metadata, (self.train_ind, self.val_ind, self.test_ind) = self.load()
# Set splitter after dataset is available
self.splitter = self.default_splitter if splitter is None else splitter
# Split dataset if it wasn't loaded from file
if not hasattr(self, "train_ind"):
self.dataset.features_computer = self.get_task_vars()
# compute metadata
def process(self) -> RNADataset:
"""Tasks must implement this method. Executing the method should result in a list of ``.json`` files
saved in ``{root}/dataset``. All the RNA graphs should contain all the annotations needed to run the task (e.g. node/edge attributes)
raise NotImplementedError
def init_metadata(self) -> dict:
"""Optionally adds some key/value pairs to self.metadata."""
return {}
def get_task_vars(self) -> FeaturesComputer:
"""Define a FeaturesComputer object to set which input and output variables will be used in the task."""
return FeaturesComputer()
def default_splitter(self):
return RandomSplitter()
def split(self, dataset):
"""Calls the splitter and returns train, val, test splits."""
splits = self.splitter(dataset)
self.train_ind, self.val_ind, self.test_ind = splits
return splits
def set_datasets(self, recompute=True):
"""Sets the train, val and test datasets
Call this each time you modify ``self.dataset``."""
if not hasattr(self, "train_ind") or recompute:
self.train_ind, self.val_ind, self.test_ind = self.split(self.dataset)
self.train_dataset = self.dataset.subset(self.train_ind)
self.val_dataset = self.dataset.subset(self.val_ind)
self.test_dataset = self.dataset.subset(self.test_ind)
def set_loaders(self, recompute=True, **dataloader_kwargs):
"""Sets the dataloader properties.
Call this each time you modify ``self.dataset``."""
# If no collater is provided we need one
if dataloader_kwargs is None:
dataloader_kwargs = {"collate_fn": Collater(self.train_dataset)}
if "collate_fn" not in dataloader_kwargs:
collater = Collater(self.train_dataset)
dataloader_kwargs["collate_fn"] = collater
# Now build the loaders
self.train_dataloader = DataLoader(dataset=self.train_dataset, **dataloader_kwargs)
dataloader_kwargs["shuffle"] = False
self.val_dataloader = DataLoader(dataset=self.val_dataset, **dataloader_kwargs)
self.test_dataloader = DataLoader(dataset=self.test_dataset, **dataloader_kwargs)
def get_split_datasets(self, recompute=True):
# If datasets were not already computed or if we want to recompute them to account
# for changes in the global dataset
if recompute or "train_dataset" not in self.__dict__:
print(">>> Splitting the dataset...")
print(">>> Done")
return self.train_dataset, self.val_dataset, self.test_dataset
def get_split_loaders(self, recompute=True, **dataloader_kwargs):
# If dataloaders were not already precomputed or if we want to recompute them to account
# for changes in the global dataset
if recompute or "train_dataloader" not in self.__dict__:
self.set_loaders(recompute=recompute, **dataloader_kwargs)
return self.train_dataloader, self.val_dataloader, self.test_dataloader
def evaluate(self, model, loader) -> dict:
raise NotImplementedError
def task_id(self):
"""Task hash is a hash of all RNA ids and node IDs in the dataset"""
h ="sha256")
if not self.in_memory:
raise ValueError("task id is only available (and tractable) for small, in-memory datasets")
for rna in self.dataset.rnas:
for nt in sorted(rna.nodes()):
[h.update(str(i).encode("utf-8")) for i in self.train_ind]
[h.update(str(i).encode("utf-8")) for i in self.val_ind]
[h.update(str(i).encode("utf-8")) for i in self.test_ind]
return h.hexdigest()
def write(self):
"""Save task data and splits to root. Creates a folder in ``root`` called
``'graphs'`` which stores the RNAs that form the dataset, and three `.txt` files (`'{train, val, test}_idx.txt'`,
one for each split with a list of indices.
if not os.path.exists(self.dataset_path) or self.recompute:
print(">>> Saving dataset."), recompute=self.recompute)
with open(Path(self.root) / "train_idx.txt", "w") as idx:
[idx.write(str(ind) + "\n") for ind in self.train_ind]
with open(Path(self.root) / "val_idx.txt", "w") as idx:
[idx.write(str(ind) + "\n") for ind in self.val_ind]
with open(Path(self.root) / "test_idx.txt", "w") as idx:
[idx.write(str(ind) + "\n") for ind in self.test_ind]
with open(Path(self.root) / "metadata.json", "w") as meta:
json.dump(self.metadata, meta, indent=4)
# task id is only available (and tractable) for small, in-memory datasets
if self.in_memory:
with open(Path(self.root) / "task_id.txt", "w") as tid:
print(">>> Done")
def load(self):
"""Load dataset and splits from disk."""
# load splits
print(">>> Loading precomputed dataset...")
train_ind = [int(ind) for ind in open(os.path.join(self.root, "train_idx.txt"), "r").readlines()]
val_ind = [int(ind) for ind in open(os.path.join(self.root, "val_idx.txt"), "r").readlines()]
test_ind = [int(ind) for ind in open(os.path.join(self.root, "test_idx.txt"), "r").readlines()]
dataset = RNADataset(dataset_path=self.dataset_path, in_memory=self.in_memory, debug=self.debug)
with open(Path(self.root) / "metadata.json", "r") as meta:
metadata = json.load(meta)
return dataset, metadata, (train_ind, val_ind, test_ind)
def __eq__(self, other):
return self.task_id == other.task_id
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
def describe(self, recompute=False):
Get description of task dataset, including dimensions needed for model initialization
and other relevant statistics. Prints the description and returns it as a dict.
dict: Contains dataset information and model dimensions
if not recompute and "description" in self.metadata:
info = self.metadata["description"]
print(">>> Computing description of task...")
# Get dimensions from first graph
first_item = self.dataset[0]
compute_num_edge_attributes = "graph" in first_item
first_node_map = {n: i for i, n in enumerate(sorted(first_item["rna"].nodes()))}
first_features_dict = self.dataset.features_computer(first_item)
first_features_array = first_features_dict["nt_features"][next(iter(first_node_map.keys()))]
num_node_features = first_features_array.shape[0]
# Dynamic class counting
class_counts = {}
classes = set()
unique_edge_attrs = set() # only used with graphs
# Collect statistics from dataset
import tqdm
for item in tqdm.tqdm(self.dataset):
if compute_num_edge_attributes:
graph = item["graph"]
node_map = {n: i for i, n in enumerate(sorted(item["rna"].nodes()))}
features_dict = self.dataset.features_computer(item)
if "nt_targets" in features_dict:
list_y = [features_dict["nt_targets"][n] for n in node_map.keys()]
# In the case of single target, pytorch CE loss expects shape (n,) and not (n,1)
# For multi-target cases, we stack to get (n,d)
if len(list_y[0]) == 1:
y =
y = torch.stack(list_y)
if "rna_targets" in features_dict:
y = features_dict["rna_targets"].clone().detach()
graph_classes = y.unique().tolist()
# Count classes in this graph
for cls in graph_classes:
cls_int = int(cls)
if cls_int not in class_counts:
class_counts[cls_int] = 0
class_counts[cls_int] += torch.sum(y == cls).item()
info = {
"num_node_features": num_node_features,
"num_classes": len(classes),
"dataset_size": len(self.dataset),
"class_distribution": class_counts,
if compute_num_edge_attributes:
info["num_edge_attributes"] = len(unique_edge_attrs)
with open(Path(self.root) / "metadata.json", "w") as meta:
json.dump(self.metadata, meta, indent=4)
self.metadata["description"] = info
# Print description
print("Dataset Description:")
for k, v in info.items():
if k != "class_distribution":
print(k, " : ", v)
print("Class distribution:")
for cls in sorted(v.keys()):
print(f"\tClass {cls}: {v[cls]} nodes")
return info
class ClassificationTask(Task):
def __init__(self, graph_level=False, num_classes=None, **kwargs):
self.num_classes = self.metadata["description"]["num_classes"] if num_classes is None else num_classes
self.graph_level = graph_level
def dummy_model(self) -> torch.nn:
if self.graph_level:
return DummyGraphModel(num_classes=self.num_classes)
return DummyResidueModel(num_classes=self.num_classes)
def dummy_inference(self):
all_probs = []
all_preds = []
all_labels = []
dummy_model = self.dummy_model
with torch.no_grad():
for batch in self.test_dataloader:
graph = batch["graph"]
out = dummy_model(graph)
labels = graph.y
# get preds and probas + cast to numpy
if self.num_classes == 2:
probs = torch.sigmoid(out.flatten())
preds = (probs > 0.5).float()
probs = torch.softmax(out, dim=1)
preds = probs.argmax(dim=1)
probs = tonumpy(probs)
preds = tonumpy(preds)
labels = tonumpy(labels)
# split predictions per RNA if residue level
if not self.graph_level:
cumulative_sizes = tuple(tonumpy(graph.ptr))
probs = [probs[start:end] for start, end in zip(cumulative_sizes[:-1], cumulative_sizes[1:])]
preds = [preds[start:end] for start, end in zip(cumulative_sizes[:-1], cumulative_sizes[1:])]
labels = [labels[start:end] for start, end in zip(cumulative_sizes[:-1], cumulative_sizes[1:])]
if self.graph_level:
all_probs = np.stack(all_probs)
all_preds = np.stack(all_preds)
all_labels = np.stack(all_labels)
return 0, all_preds, all_probs, all_labels
def compute_one_metric(self, preds, probs, labels):
one_metric = {
"accuracy": accuracy_score(labels, preds),
"f1": f1_score(labels, preds, average="binary" if self.num_classes == 2 else "macro"),
"mcc": matthews_corrcoef(labels, preds),
one_metric["auc"] = roc_auc_score(
labels, probs, average=None if self.num_classes == 2 else "macro", multi_class="ovo"
return one_metric
return one_metric
def compute_metrics(self, all_preds, all_probs, all_labels):
if self.graph_level:
return self.compute_one_metric(all_preds, all_probs, all_labels)
# Here we have a list of preds [(n1,), (n2,)...] for each residue in each RNA
# Either compute the overall flattened results, or aggregate by system
sorted_keys = []
metrics = []
for pred, prob, label in zip(all_preds, all_probs, all_labels):
# Can't compute metrics over just one class
if len(np.unique(label)) == 1:
one_metric = self.compute_one_metric(pred, prob, label)
metrics.append([v for k, v in sorted(one_metric.items())])
sorted_keys = sorted(one_metric.keys())
metrics = np.array(metrics)
mean_metrics = np.mean(metrics, axis=0)
metrics = {k: v for k, v in zip(sorted_keys, mean_metrics)}
# Get the flattened result, renamed to include "global"
all_preds = np.concatenate(all_preds)
all_probs = np.concatenate(all_probs)
all_labels = np.concatenate(all_labels)
global_metrics = self.compute_one_metric(all_preds, all_probs, all_labels)
metrics_global = {f"global_{k}": v for k, v in global_metrics.items()}
return metrics
class ResidueClassificationTask(ClassificationTask):
def __init__(self, **kwargs):
super().__init__(graph_level=False, **kwargs)
[docs]class RNAClassificationTask(ClassificationTask):
[docs] def __init__(self, **kwargs):
super().__init__(graph_level=True, **kwargs)
[docs]class ResidueClassificationTask(ClassificationTask):
[docs] def __init__(self, **kwargs):
super().__init__(graph_level=False, **kwargs)