Source code for rnaglib.dataset_transforms.similarity_splitter

"""Various splitters taking similarity into account for RNA tasks."""

from collections import Counter
from collections.abc import Iterable

import numpy as np
from scipy.sparse.csgraph import connected_components

from rnaglib.dataset_transforms import Splitter
from rnaglib.dataset_transforms.linear_optimisation import assign_clusters
from rnaglib.dataset_transforms.splitting_utils import label_counter


[docs] class ClusterSplitter(Splitter): """Abstract class for splitting by clustering with a similarity function. :param float similarity_threshold: similarity threshold (using similarity defined as 1-distance) above which two RNAs will be clustered in the same cluster (default 0.5) :param int n_jobs: number of jobs (for parallelization) (if set to -1, use the maximum number of cores)(default -1) :param int seed: seed for shuffling (default 0) :param bool balanced: whether to used balanced clusters (default True) :param str distance_name: name of the distance metric to use to perform clustering (must have been computed for this dataset, see DistanceComputer if it hasn't) (default "USalign") :param bool verbose: whether to display messages (default False) """
[docs] def __init__( self, similarity_threshold: float = 0.5, n_jobs: int = -1, seed: int = 0, balanced: bool = True, distance_name: str = "USalign", verbose=False, *args, **kwargs, ): self.similarity_threshold = similarity_threshold self.n_jobs = n_jobs self.seed = seed self.balanced = balanced self.distance_name = distance_name self.verbose = verbose super().__init__(**kwargs)
def forward(self, dataset): print(f"pre cluster len: {len(dataset)}") clusters = self.cluster_split(dataset, frac=0, split=False) if self.balanced and not self.debug: _, label_counts = label_counter(dataset) else: label_counts = None # print(f"dataset:{dataset}") # print(f"label_counts:{label_counts}") named_clusters = [] for cluster in clusters: named_clusters.append([dataset[i]["rna"].name for i in cluster]) # print(f"names:{named_clusters}") # print(f"clusters: {clusters}") train, val, test = self.balancer( named_clusters, label_counts, dataset, (self.split_train, self.split_valid, self.split_test), ) return train, val, test def balancer(self, clusters, label_counts, dataset, fracs, n=0.2): """Splits clusters into train, val, test keeping into account label balance. Fracs is a tuple of fractions to get the right proportions. Dataset needs to be passed since the cluster indices apply to keep_dataset, not necessarily the original one. """ balanced = self.balanced if not self.debug else 0 if balanced: print("Computing balanced clusters...") # Here we need to choose from clusters keeping labels in account. # Like Plinder, we should (potentially) make sure that singleton # clusters don't go into test in a second step. # First, we need to know what the label balance is labelcounts = [] for cluster in clusters: # Summing all the label counts from each element of the cluster # print(f"cluster:{cluster}") labelcount = sum([label_counts[i] for i in cluster], Counter()) # print(f"labelcount:{labelcount}") labelcounts.append(labelcount) # overall_counts = reduce(lambda x, y: x + y, labelcounts) # print(f"overall_counts:{overall_counts}") # print(f"balanced:{self.balanced}") else: labelcounts = [Counter({0: len(c)}) for c in clusters] train, val, test, metrics = assign_clusters( clusters, labelcounts, split_ratios=fracs, label_weight=int(balanced), verbose=self.verbose ) print("Done.") # print(f"metrics:{metrics}") return ( [x for x in range(len(dataset)) if dataset[x]["rna"].name in sum(train, [])], [x for x in range(len(dataset)) if dataset[x]["rna"].name in sum(val, [])], [x for x in range(len(dataset)) if dataset[x]["rna"].name in sum(test, [])], ) def cluster_split( self, dataset: Iterable, frac: float, n: float = 0.05, split: bool = True, ): """Fast cluster-based splitting adapted from ProteinShake. (https://github.com/BorgwardtLab/proteinshake_release/blob/main/structure_split.py). Splits the dataset into two splits, with the guarantee that no two points above ``similarity_threshold`` of each other belong to the same split. Computes a similarity matrix used to identify redundant clusters based on the ``similarity_threshold``. To split the dataset, we iterate over a pool of data points until the desired size of the test set is reached. The pool initially consists of the whole dataset. At each step, we choose a random point from the pool and fetch all points from the pool with similarity above ``similarity_threshold``, we call this the current cluster. If the cluster contains more than ``test_size * n`` points, we sub-sample the cluster. If the cluster would make the test set larger than ``test_size`` we sub-sample it to the difference between the current test set and ``test_size``. We then remove the current cluster from the pool and add it to the test set. Points that remain in the pool are kept as the training set. :param dataset: dataset to split :param frac: fraction of dataset to use as the test set :param n: portion of the test set size to use as largest test set cluster size :param split: if split is False, we return all clusters instead of splitting them """ if dataset.distances is not None: if not self.distance_name in dataset.distances: raise ValueError(f"The distance matrix using distances {self.distance_name} has not been computed") similarity_matrix = 1 - dataset.distances[self.distance_name] adjacency_matrix = (similarity_matrix >= self.similarity_threshold).astype(int) n_components, labels = connected_components(adjacency_matrix) neighbors = [] for i in range(n_components): neighborhood = np.where(labels == i)[0].tolist() neighbors.append(neighborhood) print(f"We have {len(neighbors)} clusters.") return neighbors