Source code for rnaglib.transforms.represent.graph

import torch
import networkx as nx

from rnaglib.config.graph_keys import GRAPH_KEYS, TOOL
from rnaglib.algorithms import fix_buggy_edges

from .representation import Representation


[docs]class GraphRepresentation(Representation): """ Converts RNA into a Leontis-Westhof graph (2.5D) where nodes are residues and edges are either base pairs or backbones. Base pairs are annotated with the Leontis-Westhof classification for canonical and non-canonical base pairs. """
[docs] def __init__( self, clean_edges=True, framework="nx", edge_map=GRAPH_KEYS["edge_map"][TOOL], etype_key="LW", **kwargs, ): authorized_frameworks = {"nx", "dgl", "pyg"} assert framework in authorized_frameworks, ( f"Framework {framework} not supported for this representation. " f"Choose one of {authorized_frameworks}." ) self.framework = framework self.clean_edges = clean_edges self.etype_key = etype_key self.edge_map = edge_map super().__init__(**kwargs) pass
def __call__(self, rna_graph, features_dict): if self.clean_edges: base_graph = fix_buggy_edges(graph=rna_graph) else: base_graph = rna_graph if self.framework == "nx": return self.to_nx(base_graph, features_dict) if self.framework == "dgl": return self.to_dgl(base_graph, features_dict) if self.framework == "pyg": return self.to_pyg(base_graph, features_dict) def to_nx(self, graph, features_dict): # Get Edge Labels edge_type = {(u, v): self.edge_map[data[self.etype_key]] for u, v, data in graph.edges(data=True)} nx.set_edge_attributes(graph, name="edge_type", values=edge_type) # Add features and targets for name, encoding in features_dict.items(): nx.set_node_attributes(graph, name=name, values=encoding) return graph def to_dgl(self, graph, features_dict): import dgl nx_graph = self.to_nx(graph, features_dict) # Careful ! When doing this, the graph nodes get sorted. g_dgl = dgl.from_networkx(nx_graph=nx_graph, edge_attrs=["edge_type"], node_attrs=features_dict.keys()) return g_dgl def to_pyg(self, graph, features_dict): from torch_geometric.data import Data # for some reason from_networkx is not working so doing by hand # not super efficient at the moment node_map = {n: i for i, n in enumerate(sorted(graph.nodes()))} x, y = None, None if "nt_features" in features_dict: x = ( torch.stack([features_dict["nt_features"][n] for n in node_map.keys()]) if "nt_features" in features_dict else None ) if "nt_targets" in features_dict: list_y = [features_dict["nt_targets"][n] for n in node_map.keys()] # In the case of single target, pytorch CE loss expects shape (n,) and not (n,1) # For multi-target cases, we stack to get (n,d) if len(list_y[0]) == 1: y = torch.cat(list_y) else: y = torch.stack(list_y) if "rna_targets" in features_dict: y = torch.tensor(features_dict["rna_targets"]) edge_index = [[node_map[u], node_map[v]] for u, v in sorted(graph.edges())] edge_index = torch.tensor(edge_index, dtype=torch.long).T edge_attrs = [self.edge_map[data[self.etype_key]] for u, v, data in sorted(graph.edges(data=True))] edge_attrs = torch.tensor(edge_attrs) return Data(x=x, y=y, edge_attr=edge_attrs, edge_index=edge_index) @property def name(self): return "graph" def batch(self, samples): """ Batch a list of graph samples :param samples: A list of the output from this representation :return: a batched version of it. """ if self.framework == "nx": return samples if self.framework == "dgl": import dgl batched_graph = dgl.batch([sample for sample in samples]) return batched_graph if self.framework == "pyg": from torch_geometric.data import Batch batch = Batch.from_data_list(samples) # sometimes batching changes dtype from int to float32? batch.edge_index = batch.edge_index.to(torch.int64) batch.edge_attr = batch.edge_attr.to(torch.int64) return batch