Source code for rnaglib.dataset_transforms.structure_distance_computer

import os
import itertools
import tempfile
from pathlib import Path

import numpy as np
from joblib import Parallel, delayed
from tqdm import tqdm

from .distance_computer import DistanceComputer
from rnaglib.utils import (
    US_align_wrapper,
    rna_align_wrapper,
    clean_mmcif,
)
from rnaglib.utils.misc import filter_cif_with_res
from rnaglib.utils.graph_io import get_default_download_dir


[docs] class StructureDistanceComputer(DistanceComputer): """Distance computer computing a structure-based pairwise distance between RNAs from a dataset :param str name: the name identifying the distance metric :param bool use_substructures: whether to filter cif files of RNAs to remove the residues which are not present in the dataset (default True) :param structures_path: path to the directory where the structures are stored (as cif files) :param int n_jobs: number of jobs (for parallelization) (if set to -1, use the maximum number of cores)(default -1) """
[docs] def __init__( self, name: str = "USalign", use_substructures: bool = True, structures_path: Path = None, n_jobs: int = -1, **kwargs, ): self.name = name self.use_substructures = use_substructures self.structures_path = structures_path self.n_jobs = n_jobs super().__init__(name=self.name, **kwargs)
def forward(self, dataset): """Computes pairwise structural similarity between all pairs of RNAs with rna-align. Stalls with RNAs > 200 nts. :param dataset: RNA dataset to compute similarity over. :returns np.array: Array of pairwise similarities in order of given dataset. """ if self.name not in ["RNAalign", "USalign"]: raise ValueError("name must be 'RNAalign' or 'USalign'") # set default structures dir if a specific directory wasn't specified by the user if self.structures_path is None: dirname = get_default_download_dir() self.structures_path = os.path.join(dirname, "structures") with tempfile.TemporaryDirectory() as tmpdir: print("dumping structures...") # tmpdir = 'debug_persistent' os.makedirs(tmpdir, exist_ok=True) all_pdb_path = [] for idx, rna in tqdm(enumerate(dataset), total=len(dataset)): rna_graph = rna["rna"] cif_path = Path(self.structures_path) / f"{rna_graph.graph['pdbid'].lower()}.cif" if self.use_substructures: reslist = [(n.split(".")[1], int(n.split(".")[2])) for n in rna["rna"].nodes()] new_cif = os.path.join(tmpdir, f"{rna_graph.name}.cif") filter_cif_with_res(cif_path, reslist, new_cif) all_pdb_path.append(new_cif) else: clean_path = Path(tmpdir) / f"{rna_graph.name}.cif" clean_mmcif(cif_path, clean_path) all_pdb_path.append(clean_path) todo = list(itertools.combinations(all_pdb_path, 2)) if self.name == "USalign": sims = Parallel(n_jobs=self.n_jobs)( delayed(US_align_wrapper)(pdbid1, pdbid2) for pdbid1, pdbid2 in tqdm(todo, total=len(todo), desc="USalign") ) elif self.name == "RNAalign": sims = Parallel(n_jobs=self.n_jobs)( delayed(rna_align_wrapper)(pdbid1, pdbid2) for pdbid1, pdbid2 in tqdm(todo, total=len(todo), desc="RNAalign") ) sim_mat = np.zeros((len(all_pdb_path), len(all_pdb_path))) sim_mat[np.triu_indices(len(all_pdb_path), 1)] = sims sim_mat += sim_mat.T np.fill_diagonal(sim_mat, 1) row_nan_count = np.isnan(sim_mat).sum(axis=1) # find rnas that failed against all others keep_idx = np.where(row_nan_count != sim_mat.shape[0] - 1)[0] sim_mat = sim_mat[keep_idx][:, keep_idx] keep_dataset_names = [rna["rna"].name for i, rna in enumerate(dataset) if i in keep_idx] return sim_mat, keep_dataset_names