"""Inverse Folding task definitions"""
import os
from collections import defaultdict
import numpy as np
from pathlib import Path
from sklearn.metrics import matthews_corrcoef, f1_score, accuracy_score, roc_auc_score, confusion_matrix
from rnaglib.data_loading import RNADataset
from rnaglib.transforms import FeaturesComputer, DummyAnnotator, ComposeFilters, RibosomalFilter, RNAAttributeFilter
from rnaglib.transforms import NameFilter, ChainFilter, ChainSplitTransform, ChainNameTransform
from rnaglib.tasks import ResidueClassificationTask
from rnaglib.encoders import BoolEncoder, NucleotideEncoder
from rnaglib.splitters import NameSplitter
from rnaglib.utils import dump_json
[docs]class InverseFolding(ResidueClassificationTask):
target_var = "nt_code" # in rna graph
input_var = "dummy" # should be dummy variable
nucs = ["A", "C", "G", "U"]
[docs] def __init__(self, **kwargs):
super().__init__(**kwargs)
def process(self) -> RNADataset:
# build the filters
ribo_filter = RibosomalFilter()
resolution_filter = RNAAttributeFilter(
attribute="resolution_high", value_checker=lambda val: float(val[0]) < 4.0
)
filters = ComposeFilters([ribo_filter, resolution_filter])
# Define your transforms
annotate_rna = DummyAnnotator()
# Run through database, applying our filters
dataset = RNADataset(debug=self.debug, in_memory=self.in_memory)
all_rnas = []
os.makedirs(self.dataset_path, exist_ok=True)
for rna in dataset:
if filters.forward(rna):
rna = annotate_rna(rna)["rna"]
if self.in_memory:
all_rnas.append(rna)
else:
all_rnas.append(rna.name)
dump_json(os.path.join(self.dataset_path, f"{rna.name}.json"), rna)
if self.in_memory:
dataset = RNADataset(rnas=all_rnas)
else:
dataset = RNADataset(dataset_path=self.dataset_path, rna_id_subset=all_rnas)
return dataset
def get_task_vars(self) -> FeaturesComputer:
return FeaturesComputer(
nt_features=self.input_var,
nt_targets=self.target_var,
custom_encoders={
self.input_var: BoolEncoder(),
self.target_var: NucleotideEncoder(),
},
)
def compute_one_metric(self, preds, unfiltered_preds, probs, labels, unfiltered_labels):
# Calculate metrics only on standard nucleotides
# Note that accuracy is equivalent to sequence recovery rate
one_metric = {
"accuracy": accuracy_score(labels, preds),
"mcc": matthews_corrcoef(labels, preds),
"macro_f1": f1_score(labels, preds, average="macro"),
"weighted_f1": f1_score(labels, preds, average="weighted"),
# Calculate coverage (percentage of predictions that are standard nucleotides)
"coverage": (unfiltered_preds != 0).mean(),
# Add non-standard nucleotide statistics
"non_standard_ratio": (unfiltered_labels == 0).mean(),
}
# Only calculate AUC for standard nucleotides, don't forget to offset i
for i, nuc in enumerate(self.nucs):
binary_labels = labels == i + 1
binary_probs = probs[:, i + 1]
binary_preds = preds == i + 1
try:
one_metric[f"auc_{nuc}"] = roc_auc_score(binary_labels, binary_probs)
one_metric[f"f1_{nuc}"] = f1_score(binary_labels, binary_preds)
except ValueError:
one_metric[f"auc_{nuc}"] = float("nan")
one_metric[f"f1_{nuc}"] = float("nan")
# Add average AUC
valid_aucs = [v for k, v in one_metric.items() if k.startswith("auc_") and not np.isnan(v)]
one_metric["mean_auc"] = np.mean(valid_aucs) if valid_aucs else float("nan")
return one_metric
def compute_metrics(self, all_preds, all_probs, all_labels):
"""
Evaluate model performance on nucleotide prediction task
Returns:
dict: Dictionary containing metrics including loss if criterion provided
Note: Label 0 represents non-standard/unknown nucleotides and is excluded
from performance metrics to focus on ACGU prediction quality.
"""
# Some metrics are computed only on standard nucleotides
# Compute filtered versions of the predictions
filtered_all_preds, filtered_all_probs, filtered_all_labels = [], [], []
for pred, prob, label in zip(all_preds, all_probs, all_labels):
valid_mask = label != 0
if len(valid_mask) > 0:
filt_pred = pred[valid_mask]
filt_prob = prob[valid_mask]
filt_label = label[valid_mask]
filtered_all_preds.append(filt_pred)
filtered_all_probs.append(filt_prob)
filtered_all_labels.append(filt_label)
# Here we have a list of preds [(n1,), (n2,)...] for each residue in each RNA
# Either compute the overall flattened results, or aggregate by system
sorted_keys = []
metrics = []
for pred, filt_pred, prob, label, filt_label in zip(
all_preds, filtered_all_preds, all_probs, all_labels, filtered_all_labels
):
# Can't compute metrics over just one class
if len(np.unique(label)) == 1:
continue
one_metric = self.compute_one_metric(pred, filt_pred, prob, label, filt_label)
metrics.append([v for k, v in sorted(one_metric.items())])
# metrics.append(np.array([v for k, v in sorted(one_metric.items())]))
sorted_keys = sorted(one_metric.keys())
metrics = np.array(metrics)
mean_metrics = np.nanmean(metrics, axis=0)
metrics = {k: v for k, v in zip(sorted_keys, mean_metrics)}
# Get the flattened result, renamed to include "global"
filtered_all_preds = np.concatenate(filtered_all_preds)
all_preds = np.concatenate(all_preds)
filtered_all_probs = np.concatenate(filtered_all_probs)
all_labels = np.concatenate(all_labels)
filtered_all_labels = np.concatenate(filtered_all_labels)
global_metrics = self.compute_one_metric(
filtered_all_preds, all_preds, filtered_all_probs, filtered_all_labels, all_labels
)
metrics_global = {f"global_{k}": v for k, v in global_metrics.items()}
metrics.update(metrics_global)
# Add confusion matrix (including non-standard nucleotides)
cm = confusion_matrix(all_labels, all_preds)
metrics["confusion_matrix"] = cm.tolist()
return metrics
class gRNAde(InverseFolding):
"""This class is a subclass of InverseFolding and is used to train a model on the gRNAde dataset."""
# everything is inherited except for process and splitter.
def __init__(self, **kwargs):
self.splits = {
"train": [],
"test": [],
"val": [],
"all": [],
# Use sets instead of lists for chains since order doesn't matter
"pdb_to_chain_train": defaultdict(set),
"pdb_to_chain_test": defaultdict(set),
"pdb_to_chain_val": defaultdict(set),
"pdb_to_chain_all": defaultdict(set),
}
# Populate the structure
data_dir = Path(os.path.dirname(os.path.abspath(__file__))) / "data"
for split in ["train", "test", "val"]:
file_path = data_dir / f"{split}_ids_das.txt"
with open(file_path, "r") as f:
for line in f:
line = line.strip()
pdb_id = line.split("_")[0].lower()
chain = line.split("_")[-1] # .upper()
chain_components = list(chain.split("-"))
# [c.upper() for c in chain.split("-")]
if pdb_id not in self.splits[split]:
self.splits[split].append(pdb_id)
if pdb_id not in self.splits["all"]:
self.splits["all"].append(pdb_id)
# Using update for sets automatically handles duplicates
self.splits[f"pdb_to_chain_{split}"][pdb_id].update(chain_components)
self.splits["pdb_to_chain_all"][pdb_id].update(chain_components)
super().__init__(**kwargs)
@property
def default_splitter(self):
train_names = [
f"{pdb.lower()}_{chain}" # .upper()
for pdb in self.splits["pdb_to_chain_train"]
for chain in self.splits["pdb_to_chain_train"][pdb]
]
val_names = [
f"{pdb.lower()}_{chain}" # .upper()
for pdb in self.splits["pdb_to_chain_val"]
for chain in self.splits["pdb_to_chain_val"][pdb]
]
test_names = [
f"{pdb.lower()}_{chain}" # .upper()
for pdb in self.splits["pdb_to_chain_test"]
for chain in self.splits["pdb_to_chain_test"][pdb]
]
return NameSplitter(train_names, val_names, test_names)
def process(self) -> RNADataset:
"""
Process the dataset in batches to avoid memory issues.
Returns a filtered and processed RNADataset.
"""
name_filter = NameFilter(self.splits["train"] + self.splits["test"] + self.splits["val"])
chain_filter = ChainFilter(self.splits["pdb_to_chain_all"])
filters = ComposeFilters([name_filter, chain_filter])
annote_dummy = DummyAnnotator()
split_chain = ChainSplitTransform()
add_name_chains = ChainNameTransform()
# Initialize dataset with in_memory=False to avoid loading everything at once
source_dataset = RNADataset(debug=self.debug, redundancy="all", in_memory=False)
all_rnas = []
os.makedirs(self.dataset_path, exist_ok=True)
import tqdm
for rna in tqdm.tqdm(source_dataset):
if filters.forward(rna):
rna = annote_dummy(rna)
rna_chains = split_chain(rna) # Split by chain
renamed_chains = list(add_name_chains(rna_chains)) # Rename
for rna_chain in renamed_chains:
rna_chain = rna_chain["rna"]
if self.in_memory:
all_rnas.append(rna_chain)
else:
all_rnas.append(rna_chain.name)
dump_json(os.path.join(self.dataset_path, f"{rna_chain.name}.json"), rna_chain)
if self.in_memory:
dataset = RNADataset(rnas=all_rnas)
else:
dataset = RNADataset(dataset_path=self.dataset_path, rna_id_subset=all_rnas)
return dataset