Source code for rnaglib.dataset_transforms.redundancy_remover

import numpy as np
from scipy.sparse.csgraph import connected_components

from rnaglib.dataset_transforms import DSTransform


[docs] class RedundancyRemover(DSTransform): """Dataset transform removing redundancy in a dataset by performing clustering on the dataset then keeping only the RNA with the highest resolution within each cluster :param str distance_name: the name of the distance metric which has to be used to perform clustering. The distance must have been computed on the dataset (see DistanceComputer) :param float threshold: the similarity threshold (considering similarity as 1-distance) to use to perform clustering """
[docs] def __init__( self, distance_name: str = "USalign", threshold: float = 0.95, ): self.distance_name = distance_name self.threshold = threshold
def __call__(self, dataset): """"Removes redundancy to a specific dataset following the parameters specified in the RedundancyRemover object :return: the dataset with redundancy removed according to specified criteria :rtype: RNADataset """ if dataset.distances is None or not self.distance_name in dataset.distances: raise ValueError(f"The distance matrix using distances {self.distance_name} has not been computed") adjacency_matrix = (dataset.distances[self.distance_name] <= 1 - self.threshold).astype(int) n_components, labels = connected_components(adjacency_matrix) neighbors = [] for i in range(n_components): neighborhood = np.where(labels == i)[0].tolist() neighbors.append(neighborhood) final_list_ids = [] for neighborhood in neighbors: highest_resolution = 100 highest_resolution_idx = neighborhood[0] for rna_idx in neighborhood: rna_dict = dataset[rna_idx] try: resolution = rna_dict['rna'].graph['resolution_high'] if resolution < highest_resolution: highest_resolution = resolution highest_resolution_idx = rna_idx except: continue final_list_ids.append(highest_resolution_idx) dataset = dataset.subset(list_of_ids=final_list_ids) return dataset