Source code for rnaglib.transforms.represent.voxel

import os
import sys

import copy
import networkx as nx
import numpy as np
from sklearn.gaussian_process.kernels import RBF
import torch

from .representation import Representation
from .point_cloud import get_point_cloud_dict


def get_bins(coords, spacing, padding, xyz_min=None, xyz_max=None):
    """
    Compute the 3D bins from the coordinates
    """
    if xyz_min is None:
        xm, ym, zm = np.nanmin(coords, axis=0) - padding
    else:
        xm, ym, zm = xyz_min - padding
    if xyz_max is None:
        xM, yM, zM = np.nanmax(coords, axis=0) + padding
    else:
        xM, yM, zM = xyz_max + padding

    # print(xm)
    # print(xM)
    # print(spacing)
    xi = np.arange(xm, xM, spacing)
    yi = np.arange(ym, yM, spacing)
    zi = np.arange(zm, zM, spacing)
    return xi, yi, zi


def just_one(coord, xi, yi, zi, sigma, feature, total_grid, use_multiprocessing=False):
    """

    :param coord: x,y,z
    :param grid:
    :param sigma:
    :return:
    """
    #  Find subgrid
    nx, ny, nz = xi.size, yi.size, zi.size

    bound = int(4 * sigma)
    x, y, z = coord
    binx = np.digitize(x, xi)
    biny = np.digitize(y, yi)
    binz = np.digitize(z, zi)
    min_bounds_x, max_bounds_x = max(0, binx - bound), min(nx, binx + bound)
    min_bounds_y, max_bounds_y = max(0, biny - bound), min(ny, biny + bound)
    min_bounds_z, max_bounds_z = max(0, binz - bound), min(nz, binz + bound)

    X, Y, Z = np.meshgrid(xi[min_bounds_x: max_bounds_x],
                          yi[min_bounds_y: max_bounds_y],
                          zi[min_bounds_z:max_bounds_z],
                          indexing='ij')
    X, Y, Z = X.flatten(), Y.flatten(), Z.flatten()

    #  Compute RBF
    rbf = RBF(sigma)
    subgrid = rbf(coord, np.c_[X, Y, Z])
    subgrid = subgrid.reshape((max_bounds_x - min_bounds_x,
                               max_bounds_y - min_bounds_y,
                               max_bounds_z - min_bounds_z))

    # Broadcast the feature throughout the local grid.
    subgrid = subgrid[None, ...]
    feature = feature[:, None, None, None]
    subgrid_feature = subgrid * feature

    #  Add on the first grid
    if not use_multiprocessing:
        total_grid[:, min_bounds_x: max_bounds_x, min_bounds_y: max_bounds_y,
        min_bounds_z:max_bounds_z] += subgrid_feature
    else:
        return min_bounds_x, max_bounds_x, min_bounds_y, max_bounds_y, min_bounds_z, max_bounds_z, subgrid_feature


def gaussian_blur(coords, xi, yi, zi, features=None, sigma=1., use_multiprocessing=False):
    """

    :param coords: (n_points, 3)
    :param xi:
    :param yi:
    :param zi:
    :param features: (n_points, dim) or None
    :param sigma:
    :param use_multiprocessing:
    :return:
    """

    nx, ny, nz = xi.size, yi.size, zi.size
    features = np.ones((len(coords), 1)) if features is None else features
    feature_len = features.shape[1]
    total_grid = np.zeros(shape=(feature_len, nx, ny, nz))

    if use_multiprocessing:
        import multiprocessing
        args = [(coord, xi, yi, zi, sigma, features[i], None, True) for i, coord in enumerate(coords)]
        pool = multiprocessing.Pool()
        grids_to_add = pool.starmap(just_one, args)
        for min_bounds_x, max_bounds_x, min_bounds_y, max_bounds_y, min_bounds_z, max_bounds_z, subgrid in grids_to_add:
            total_grid[:, min_bounds_x: max_bounds_x, min_bounds_y: max_bounds_y, min_bounds_z:max_bounds_z] += subgrid
    else:
        for i, coord in enumerate(coords):
            just_one(coord, feature=features[i], xi=xi, yi=yi, zi=zi, sigma=sigma, total_grid=total_grid)
    return total_grid


def get_grid(coords, features=None, spacing=2, padding=3, xyz_min=None, xyz_max=None, sigma=1.):
    """
    Generate a grid from the coordinates
    :param coords: (n,3) array
    :param features: (n,k) array
    :param spacing:
    :param padding:
    :param xyz_min:
    :param xyz_max:
    :param sigma:
    :return:
    """
    xi, yi, zi = get_bins(coords, spacing, padding, xyz_min, xyz_max)
    grid = gaussian_blur(coords, xi, yi, zi, features=features, sigma=sigma)
    return grid


[docs] class VoxelRepresentation(Representation): """ Converts RNA into a voxel based representation """
[docs] def __init__(self, spacing=2, padding=3, sigma=1., **kwargs): super().__init__(**kwargs) self.spacing = spacing self.padding = padding self.sigma = sigma
def __call__(self, rna_graph, features_dict): # If we need voxels, let's do the computations. # We redo the point cloud computations that are fast compared to voxels point_cloud_dict = get_point_cloud_dict(rna_graph, features_dict, sort=False) point_cloud_coords = point_cloud_dict['point_cloud_coords'] output_dim = 0 if "point_cloud_feats" in point_cloud_dict: stacked_feats = point_cloud_dict['point_cloud_feats'] input_dim = stacked_feats.shape[1] # If no features are provided, use a one hot encoding else: stacked_feats = torch.ones(size=(len(point_cloud_coords), 1)) input_dim = 1 to_embed = [stacked_feats] if "point_cloud_targets" in point_cloud_dict: stacked_targets = point_cloud_dict['point_cloud_targets'] output_dim = stacked_targets.shape[1] to_embed.append(stacked_targets) if output_dim > 0: features = torch.hstack((stacked_feats, stacked_targets)) else: features = stacked_feats features = features.numpy() # TODO : port in torch to avoid back and forth coords = point_cloud_coords.numpy() voxel_representation = get_grid(coords=coords, features=features, spacing=self.spacing, padding=self.padding, sigma=self.sigma) voxel_representation = torch.from_numpy(voxel_representation) res_dict = {'voxel_feats': voxel_representation[:input_dim]} if output_dim > 0: res_dict['voxel_target'] = voxel_representation[-output_dim:] return res_dict @property def name(self): return "voxel" def batch(self, samples): """ Batch a list of voxel samples :param samples: A list of the output from this representation :return: a batched version of it. """ voxel_batch = {} for key, value in samples[0].items(): voxel_batch[key] = [sample[key] for sample in samples] return voxel_batch