Source code for rnaglib.tasks.RNA_IF.inverse_folding

"""Inverse Folding task definitions"""

import os
from collections import defaultdict
from pathlib import Path

import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, matthews_corrcoef, roc_auc_score
from tqdm import tqdm

from rnaglib.dataset import RNADataset
from rnaglib.dataset_transforms import CDHitComputer, StructureDistanceComputer, RedundancyRemover
from rnaglib.dataset_transforms import ClusterSplitter, NameSplitter
from rnaglib.encoders import BoolEncoder, NucleotideEncoder
from rnaglib.tasks import ResidueClassificationTask
from rnaglib.transforms import ChainFilter, ConnectedComponentPartition, DummyAnnotator, FeaturesComputer


[docs] class InverseFolding(ResidueClassificationTask): """RNA design task, taking as input the structures with the identity of the residues masked and trying to find it back Task type: multi-class classification Task level: residue-level :param tuple[int] size_thresholds: range of RNA sizes to keep in the task dataset(default (15, 500)) """ target_var = "nt_code" # in rna graph input_var = "dummy" # should be dummy variable nucs = ["A", "C", "G", "U"] name = "rna_if" default_metric = "accuracy" version = "2.0.2"
[docs] def __init__(self, size_thresholds=(15, 300), **kwargs): meta = {"multi_label": False} super().__init__(additional_metadata=meta, size_thresholds=size_thresholds, **kwargs)
@property def default_splitter(self): """Returns the splitting strategy to be used for this specific task. Canonical splitter is ClusterSplitter which is a similarity-based splitting relying on clustering which could be refined into a sequencce- or structure-based clustering using distance_name argument :return: the default splitter to be used for the task :rtype: Splitter """ return ClusterSplitter(distance_name="USalign") def process(self) -> RNADataset: """" Creates the task-specific dataset. :return: the task-specific dataset :rtype: RNADataset """ print(">>> RNA_IF process") # Define your transforms annotate_rna = DummyAnnotator() connected_components_partition = ConnectedComponentPartition() # Run through database, applying our filters dataset = RNADataset(in_memory=self.in_memory, redundancy="all", debug=self.debug, version=self.version) all_rnas = [] os.makedirs(self.dataset_path, exist_ok=True) for i, rna in tqdm(enumerate(dataset), total=len(dataset)): for rna_connected_component in connected_components_partition(rna): if self.size_thresholds is not None and not self.size_filter.forward(rna_connected_component): continue rna = annotate_rna(rna_connected_component)["rna"] self.add_rna_to_building_list(all_rnas=all_rnas, rna=rna) dataset = self.create_dataset_from_list(all_rnas) return dataset def post_process(self): """The task-specific post processing steps to remove redundancy and compute distances which will be used by the splitters. """ print(">>> RNA_IF post") cd_hit_computer = CDHitComputer(similarity_threshold=0.99) cd_hit_rr = RedundancyRemover(distance_name="cd_hit", threshold=0.9) self.dataset = cd_hit_computer(self.dataset) self.dataset = cd_hit_rr(self.dataset) us_align_computer = StructureDistanceComputer(name="USalign") self.dataset = us_align_computer(self.dataset) self.dataset.save_distances() def get_task_vars(self) -> FeaturesComputer: """Specifies the `FeaturesComputer` object of the tasks which defines the features which have to be added to the RNAs (graphs) and nucleotides (graph nodes) :return: the features computer of the task :rtype: FeaturesComputer """ return FeaturesComputer( nt_features=self.input_var, nt_targets=self.target_var, custom_encoders={ self.input_var: BoolEncoder(), self.target_var: NucleotideEncoder(), }, ) def compute_one_metric(self, preds, unfiltered_preds, probs, labels, unfiltered_labels): # Calculate metrics only on standard nucleotides # Note that accuracy is equivalent to sequence recovery rate one_metric = { "accuracy": accuracy_score(labels, preds), "mcc": matthews_corrcoef(labels, preds), "macro_f1": f1_score(labels, preds, average="macro"), "weighted_f1": f1_score(labels, preds, average="weighted"), # Calculate coverage (percentage of predictions that are standard nucleotides) "coverage": (unfiltered_preds != 0).mean(), # Add non-standard nucleotide statistics "non_standard_ratio": (unfiltered_labels == 0).mean(), } # Only calculate AUC for standard nucleotides, don't forget to offset i for i, nuc in enumerate(self.nucs): binary_labels = labels == i + 1 binary_probs = probs[:, i + 1] binary_preds = preds == i + 1 try: one_metric[f"auc_{nuc}"] = roc_auc_score(binary_labels, binary_probs) one_metric[f"f1_{nuc}"] = f1_score(binary_labels, binary_preds) except ValueError: one_metric[f"auc_{nuc}"] = float("nan") one_metric[f"f1_{nuc}"] = float("nan") # Add average AUC valid_aucs = [v for k, v in one_metric.items() if k.startswith("auc_") and not np.isnan(v)] one_metric["mean_auc"] = np.mean(valid_aucs) if valid_aucs else float("nan") return one_metric def compute_metrics(self, all_preds, all_probs, all_labels): """Evaluate model performance on nucleotide prediction task. Returns: dict: Dictionary containing metrics including loss if criterion provided Note: Label 0 represents non-standard/unknown nucleotides and is excluded from performance metrics to focus on ACGU prediction quality. """ # Some metrics are computed only on standard nucleotides # Compute filtered versions of the predictions filtered_all_preds, filtered_all_probs, filtered_all_labels = [], [], [] for pred, prob, label in zip(all_preds, all_probs, all_labels, strict=False): valid_mask = label != 0 if len(valid_mask) > 0: filt_pred = pred[valid_mask] filt_prob = prob[valid_mask] filt_label = label[valid_mask] filtered_all_preds.append(filt_pred) filtered_all_probs.append(filt_prob) filtered_all_labels.append(filt_label) # Here we have a list of preds [(n1,), (n2,)...] for each residue in each RNA # Either compute the overall flattened results, or aggregate by system sorted_keys = [] metrics = [] for pred, filt_pred, prob, label, filt_label in zip( all_preds, filtered_all_preds, all_probs, all_labels, filtered_all_labels, strict=False, ): # Can't compute metrics over just one class if len(np.unique(label)) == 1: continue one_metric = self.compute_one_metric(pred, filt_pred, prob, label, filt_label) metrics.append([v for k, v in sorted(one_metric.items())]) # metrics.append(np.array([v for k, v in sorted(one_metric.items())])) sorted_keys = sorted(one_metric.keys()) metrics = np.array(metrics) mean_metrics = np.nanmean(metrics, axis=0) metrics = {k: v for k, v in zip(sorted_keys, mean_metrics, strict=False)} # Get the flattened result, renamed to include "global" filtered_all_preds = np.concatenate(filtered_all_preds) all_preds = np.concatenate(all_preds) filtered_all_probs = np.concatenate(filtered_all_probs) all_labels = np.concatenate(all_labels) filtered_all_labels = np.concatenate(filtered_all_labels) global_metrics = self.compute_one_metric( filtered_all_preds, all_preds, filtered_all_probs, filtered_all_labels, all_labels, ) metrics_global = {f"global_{k}": v for k, v in global_metrics.items()} metrics.update(metrics_global) # Add confusion matrix (including non-standard nucleotides) cm = confusion_matrix(all_labels, all_preds) metrics["confusion_matrix"] = cm.tolist() return metrics
[docs] class gRNAde(InverseFolding): """This class is a subclass of InverseFolding and is used to train a model on the gRNAde dataset. Task type: multi-class classification Task level: residue-level :param tuple[int] size_thresholds: range of RNA sizes to keep in the task dataset(default (15, 500)) """ # everything is inherited except for process and splitter. name = "rna_if_bench"
[docs] def __init__(self, size_thresholds=(15, 300), **kwargs): self.splits = { # Use sets instead of lists for chains since order doesn't matter "pdb_to_chain_train": defaultdict(set), "pdb_to_chain_test": defaultdict(set), "pdb_to_chain_val": defaultdict(set), "pdb_to_chain_all": defaultdict(set), "pdb_to_chain_all_single": defaultdict(set), } # Populate the structure data_dir = Path(os.path.dirname(os.path.abspath(__file__))) / "data" for split in ["train", "test", "val"]: file_path = data_dir / f"{split}_ids_das.txt" with open(file_path) as f: for i, line in enumerate(f): line = line.strip() pdb_id = line.split("_")[0].lower() chain = line.split("_")[-1] chain_components = list(chain.split("-")) # Using update for sets automatically handles duplicates self.splits[f"pdb_to_chain_{split}"][pdb_id].add(chain) self.splits["pdb_to_chain_all"][pdb_id].add(chain) self.splits["pdb_to_chain_all_single"][pdb_id].update(chain_components) super().__init__(size_thresholds=size_thresholds, **kwargs)
@property def default_splitter(self): """Returns the splitting strategy to be used for this specific task. In this case, an ad hoc splitter is being applied to match the train, val and test splits used in gRNAde. :return: the default splitter to be used for the task :rtype: Splitter """ train_names = [ f"{pdb.lower()}_{chain}" for pdb in self.splits["pdb_to_chain_train"] for chain in self.splits["pdb_to_chain_train"][pdb] ] val_names = [ f"{pdb.lower()}_{chain}" # .upper() for pdb in self.splits["pdb_to_chain_val"] for chain in self.splits["pdb_to_chain_val"][pdb] ] test_names = [ f"{pdb.lower()}_{chain}" # .upper() for pdb in self.splits["pdb_to_chain_test"] for chain in self.splits["pdb_to_chain_test"][pdb] ] return NameSplitter(train_names, val_names, test_names) def process(self) -> RNADataset: """" Creates the task-specific dataset. :return: the task-specific dataset :rtype: RNADataset """ pdb_to_single_chains = { pdb.lower(): [chain for chain in self.splits["pdb_to_chain_all_single"][pdb]] for pdb in self.splits["pdb_to_chain_all_single"] } chain_filter = ChainFilter(pdb_to_single_chains) annote_dummy = DummyAnnotator() # Initialize dataset with in_memory=False to avoid loading everything at once rna_ids = list(pdb_to_single_chains.keys()) source_dataset = RNADataset(rna_id_subset=rna_ids, redundancy="all", in_memory=False, debug=self.debug, version=self.version) all_rnas = [] os.makedirs(self.dataset_path, exist_ok=True) for rna in tqdm(source_dataset): if chain_filter.forward(rna): rna = annote_dummy(rna) base_graph = rna["rna"] pdb = base_graph.name for chain in self.splits["pdb_to_chain_all"][pdb]: chain_components = set(chain.split("-")) selected_nodes = [node for node in base_graph.nodes() if node.split(".")[1] in chain_components] selected_chains = base_graph.copy().subgraph(selected_nodes) selected_chains.name = f"{pdb.lower()}_{chain}" self.add_rna_to_building_list(all_rnas=all_rnas, rna=selected_chains) dataset = self.create_dataset_from_list(all_rnas) return dataset def post_process(self): print("gRNAde post process")