Source code for rnaglib.transforms.filter.filters

from typing import Iterator, Any, Callable
import requests

import networkx as nx
from rnaglib.transforms import FilterTransform

""" Filters return a boolean after receiving an RNA.
This can be used to exclude RNAs from a datataset based on some
desired conditione.
"""


class DummyFilter(FilterTransform):
    """Always return True"""

    def forward(self, rna_dict: dict) -> bool:
        return True


[docs] class SizeFilter(FilterTransform): """Reject RNAs that are not in the given size bounds. :param min_size: smallest allowed number of residues :param max_size: largest allowed number of residues. Default -1 which means no upper bound. """
[docs] def __init__(self, min_size: int = 0, max_size: int = -1, **kwargs): self.min_size = min_size self.max_size = max_size super().__init__(**kwargs)
def forward(self, rna_dict: dict) -> bool: n = len(rna_dict["rna"].nodes()) if self.max_size == -1: return n > self.min_size else: return n > self.min_size and n < self.max_size
[docs] class RNAAttributeFilter(FilterTransform): """Reject RNAs that lack a certain annotation at the whole RNA level. :param attribute: which RNA-level attribute to look for. """
[docs] def __init__(self, attribute: str, value_checker: Callable = None, **kwargs): self.attribute = attribute self.value_checker = value_checker super().__init__(**kwargs) pass
def forward(self, data: dict): try: val = data["rna"].graph[self.attribute] except KeyError: return False else: return self.value_checker(val)
[docs] class ResidueAttributeFilter(FilterTransform): """Reject RNAs that lack a certain annotation at the whole residue-level. :param attribute: which node-level attribute to look for. :param aggregation_mode: str (either "aggfunc" or "min_valid"); if set to "aggfunc", keeps an RNA if the output of the aggregation function of the residue attribute at the RNA level passes the value_checker; if set to "min_valid", keeps an RNA if more than min_valid nodes pass the value_checker :param value_checker: function with accepts the value of the desired attribute and returns True/False :param aggfunc: function to aggregate the residue labels at the RNA level (only if aggregarion_mode is "aggfunc") :param min_valid: minium number of valid nodes that pass the filter for keeping the RNA. (only if aggregation_mode is "min_valid") Example --------- Keep RNAs with at least 1 chemically modified residue:: >>> from rnaglib.dataset import RNADataset >>> from rnaglib.transforms import ResidueAttributeFilter >>> dset = RNADataset(debug=True) >>> t = ResidueAttributeFilter(attribute='is_modified', value_checker: lambda val: val == True, min_valid=1) >>> len(dset) >>> rnas = list(t(dset)) >>> len(rnas) """
[docs] def __init__( self, attribute: str, aggregation_mode: str = "min_valid", value_checker: Callable = None, min_valid: int = 1, aggfunc: Callable = None, **kwargs, ): self.attribute = attribute self.aggregation_mode = aggregation_mode self.min_valid = min_valid self.aggfunc = aggfunc self.value_checker = value_checker super().__init__(**kwargs) pass
def forward(self, data: dict): n_valid = 0 g = data["rna"] if self.aggregation_mode == "aggfunc": vals_list = [] for node, ndata in g.nodes(data=True): try: val = ndata[self.attribute] except KeyError: continue else: if self.aggregation_mode == "min_valid" and self.value_checker(val): n_valid += 1 elif self.aggregation_mode == "aggfunc": vals_list.append(val) if self.aggregation_mode == "min_valid" and n_valid >= self.min_valid: return True if self.aggregation_mode == "min_valid": return False else: return self.aggfunc(vals_list)
[docs] class ResidueNameFilter(FilterTransform): """ Filter RNAs based on their residuess' names. This filter keeps only the RNAs such that a minimal number of their residues' names match a specific criterion. :param Callable value_checker: a method taking as input an RNA residue name and returning a boolean defining the filter's criterion on the residues' names (default None) :param int min_valid: the minimal number of residues within an RNA which have to match the above defined criterion so that the RNA is kept by the filter """
[docs] def __init__( self, value_checker: Callable = None, min_valid: int = 1, **kwargs, ): self.min_valid = min_valid self.value_checker = value_checker super().__init__(**kwargs) pass
def forward(self, data: dict): """ Check if the RNA's name is in the list of allowed names. :param data: Dictionary containing RNA data. :return: True if the RNA's name is in the allowed list, False otherwise. """ n_valid = 0 g = data["rna"] for node, ndata in g.nodes(data=True): if self.value_checker(node): n_valid += 1 if n_valid >= self.min_valid: return True return False
[docs] class RibosomalFilter(FilterTransform): """Remove RNA if ribosomal""" ribosomal_keywords = ["ribosomal", "rRNA", "50S", "30S", "60S", "40S"]
[docs] def __init__(self, **kwargs): super().__init__(**kwargs) pass
def forward(self, data: dict): pdbid = data["rna"].graph["pdbid"][0] url = f"https://data.rcsb.org/rest/v1/core/entry/{pdbid}" response = requests.get(url) data = response.json() # Check title and description title = data.get("struct", {}).get("title", "").lower() if any(keyword in title for keyword in self.ribosomal_keywords): return False # Check keywords keywords = data.get("struct_keywords", {}).get("pdbx_keywords", "").lower() if any(keyword in keywords for keyword in self.ribosomal_keywords): return False # Check polymer descriptions (for RNA and ribosomal proteins) for polymer in data.get("polymer_entities", []): description = polymer.get("rcsb_polymer_entity", {}).get("pdbx_description", "").lower() if any(keyword in description for keyword in self.ribosomal_keywords): return False return True
[docs] class NameFilter(FilterTransform): """ Filter RNAs based on their names. This filter keeps only the RNAs whose names are present in the provided list. :param names: A list of RNA names to keep. """
[docs] def __init__(self, names: list, **kwargs): self.names = {name.lower() for name in names} super().__init__(**kwargs)
def forward(self, data: dict) -> bool: """ Check if the RNA's name is in the list of allowed names. :param data: Dictionary containing RNA data. :return: True if the RNA's name is in the allowed list, False otherwise. """ rna_name = data["rna"].name return rna_name in self.names
[docs] class ChainFilter(FilterTransform): """ Filter RNAs based on valid chain names for each structure. Keeps any RNA with at least one residue having a valid chain name, and removes residues with invalid chain names from kept RNAs. :param valid_chains_dict: Dictionary mapping structure names to lists of valid chain names. """
[docs] def __init__(self, valid_chains_dict: dict, **kwargs): self.valid_chains_dict = { pdb.lower(): [chain for chain in chains] for pdb, chains in valid_chains_dict.items() # .upper() } super().__init__(**kwargs)
def forward(self, data: dict) -> bool: g = data["rna"] structure_name = g.name valid_chains = set(self.valid_chains_dict.get(structure_name, [])) nodes_to_remove = [] has_valid_node = False for node, ndata in g.nodes(data=True): chain_name = node.split(".")[1] # .upper() if chain_name in valid_chains: has_valid_node = True else: nodes_to_remove.append(node) if has_valid_node: # Remove nodes with invalid chain names g.remove_nodes_from(nodes_to_remove) return True else: return False def __repr__(self) -> str: return f"{self.__class__.__name__}(valid_chains_dict={self.valid_chains_dict})"
[docs] class ResolutionFilter(RNAAttributeFilter): """Filters RNA based on their resolution. Only keeps RNAs which resolution is less than a certain threshold. :param float resolution_threshold: resolution (in Angstroms) below which the RNA will be kept and above which it will be discarded """
[docs] def __init__(self, resolution_threshold: float, **kwargs): def value_checker(val): try: return float(val[0]) < resolution_threshold except: return False super().__init__(attribute="resolution_high", value_checker=value_checker, **kwargs)