"""Collection of functions operating on RNA graphs"""
import pickle
import os
from typing import Optional, Hashable, Dict, List, Tuple
from tqdm import tqdm
import networkx as nx
import numpy as np
from rnaglib.config.graph_keys import GRAPH_KEYS, TOOL
CANONICALS = GRAPH_KEYS["canonical"][TOOL]
VALID_EDGES = GRAPH_KEYS["edge_map"][TOOL].keys()
def multigraph_to_simple(g: nx.MultiDiGraph) -> nx.DiGraph:
"""Convert directed multi graph to simple directed graph.
When multiple edges are found between two nodes, we keep backbone.
"""
simple_g = nx.DiGraph()
backbone_types = ["B53", "B35"]
# first pass adds the backbones
for u, v, data in g.edges(data=True):
etype = data["LW"]
if etype in backbone_types:
simple_g.add_edge(u, v, **data)
pass
# second pass adds non-canonicals when no backbone exists
basepairs = []
for u, v, data in g.edges(data=True):
etype = data["LW"]
if etype not in backbone_types and not simple_g.has_edge(u, v):
basepairs.append((u, v, data))
simple_g.add_edges_from(basepairs)
simple_g.graph = g.graph.copy()
simple_g_nodes = set(simple_g.nodes())
simple_g_node_attrs = {k: v for k, v in dict(g.nodes(data=True)).items() if k in simple_g_nodes}
nx.set_node_attributes(simple_g, simple_g_node_attrs)
return simple_g
def reorder_nodes(g: nx.DiGraph) -> nx.DiGraph:
"""
Reorder nodes in graph according to default ``sorted()`` order.
:param g: Pass a graph for node reordering.
:type g: networkx.DiGraph
:return h: (nx DiGraph)
"""
reordered_graph = nx.DiGraph()
reordered_graph.add_nodes_from(sorted(g.nodes.data()))
reordered_graph.add_edges_from(g.edges.data())
for key, value in g.graph.items():
reordered_graph.graph[key] = value
return reordered_graph
def induced_edge_filter(graph: nx.DiGraph, roots: List[Hashable], depth: Optional[int] = 1) -> nx.DiGraph:
"""
Remove edges in graph introduced by the induced sugraph routine.
Only keep edges which fall within a single node's neighbourhood.
:param graph: networkx subgraph
:param roots: nodes to use for filtering
:param depth: size of neighbourhood to take around each node.
:returns clean_g: cleaned graph
"""
# a depth of zero does not make sense for this operation as it would remove
# all edges
if depth < 1:
depth = 1
neighbourhoods = []
flat_neighbors = set()
for root in roots:
root_neighbors = bfs(graph, [root], depth=depth)
neighbourhoods.append(root_neighbors)
flat_neighbors = flat_neighbors.union(root_neighbors)
flat_neighbors = list(flat_neighbors)
subgraph = graph.subgraph(flat_neighbors)
subgraph = subgraph.copy()
# graph_new = graph_new.subgraph(flat_neighbors)
kill = []
for u, v in subgraph.edges():
for nei in neighbourhoods:
if u in nei and v in nei:
break
else:
kill.append((u, v))
subgraph.remove_edges_from(kill)
return subgraph
def get_nc_nodes(graph: nx.DiGraph, depth: int = 4, return_index: bool = False) -> set:
"""
Returns indices of nodes in graph which have a non-canonical or
looping base in their neighbourhood.
:param graph: a networkx graph
:param depth: The depth up to which we consider nodes neighbors of a NC
:param return_index: If True, return the index in the list instead.
:return: set of nodes (or their index) in loops or that have a NC.
"""
keep = []
for i, node in enumerate(sorted(graph.nodes())):
to_keep = i if return_index else node
if graph.degree(node) == 2:
keep.append(to_keep)
elif has_NC_bfs(graph, node, depth=depth):
keep.append(to_keep)
else:
pass
return set(keep)
def nc_clean_dir(graph_dir, dump_dir):
"""
Copy graphs from graph_dir to dump_dir but copied graphs are
trimmed according to `get_nc_nodes_index`.
:param graph_dir: A directory that should contain networkx pickles.
:param dump_dir: The directory where to dump the trimmed graphs
"""
for g in tqdm(os.listdir(graph_dir)):
graph = nx.read_gpickle(os.path.join(graph_dir, g))
keep_nodes = get_nc_nodes(graph)
print(f">>> kept {len(keep_nodes)} nodes of {len(graph.nodes())}.")
kill_nodes = set(graph.nodes()) - keep_nodes
graph.remove_nodes_from(kill_nodes)
dangle_trim(graph)
if len(graph.nodes()) > 4:
nx.write_gpickle(graph, os.path.join(dump_dir, g))
def incident_nodes(graph, nodes):
"""
Returns set of nodes in $graph$ incident to input nodes.
:param graph: A networkx graph
:param nodes: set of nodes in graph
:return: set of nodes around the input the set of nodes according to the connectivity of the graph
"""
core = set(nodes)
hits = set()
for u, v in graph.edges():
if u in core and v not in core:
hits.add(v)
if u not in core and v in core:
hits.add(u)
return hits
def nx_to_dgl(graph, edge_map, label="label"):
"""
Networkx graph to DGL.
"""
import dgl
graph, _, ring = pickle.load(open(graph, "rb"))
edge_type = {edge: edge_map[lab] for edge, lab in (nx.get_edge_attributes(graph, label)).items()}
nx.set_edge_attributes(graph, name="edge_type", values=edge_type)
g_dgl = dgl.DGLGraph()
g_dgl.from_networkx(nx_graph=graph, edge_attrs=["edge_type"])
return g_dgl
def dgl_to_nx(graph, edge_map, label="label"):
import dgl
g = dgl.to_networkx(graph, edge_attrs=["edge_type"])
edge_map_r = {v: k for k, v in edge_map.items()}
nx.set_edge_attributes(
g,
{(n1, n2): edge_map_r[d["edge_type"].item()] for n1, n2, d in g.edges(data=True)},
label,
)
return g
def bfs_generator(graph, initial_node):
"""
Generator version of bfs given graph and initial node.
Yields nodes at next hop at each call.
:param graph: Nx graph
:param initial_node: single or iterable node
:param depth:
:return: The successive rings
"""
if isinstance(initial_node, list) or isinstance(initial_node, set):
previous_ring = [set(initial_node)]
else:
previous_ring = [set(initial_node)]
visited = set()
while len(visited) < len(graph):
depth_ring = set()
for n in previous_ring:
visited.add(n)
for nei in graph.neighbors(n):
if nei not in visited:
depth_ring.add(nei)
previous_ring = depth_ring
yield list(depth_ring)
[docs]
def bfs(graph, initial_nodes, nc_block=False, depth=2, label="label"):
"""
BFS from seed nodes given graph and initial node.
:param graph: Nx graph
:param initial_nodes: single or iterable node
:param depth: The number of hops to conduct from our roots
:return: list of nodes
"""
if isinstance(initial_nodes, list) or isinstance(initial_nodes, set):
total_nodes = [set(initial_nodes)]
else:
total_nodes = [set(initial_nodes)]
for d in range(depth):
depth_ring = set()
e_labels = set()
for n in total_nodes[d]:
for nei in graph.neighbors(n):
depth_ring.add(nei)
e_labels.add(graph[n][nei][label])
if nc_block and e_labels.issubset({"CWW", "B53", ""}):
break
else:
total_nodes.append(depth_ring)
total_nodes = set().union(*total_nodes)
return total_nodes
def remove_self_loops(graph):
"""
Remove all self loops connexions by modifying in place
:param graph: The graph to trim
:return: None
"""
graph.remove_edges_from([(n, n) for n in graph.nodes()])
def remove_non_standard_edges(graph, label="LW"):
"""
Remove all edges whose label is not in the VALID EDGE variable
:param graph: Nx Graph
:param label: The name of the labels to check
:return: the pruned graph, modifications are made in place
"""
remove = []
for n1, n2, d in graph.edges(data=True):
if d[label] not in VALID_EDGES:
remove.append((n1, n2))
graph.remove_edges_from(remove)
def to_orig(graph, label="LW"):
"""
Deprecated, used to include only the NC
:param graph:
:param label:
:return:
"""
H = nx.Graph()
for n1, n2, d in graph.edges(data=True):
if d[label] in VALID_EDGES:
assert d[label] != "B35"
H.add_edge(n1, n2, label=d[label])
for attrib in [
"mg",
"lig",
"lig_id",
"chemically_modified",
"pdb_pos",
"bgsu",
"carnaval",
"chain",
]:
graph_data = graph.nodes(data=True)
attrib_dict = {n: graph_data[n][attrib] for n in H.nodes()}
nx.set_node_attributes(H, attrib_dict, attrib)
remove_self_loops(H)
return H
def to_orig_all(graph_dir, dump_dir):
"""
Deprecated
:param graph_dir:
:param dump_dir:
:return:
"""
for g in tqdm(os.listdir(graph_dir)):
try:
graph = nx.read_gpickle(os.path.join(graph_dir, g))
except Exception as e:
print(f">>> failed on {g} with exception {e}")
continue
H = to_orig(graph)
nx.write_gpickle(H, os.path.join(dump_dir, g))
def find_node(graph, chain, pos):
"""
Get a node from its PDB identification
:param graph: Nx graph
:param chain: The PDB chain
:param pos: The PDB 'POS' field
:return: The node if it was found, else None
"""
for n, d in graph.nodes(data=True):
if (n[0] == chain) and (d["nucleotide"].pdb_pos == str(pos)):
return n
return None
def has_NC(graph, label="LW"):
"""
Does the input graph contain non canonical edges ?
:param graph: Nx graph
:param label: The label to use
:return: Boolean
"""
for n1, n2, d in graph.edges(data=True):
if d[label] not in CANONICALS:
return True
return False
def has_NC_bfs(graph, node_id, depth=2):
"""
Return True if node has NC in their neighbourhood.
:param graph: Nx graph
:param node_id: The nodes from which to start our search
:param depth: The number of hops to conduct from our roots
:return: Boolean
"""
subg = list(bfs(graph, node_id, depth=depth))
sG = graph.subgraph(subg).copy()
return has_NC(sG)
def floaters(graph):
"""
Try to connect floating base pairs. (Single base pair not attached
to backbone).
Otherwise remove.
:param graph: Nx graph
:return: trimmed graph
"""
deg_ok = lambda H, u, v, d: (H.degree(u) == d) and (H.degree(v) == d)
floaters = []
for u, v in graph.edges():
if deg_ok(graph, u, v, 1):
floaters.append((u, v))
graph.remove_edges_from(floaters)
return graph
def dangle_trim(graph):
"""
Recursively remove dangling nodes from graph, with in place modification
:param graph: Nx graph
:return: trimmed graph
"""
dangles = lambda graph: [n for n in graph.nodes() if graph.degree(n) < 2]
while dangles(graph):
graph.remove_nodes_from(dangles(graph))
return graph
def stack_trim(graph):
"""
Remove stacks from graph.
:param graph: Nx graph
:return: trimmed graph
"""
is_ww = lambda e, graph: "CWW" in [info["LW"] for node, info in graph[e].items()]
degree = lambda i, graph, nodelist: np.sum(nx.to_numpy_matrix(graph, nodelist=nodelist)[i])
cur_graph = graph.copy()
while True:
stacks = []
for n in cur_graph.nodes:
if cur_graph.degree(n) == 2 and is_ww(n, cur_graph):
# potential stack opening
partner = None
stacker = None
for node, info in cur_graph[n].items():
if info["label"] == "B53":
stacker = node
elif info["label"] == "CWW":
partner = node
else:
pass
if cur_graph.degree(partner) > 3:
continue
partner_2 = None
stacker_2 = None
for node, info in cur_graph[partner].items():
if info["label"] == "B53":
stacker_2 = node
elif info["label"] == "CWW":
partner_2 = node
try:
if cur_graph[stacker][stacker_2]["label"] == "CWW":
stacks.append(n)
stacks.append(partner)
except KeyError:
continue
if len(stacks) == 0:
break
else:
cur_graph.remove_nodes_from(stacks)
cur_graph = cur_graph.copy()
return cur_graph
def in_stem(graph, u, v):
"""
Find if two nodes are part of a stem and engage in NC interactions
:param graph: Nx graph
:param u: one graph node
:param v: one graph node
:return: Boolean
"""
non_bb = lambda graph, e: len([info["LW"] for node, info in graph[e].items() if info["LW"] not in CANONICALS])
is_ww = lambda graph, u, v: graph[u][v]["LW"] not in {"CWW", "cWW"}
if is_ww(graph, u, v) and (non_bb(graph, u) in (1, 2)) and (non_bb(graph, v) in (1, 2)):
return True
return False
def gap_fill(original_graph, graph_to_expand):
"""
If we subgraphed, get rid of all degree 1 nodes by completing them with one more hop
:param original_graph: nx graph
:param graph_to_expand: nx graph that needs to be expanded to fix dangles
:return: the expanded graph
"""
# while True:
new_nodes = list(graph_to_expand.nodes())
for n in graph_to_expand.nodes():
if graph_to_expand.degree(n) == 1:
new_nodes.append(graph_to_expand.neighbors(n))
res_graph = original_graph.subgraph(new_nodes).copy()
return res_graph
def symmetric_elabels(graph):
"""
Make edge labels symmetric for a graph.
:param graph: Nx graph
:return: Same graph but edges are now symmetric and calling undirected is straightforward.
"""
H = graph.copy()
new_e_labels = {}
for n1, n2, d in graph.edges(data=True):
old_label = d["label"]
if old_label not in ["B53", "B35"]:
new_label = old_label[0] + "".join(sorted(old_label[1:]))
else:
new_label = "B53"
new_e_labels[(n1, n2)] = new_label
nx.set_edge_attributes(H, new_e_labels, "label")
return H
def relabel_graphs(graph_dir, dump_path):
"""
Take graphs in graph_dir and dump symmetrized in dump_path.
"""
for g in tqdm(os.listdir(graph_dir)):
graph = nx.read_gpickle(os.path.join(graph_dir, g))
graph_new = symmetric_elabels(graph)
nx.write_gpickle(graph_new, os.path.join(dump_path, g))
pass
pass
def weisfeiler_lehman_graph_hash(graph, edge_attr=None, node_attr=None, iterations=3, digest_size=16):
"""Return Weisfeiler Lehman (WL) graph hash.
The function iteratively aggregates and hashes neighbourhoods of each node.
After each node's neighbors are hashed to obtain updated node labels,
a hashed histogram of resulting labels is returned as the final hash.
Hashes are identical for isomorphic graphs and strong guarantees that
non-isomorphic graphs will get different hashes. See [1] for details.
Note: Similarity between hashes does not imply similarity between graphs.
If no node or edge attributes are provided, the degree of each node
is used as its initial label.
Otherwise, node and/or edge labels are used to compute the hash.
Parameters
----------
graph: graph
The graph to be hashed.
Can have node and/or edge attributes. Can also have no attributes.
edge_attr: string
The key in edge attribute dictionary to be used for hashing.
If None, edge labels are ignored.
node_attr: string
The key in node attribute dictionary to be used for hashing.
If None, and no edge_attr given, use
degree of node as label.
iterations: int
Number of neighbor aggregations to perform.
Should be larger for larger graphs.
digest_size: int
Size of blake2b hash digest to use for hashing node labels.
Returns
-------
h : string
Hexadecimal string corresponding to hash of the input graph.
Examples
--------
Two graphs with edge attributes that are isomorphic, except for
differences in the edge labels.
>>> import networkx as nx
>>> G1 = nx.Graph()
>>> G1.add_edges_from([(1, 2, {'label': 'A'}),\
(2, 3, {'label': 'A'}),\
(3, 1, {'label': 'A'}),\
(1, 4, {'label': 'B'})])
>>> G2 = nx.Graph()
>>> G2.add_edges_from([(5,6, {'label': 'B'}),\
(6,7, {'label': 'A'}),\
(7,5, {'label': 'A'}),\
(7,8, {'label': 'A'})])
Omitting the `edge_attr` option, results in identical hashes.
>>> weisfeiler_lehman_graph_hash(G1)
'0db442538bb6dc81d675bd94e6ebb7ca'
>>> weisfeiler_lehman_graph_hash(G2)
'0db442538bb6dc81d675bd94e6ebb7ca'
With edge labels, the graphs are no longer assigned
the same hash digest.
>>> weisfeiler_lehman_graph_hash(G1, edge_attr='label')
'408c18537e67d3e56eb7dc92c72cb79e'
>>> weisfeiler_lehman_graph_hash(G2, edge_attr='label')
'f9e9cb01c6d2f3b17f83ffeaa24e5986'
References
-------
.. [1] Shervashidze, Nino, Pascal Schweitzer, Erik Jan Van Leeuwen,
Kurt Mehlhorn, and Karsten M. Borgwardt. Weisfeiler Lehman
Graph Kernels. Journal of Machine Learning Research. 2011.
http://www.jmlr.org/papers/volume12/shervashidze11a/shervashidze11a.pdf
"""
from collections import Counter
from hashlib import blake2b
def neighborhood_aggregate(graph, node, node_labels, edge_attr=None):
"""
Compute new labels for given node by aggregating
the labels of each node's neighbors.
"""
label_list = [node_labels[node]]
for nei in graph.neighbors(node):
prefix = "" if not edge_attr else graph[node][nei][edge_attr]
label_list.append(prefix + node_labels[nei])
return "".join(sorted(label_list))
def weisfeiler_lehman_step(graph, labels, edge_attr=None, node_attr=None):
"""
Apply neighborhood aggregation to each node
in the graph.
Computes a dictionary with labels for each node.
"""
new_labels = dict()
for node in graph.nodes():
new_labels[node] = neighborhood_aggregate(graph, node, labels, edge_attr=edge_attr)
return new_labels
items = []
node_labels = dict()
# set initial node labels
for node in graph.nodes():
if (not node_attr) and (not edge_attr):
node_labels[node] = str(graph.degree(node))
elif node_attr:
node_labels[node] = str(graph.nodes[node][node_attr])
else:
node_labels[node] = ""
for k in range(iterations):
node_labels = weisfeiler_lehman_step(graph, node_labels, edge_attr=edge_attr)
counter = Counter()
# count node labels
for node, d in node_labels.items():
h = blake2b(digest_size=digest_size)
h.update(d.encode("ascii"))
counter.update([h.hexdigest()])
# sort the counter, extend total counts
items.extend(sorted(counter.items(), key=lambda x: x[0]))
# hash the final counter
h = blake2b(digest_size=digest_size)
h.update(str(tuple(items)).encode("ascii"))
h = h.hexdigest()
return h
def fix_buggy_edges(graph, label="LW", strategy="remove", edge_map=GRAPH_KEYS["edge_map"][TOOL]):
"""
Sometimes some edges have weird names such as t.W representing a fuzziness.
We just remove those as they don't deliver a good information
:param graph:
:param strategy: How to deal with it : for now just remove them.
In the future maybe add an edge type in the edge map ?
:return:
"""
if strategy == "remove":
# Filter weird edges for now
to_remove = list()
for start_node, end_node, nodedata in graph.edges(data=True):
if nodedata[label] not in edge_map:
to_remove.append((start_node, end_node))
for start_node, end_node in to_remove:
graph.remove_edge(start_node, end_node)
else:
raise ValueError(f"The edge fixing strategy : {strategy} was not implemented yet")
return graph
def get_sequences(graph: nx.Graph,
gap_tolerance=2,
longest_only=True,
min_size_return=5,
verbose=True) -> Tuple[Dict[str, Tuple[str, List[str]]]]:
"""Extract ordered sequences from each chain of the RNA.
Returns a dictionary mapping <pdbid.chain>: (sequence, list of node IDs)
.. warning::
Currently does not handle missing residues. If a residue is missing it is simply skipped.
:param graph: an nx.Graph of an RNA.
"""
sequences = {}
chains = set([n.split(".")[1] for n in graph.nodes()])
seqs = {c: [] for c in chains}
for nt, d in graph.nodes(data=True):
pdbid, ch, pos = nt.split(".")
nuc = d["nt_code"].upper()
if nuc not in ["A", "U", "C", "G"]:
nuc = "N"
seqs[ch].append((nuc, int(pos)))
for ch, seq in seqs.items():
sorted_seq = sorted(seq, key=lambda x: x[1])
sorted_ids = [f"{pdbid}.{ch}.{pos}" for _, pos in sorted_seq]
# check if sequence is discontinuous and keep track of all its consecutive segments
previous = 0
consecutives = []
for i in range(len(sorted_ids) - 1):
fivep = int(sorted_ids[i].split(".")[2])
threep = int(sorted_ids[i + 1].split(".")[2])
if threep != fivep + 1:
if verbose:
print(f"WARNING: chain discontinuous.")
gap = threep - fivep - 1
if gap >= gap_tolerance:
consecutives.append((previous, i + 1))
previous = i + 1
consecutives.append((previous, len(sorted_ids)))
# Simply return the longest
if longest_only:
longest = sorted(consecutives, key=lambda x: x[1] - x[0])[-1]
consecutives = [longest]
# If we return more than one, only keep ones larger than a threshold, using 5 is nice for CD-Hit usage
else:
consecutives = [x for x in consecutives if x[1] - x[0] > min_size_return]
# Finally, return all such chunks, named with their start/end residues
for i, (start, end) in enumerate(consecutives):
sorted_seq_chunk = "".join([s for s, _ in sorted_seq[start:end]])
sorted_ids_chunk = sorted_ids[start:end]
if len(consecutives) == 1:
chunk_name = f"{pdbid}.{ch}"
else:
start_id = sorted_ids_chunk[0].split('.')[-1]
end_id = sorted_ids_chunk[-1].split('.')[-1]
chunk_name = f"{pdbid}.{ch}.{start_id}.{end_id}"
sequences[chunk_name] = sorted_seq_chunk, sorted_ids_chunk
return sequences