Source code for rnaglib.transforms.represent.point_cloud
import torch
import numpy as np
from .representation import Representation
def get_point_cloud_dict(rna_graph, features_dict, sort=False):
"""
This is factored out because this computation is also used by the voxel based representation.
:param rna_graph:
:param features_dict:
:return:
"""
node_names = []
res_dict = {'point_cloud_coords': []}
if "nt_features" in features_dict:
res_dict['point_cloud_feats'] = []
if "nt_targets" in features_dict:
res_dict['point_cloud_targets'] = []
node_iterator = rna_graph.nodes.data()
node_iterator = sorted(node_iterator) if sort else node_iterator
for node, attrs in node_iterator:
node_names.append(node)
node_coords = attrs['C5prime_xyz']
node_coords = torch.as_tensor(np.array(node_coords, dtype=float))
res_dict['point_cloud_coords'].append(node_coords)
if "nt_features" in features_dict:
res_dict['point_cloud_feats'].append(features_dict['nt_features'][node])
if "nt_targets" in features_dict:
res_dict['point_cloud_targets'].append(features_dict['nt_targets'][node])
# for key, value in res_dict.items():
# print(key, [val.shape for val in value])
stacked_res_dict = {key: torch.stack(value, dim=0) for key, value in res_dict.items()}
stacked_res_dict['point_cloud_nodes'] = node_names
return stacked_res_dict
[docs]
class PointCloudRepresentation(Representation):
"""
Converts RNA into a point cloud based representation
"""
[docs]
def __init__(self, hstack=True, sorted_nodes=True, **kwargs):
super().__init__(**kwargs)
self.hstack = hstack
self.sorted_nodes = sorted_nodes
pass
def __call__(self, rna_graph, features_dict):
return get_point_cloud_dict(rna_graph=rna_graph, features_dict=features_dict, sort=self.sorted_nodes)
@property
def name(self):
return "point_cloud"
def batch(self, samples):
"""
Batch a list of point cloud samples
:param samples: A list of the output from this representation
:return: a batched version of it.
"""
pc_batch = {}
for key, value in samples[0].items():
if self.hstack:
if key == 'point_cloud_nodes':
pc_batch[key] = [node_id for sample in samples for node_id in sample[key]]
else:
pc_batch[key] = torch.cat([sample[key] for sample in samples], dim=0)
else:
pc_batch[key] = [sample[key] for sample in samples]
return pc_batch