Source code for rnaglib.transforms.annotate.rnafm

import os
import sys
import json
from pathlib import Path
from typing import Dict, Tuple, List

import torch
import numpy as np
import networkx as nx
import torch

try:
    import fm
except ModuleNotFoundError:
    raise ModuleNotFoundError("Please make sure rna-fm is installed with pip install rna-fm")

from rnaglib.transforms import Transform
from rnaglib.encoders import ListEncoder
from rnaglib.algorithms import get_sequences


[docs] class RNAFMTransform(Transform): """Use the RNA-FM model to compute residue-level embeddings. Make sure rna-fm is installed by running ``pip install rna-fm``. Sets a node attribute to `'rnafm'` with a numpy array of the resulting embedding. Go `here <https://github.com/ml4bio/RNA-FM>`_ for the RNA-FM source code. :param chunking_strategy: how to process sequences longer than 1024. ``'simple'`` just splits into non-overlapping segments. :param chunk_size: size of chunks to use (default is 512) :param cache_path: a directory containing pre-computed npz embeddings :param expand_mean: True .. note:: Maximum size for basic RNA-FM model is 1024. If sequence is larger than 1024 we apply ``'chunking_strategy'`` to process the sequence. """ name = "rnafm" encoder = ListEncoder(640)
[docs] def __init__(self, chunking_strategy: str = "simple", chunk_size: int = 512, cache_path=None, expand_mean=True, verbose=False, debug=False, **kwargs ): # Load RNA-FM model if not debug: self.model, self.alphabet = fm.pretrained.rna_fm_t12() self.batch_converter = self.alphabet.get_batch_converter() self.chunking_strategy = chunking_strategy self.chunk_size = chunk_size self.model.eval() self.debug = debug self.cache_path = cache_path self.expand_mean = expand_mean self.verbose = verbose super().__init__(**kwargs)
def basic_chunking(self, seq): return [seq[i: i + self.chunk_size] for i in range(0, len(seq), self.chunk_size)] def chunk(self, seq_data: List[Tuple]) -> List[Tuple]: """Apply a chunking strategy to sequences longer than 1024.""" chunked = {} for chain_id, (seq, nodes) in seq_data.items(): if self.chunking_strategy == "simple": chunks = self.basic_chunking(list(zip(seq, nodes))) for i, chunk in enumerate(chunks): nodelist = [n for _, n in chunk] seq = "".join([s for s, _ in chunk]) chunked[chain_id + "_" + str(i)] = (seq, nodelist) return chunked def forward(self, rna_dict: Dict) -> Dict: if self.debug: dummy_feats = {node: [0 for _ in range(640)] for node in rna_dict["rna"].nodes()} nx.set_node_attributes(rna_dict["rna"], name=self.name, values=dummy_feats) return rna_dict chain_seqs = get_sequences(rna_dict["rna"], verbose=self.verbose) # Try to load the embs if possible. if self.cache_path is not None: if not self.quiet: print(f"Loading embeddings from {self.cache_path}.") chains = list(chain_seqs.keys()) for chain in chains: embs_path = Path(self.cache_path) / f"{chain}.npz" if embs_path.exists(): embs = np.load(embs_path) # If they are complete, remove from the chains to do and put in the graph if len(chain_seqs[chain][0]) == len(embs): nx.set_node_attributes(rna_dict["rna"], embs, self.name) chain_seqs.pop(chain) if len(chain_seqs) == 0: return rna_dict # Otherwise make the actual computations # Prepare data seq_data = self.chunk(chain_seqs) input_seqs = [(chain_id, seq) for chain_id, (seq, _) in seq_data.items()] batch_labels, batch_strs, batch_tokens = self.batch_converter(input_seqs) # Extract embeddings (on CPU) with torch.no_grad(): results = self.model(batch_tokens, repr_layers=[12]) token_embeddings = results["representations"][12] all_embs = [] for i, (_, (seq, node_ids)) in enumerate(seq_data.items()): Z = token_embeddings[i, : len(seq)] emb_dict = {n: Z[i].clone().detach() for i, n in enumerate(node_ids)} for i, n in enumerate(node_ids): z = Z[i].clone().detach() rna_dict["rna"].nodes[n][self.name] = list(z.numpy()) all_embs.append(z) # Add mean embedding to missing nodes, if self.expand mean if self.expand_mean: existing_nodes, _ = list(zip(*nx.get_node_attributes(rna_dict["rna"], self.name).items())) missing_nodes = set(rna_dict["rna"].nodes()) - set(existing_nodes) if len(missing_nodes) > 0: embs = torch.stack(all_embs, dim=0) mean_emb = torch.mean(embs, dim=0) missing_embs = {node: mean_emb.tolist() for node in missing_nodes} nx.set_node_attributes(rna_dict["rna"], name=self.name, values=missing_embs) return rna_dict