Source code for rnaglib.dataset_transforms.cd_hit
import itertools
from collections import defaultdict
import numpy as np
from tqdm import tqdm
from rnaglib.algorithms import get_sequences
from rnaglib.dataset_transforms import DistanceComputer
from rnaglib.utils import cdhit_wrapper
[docs]
class CDHitComputer(DistanceComputer):
[docs]
def __init__(
self,
similarity_threshold: float = 0.9,
**kwargs,
):
self.similarity_threshold = similarity_threshold
super().__init__(name="cd_hit", **kwargs)
def forward(self, dataset) -> tuple[np.array, list]:
"""Computes sequence similarity between all pairs of RNAs.
To deal with multi-chain RNAs we cluster all chains independently
using CD-Hit. For a given pair of multi-chained RNAs, their overall similarity score
is given by the Tanimoto coefficient of the sets of clusters assigned to each of the RNA's chains.
:param dataset: RNA dataset to compute similarity over.
:returns np.array: Array of pairwise similarities in order of given dataset.
"""
# prepare input for CD-Hit. One entry per consecutive chunk in sequence.
ids, sequences = [], []
for idx, rna in enumerate(dataset):
# Each chunk get a unique ID, starting with the "idx" of the corresponding RNA
seqs = get_sequences(rna["rna"], longest_only=False, min_size_return=5, verbose=False)
ids.extend([f"{idx}-{seq_id.replace('.', '-')}" for seq_id, (seq, _) in seqs.items()], )
sequences.extend([seq for _, (seq, _) in seqs.items()])
# For each chunk, get its cluster affectation
ids_to_cluster, cluster_to_ids = cdhit_wrapper(ids, sequences, sim_thresh=self.similarity_threshold)
# Group together chunks coming from one RNA
# TODO: this should be a Counter
idx_to_clusters = defaultdict(set)
idxs = set()
for seq_id, cluster_id in ids_to_cluster.items():
idx = seq_id.split("-")[0]
idxs.add(int(idx))
idx_to_clusters[int(idx)].add(cluster_id)
idxs = sorted(idxs)
# Compute an RNA-level pairwise distance by the clusters its chunks belong to
def tanimoto(set_1, set_2):
return len(set_1 & set_2) / len(set_1 | set_2)
def custom_tanimoto(set_1, set_2):
return len(set_1 & set_2) / min(len(set_1), len(set_2))
todo = list(itertools.combinations(idxs, 2))
sims = [
tanimoto(idx_to_clusters[rna_1], idx_to_clusters[rna_2])
for rna_1, rna_2 in tqdm(todo, desc="CD-Hit", total=len(todo))
]
sim_mat = np.zeros((len(idxs), len(idxs)))
sim_mat[np.triu_indices(len(idxs), 1)] = sims
sim_mat += sim_mat.T
np.fill_diagonal(sim_mat, 1)
keep_dataset_names = [dataset.all_rnas.inv[i] for i in idxs] if len(idxs) != len(dataset) \
else list(dataset.all_rnas)
return sim_mat, keep_dataset_names