"""Cast annotations to feature tensors."""
from typing import Dict, Union, List, TYPE_CHECKING, Literal
import torch
import networkx as nx
from rnaglib.config import NODE_FEATURE_MAP, EDGE_FEATURE_MAP
from rnaglib.transforms import Transform
[docs]
class FeaturesComputer(Transform):
"""
This class takes as input an RNA in the networkX form
and computes the ``features_dict`` which maps node IDs to a tensor of features.
The ``features_dict`` contains keys: ``'nt_features'``for node features,
``'nt_targets'`` for node-level prediction targets. In :class:`~rnaglib.dataset.RNADataset` construction,
the ``FeaturesComputer.compute_features()`` method is called during the ``RNADataset`` ``__getitem__()`` call.
:param nt_features: List of keys to use as node (nucleotide) features and are meant to be inputs of the ML task, choose from the `dataset[i]['rna']` node attributes dictionary.
:param nt_targets: List of keys to use as node (nucleotide) features and are meant to be outputs of the ML task, choose from the `dataset[i]['rna']` node attributes dictionary.
:param rna_features: List of keys to use as graph features of graphs representing the whole RNA and are meant to be inputs of the ML task
:param rna_targets: List of keys to use as graph features of graphs representing the whole RNA and are meant to be outputs of the ML task
:param bp_features: List of keys to use as graph features of graphs representing an RNA binding pocket and are meant to be inputs of the ML task
:param bp_targets: List of keys to use as graph features of graphs representing an RNA binding pocket and are meant to be outputs of the ML task
:param extra_useful_keys: List of keys that are not RNA, nucleotide or binding pocket features ir targets but must be preserved when applying the FeaturesComputer to the dataset
:param dict custom_encoders: Dictionary of the form {feature_name : encoder}
"""
[docs]
def __init__(
self,
nt_features: Union[List, str] = None,
nt_targets: Union[List, str] = None,
rna_features: Union[List, str] = None,
rna_targets: Union[List, str] = None,
bp_features: Union[List, str] = None,
bp_targets: Union[List, str] = None,
extra_useful_keys: Union[List, str] = None,
custom_encoders: dict = None,
):
self.rna_features_parser = self.build_feature_parser(rna_features, custom_encoders=custom_encoders)
self.rna_targets_parser = self.build_feature_parser(rna_targets, custom_encoders=custom_encoders)
self.node_features_parser = self.build_feature_parser(nt_features, custom_encoders=custom_encoders)
self.node_targets_parser = self.build_feature_parser(nt_targets, custom_encoders=custom_encoders)
# This is only useful when using a FeatureComputer to create a dataset, and avoid removing important features
# of the graph that are not used during loading
self.extra_useful_keys = extra_useful_keys
# experimental
self.rna_features = rna_features
self.rna_targets = rna_targets
self.bp_features = bp_features
self.bp_targets = bp_targets
def add_feature(
self,
feature_names=None,
custom_encoders=None,
input_feature=True,
feature_level: Literal["rna", "residue"] = "residue",
):
"""
Update the input/output feature selector with either an extra available named feature or a custom encoder
:param feature_names: Name of the input feature to add
:param dict custom_encoders: Dictionary of the form {feature_name : encoder}
:param input_feature: Set to true to modify the input feature encoder, false for the target one
:param feature_level: If featureis RNA-level ('rna`) or residue-level (`residue`)
:return: None
"""
# Select the right node_parser and update it
if feature_level == "residue":
old_parser = self.node_features_parser if input_feature else self.node_target_parser
elif feature_level == "rna":
old_parser = self.rna_features_parser if input_feature else self.rna_target_parser
else:
raise ValueError(f"Invalid feature level {feature_level}, must be 'rna' or 'residue'")
new_parser = self.build_feature_parser(asked_features=feature_names, custom_encoders=custom_encoders)
old_parser.update(new_parser)
def remove_feature(self, feature_name=None, input_feature=True):
"""
Update the input/output feature selector with either an extra available named feature or a custom encoder
:param feature_name: Name of the input feature to remove
:param input_feature: Set to true to modify the input feature encoder, false for the target one
:return: None
"""
if not isinstance(feature_name, list):
feature_name = [feature_name]
# Select the right node_parser and update it
node_parser = self.node_features_parser if input_feature else self.node_target_parser
filtered_node_parser = {k: node_parser[k] for k in node_parser if not k in feature_name}
if input_feature:
self.node_features_parser = filtered_node_parser
else:
self.node_target_parser = filtered_node_parser
@staticmethod
def compute_dim(node_parser):
"""
Based on the encoding scheme, we can compute the shapes of the in and out tensors
:param node_parser: dictionary of the form {feature_name : encoder}
:return:
"""
if len(node_parser) == 0:
return 0
all_node_feature_encoding = list()
for i, (feature, feature_encoder) in enumerate(node_parser.items()):
node_feature_encoding = feature_encoder.encode_default()
all_node_feature_encoding.append(node_feature_encoding)
all_node_feature_encoding = torch.cat(all_node_feature_encoding)
return len(all_node_feature_encoding)
@property
def input_dim(self):
return self.compute_dim(self.node_features_parser)
@property
def output_dim(self):
return self.compute_dim(self.node_target_parser)
def remove_useless_keys(self, rna_graph):
"""
Copy the original graph to only retain keys relevant to this FeaturesComputer
:param rna_graph:
:return: The graph with only keys which are either features, targets or extra useless keys according to the arguments of FeaturesComputer object
"""
useful_keys = set(self.node_features_parser.keys()).union(set(self.node_target_parser.keys()))
if self.extra_useful_keys is not None:
useful_keys = useful_keys.union(set(self.extra_useful_keys))
cleaned_graph = nx.DiGraph(name=rna_graph.name)
cleaned_graph.add_edges_from(rna_graph.edges(data=True))
for key in useful_keys:
val = nx.get_node_attributes(rna_graph, key)
nx.set_node_attributes(cleaned_graph, name=key, values=val)
return cleaned_graph
@staticmethod
def encode_rna(g: nx.Graph, parser):
"""
Simply apply the rna encoding functions in ``parser`` for all features.
Then use torch.cat over the result to get a tensor for each node in the graph.
:param g: a nx graph
:param node_parser: dictionary of the form {feature_name : encoder}
:return: A dict that maps nodes to encodings
"""
if len(parser) == 0:
return None
all_feature_encoding = list()
for i, (feature, feature_encoder) in enumerate(parser.items()):
try:
feature_encoding = feature_encoder.encode(g.graph[feature])
except KeyError:
feature_encoding = feature_encoder.encode_default()
all_feature_encoding.append(feature_encoding)
encodings = torch.cat(all_feature_encoding) if len(all_feature_encoding) > 1 else all_feature_encoding[0]
return encodings
@staticmethod
def encode_nodes(g: nx.Graph, node_parser):
"""
Simply apply the node encoding functions in node_parser to each node in the graph
Then use torch.cat over the result to get a tensor for each node in the graph.
:param g: a nx graph
:param node_parser: dictionary of the form {feature_name : encoder}
:return: A dict that maps nodes to encodings
"""
node_encodings = {}
if len(node_parser) == 0:
return None
for node, attrs in g.nodes.data():
all_node_feature_encoding = list()
for i, (feature, feature_encoder) in enumerate(node_parser.items()):
try:
node_feature = attrs[feature]
node_feature_encoding = feature_encoder.encode(node_feature)
except KeyError:
node_feature_encoding = feature_encoder.encode_default()
all_node_feature_encoding.append(node_feature_encoding)
node_encodings[node] = torch.cat(all_node_feature_encoding)
return node_encodings
def build_feature_parser(
self,
asked_features: Union[List, str] = None,
custom_encoders: dict = None,
feature_map: dict = None,
) -> dict:
"""
This function will load the predefined feature maps available globally.
Then for each of the features in 'asked feature', it will return an encoder object for each of the asked features
in the form of a dict {asked_feature : EncoderObject}
If some keys don't exist, will raise an Error. However if some keys are present but problematic,
this will just cause a printing of the problematic keys
:param asked_features: A list of string keys that are present in the encoder
:param custom_encoders: Dictionary of the form {feature_name : encoder}
:param feature_map: Dictionary mapping feature key to an Encoder() object.
:return: A dict {asked_feature : EncoderObject}
"""
if asked_features is None:
return {}
# default to node-feature map
if feature_map is None:
feature_map = {**NODE_FEATURE_MAP, **EDGE_FEATURE_MAP}
else:
feature_map = feature_map.copy()
# Build an asked list of features, with no redundancies
asked_features = [] if asked_features is None else asked_features
if not isinstance(asked_features, list):
asked_features = [asked_features]
# Make a non redundant list that keeps the features order
nr_asked_feature = []
for item in asked_features:
if item not in nr_asked_feature:
nr_asked_feature.append(item)
# attach the transform's encoder
if custom_encoders is not None:
for feature, encoder in custom_encoders.items():
feature_map[feature] = encoder
# Update the map {key:encoder} and ensure every asked feature is in this encoding map.
if any([feature not in feature_map for feature in asked_features]):
problematic_keys = tuple([feature for feature in asked_features if feature not in feature_map])
raise ValueError(f"{problematic_keys} were asked as a feature or target but do not exist")
# Filter out None encoder functions, we don't know how to encode those...
encoding_features = [feature for feature in asked_features if feature_map[feature] is not None]
if len(encoding_features) < len(asked_features):
unencodable_keys = [feature for feature in asked_features if feature_map[feature] is None]
print(f"{unencodable_keys} were asked as a feature or target but do not exist")
# Finally, keep only the relevant keys to include in the encoding dict.
subset_dict = {k: feature_map[k] for k in encoding_features}
return subset_dict
def build_edge_feature_parser(self, asked_features=None):
raise NotImplementedError
def forward(self, rna_dict: Dict):
"""
Add 3 dictionaries to the `rna_dict` wich maps nts, edges, and the whole graph
to a feature vector each. The final converter uses these to include the data in the
framework-specific object.
"""
features_dict = {}
if len(self.rna_features_parser) > 0:
rna_feature_encoding = self.encode_rna(rna_dict["rna"], parser=self.rna_features_parser)
features_dict["rna_features"] = rna_feature_encoding
if len(self.rna_targets_parser) > 0:
rna_targets_encoding = self.encode_rna(rna_dict["rna"], parser=self.rna_targets_parser)
features_dict["rna_targets"] = rna_targets_encoding
# Get Node labels
if len(self.node_features_parser) > 0:
feature_encoding = self.encode_nodes(
rna_dict["rna"],
node_parser=self.node_features_parser,
)
features_dict["nt_features"] = feature_encoding
if len(self.node_targets_parser) > 0:
target_encoding = self.encode_nodes(
rna_dict["rna"],
node_parser=self.node_targets_parser,
)
features_dict["nt_targets"] = target_encoding
return features_dict