Source code for rnaglib.learning.learn

import os
import sys
import time

import networkx as nx
import numpy as np
from sklearn.metrics import roc_auc_score

import torch
import torch.nn.functional as F

if __name__ == "__main__":
    script_dir = os.path.dirname(os.path.realpath(__file__))
    sys.path.append(os.path.join(script_dir, '..', '..'))

from rnaglib.config.graph_keys import GRAPH_KEYS, TOOL
from rnaglib.utils import misc
from rnaglib.learning import learning_utils


[docs] def pretrain_unsupervised(model, train_loader, optimizer, node_sim=None, learning_routine=learning_utils.LearningRoutine(), rec_params={"similarity": True, "normalize": False, "use_graph": False, "hops": 2} ): """ Perform the pretraining routine to get embeddings from graph nodes, that correlate with a node kernel. :param model: The model to train :param optimizer: the optimizer to use (eg SGD or Adam) :param train_loader: The loader to use for training, as defined in GraphLoader :param node_sim: If None, we just rely on the node_sim in the data loader. :param learning_routine: A LearningRoutine object, if we want to also use a validation phase and early stopping :param rec_params: These are parameters useful for the loss computation and further explained in learning_utils.rec_loss :return: The best loss obtained """ device = model.current_device learning_routine.device = device start_time = time.time() for epoch in range(learning_routine.num_epochs): # Training phase model.train() running_loss = 0.0 num_batches = len(train_loader) for batch_idx, batch in enumerate(train_loader): graph, (K, node_ids) = batch['graph'], batch['ring'] # Get data on the devices K = K.to(device) graph = learning_utils.send_graph_to_device(graph, device) # Do the computations for the forward pass graph, out = model(graph) loss = learning_utils.rec_loss(embeddings=out, target_K=K, graph=graph, node_ids=node_ids, **rec_params) # Backward loss.backward() optimizer.step() model.zero_grad() # Metrics loss = loss.item() running_loss += loss if batch_idx % learning_routine.print_each == 0: time_elapsed = time.time() - start_time print( f'Train Epoch: {epoch + 1} [{(batch_idx + 1)}/{num_batches} ' f'({100. * (batch_idx + 1) / num_batches:.0f}%)]\t' f'Loss: {loss:.6f} Time: {time_elapsed:.2f}') # tensorboard logging step = epoch * num_batches + batch_idx if learning_routine.writer is not None: learning_routine.writer.add_scalar("Training loss", loss, step) train_loss = running_loss / num_batches if learning_routine.writer is not None: learning_routine.writer.add_scalar("Training epoch loss", train_loss, epoch) # Validation phase, we always use early stopping, optionnaly on the training values if learning_routine.validation_loader is None: early_stop = learning_routine.early_stopping_routine(validation_loss=train_loss, epoch=epoch, model=model, optimizer=optimizer) else: validation_loss = learning_utils.evaluate_model_unsupervised(model, validation_loader=learning_routine.validation_loader, rec_params=rec_params) if learning_routine.writer is not None: learning_routine.writer.add_scalar("Validation loss during training", validation_loss, epoch) early_stop = learning_routine.early_stopping_routine(validation_loss=validation_loss, epoch=epoch, model=model, optimizer=optimizer) if early_stop: break return learning_routine.best_loss
[docs] def train_supervised(model, optimizer, train_loader, learning_routine=learning_utils.LearningRoutine()): """ Performs the entire training routine for a supervised task :param model: The model to train :param optimizer: the optimizer to use (eg SGD or Adam) :param train_loader: The loader to use for training, as defined in dataset/GraphLoader :param learning_routine: A LearningRoutine object, if we want to also use a validation phase and early stopping :return: The best loss obtained """ device = model.current_device start_time = time.time() for epoch in range(learning_routine.num_epochs): # Training phase model.train() running_loss = 0.0 num_batches = len(train_loader) for batch_idx, batch in enumerate(train_loader): # Get data on the devices graph = batch['graph'] graph = learning_utils.send_graph_to_device(graph, device) # Do the computations for the forward pass out = model(graph) labels = graph.ndata['nt_targets'] loss = torch.nn.MSELoss()(out, labels) # Backward loss.backward() optimizer.step() model.zero_grad() # Metrics loss = loss.item() running_loss += loss if batch_idx % learning_routine.print_each == 0: time_elapsed = time.time() - start_time print( f'Train Epoch: {epoch + 1} [{(batch_idx + 1)}/{num_batches} ' f'({100. * (batch_idx + 1) / num_batches:.0f}%)]\t' f'Loss: {loss:.6f} Time: {time_elapsed:.2f}') # tensorboard logging if learning_routine.writer is not None: step = epoch * num_batches + batch_idx learning_routine.writer.add_scalar("Training loss", loss, step) train_loss = running_loss / num_batches if learning_routine.writer is not None: learning_routine.writer.add_scalar("Training epoch loss", train_loss, epoch) # Test phase, if we do not have a validation, just iterate. # Otherwise call the routines. if learning_routine.validation_loader is None: early_stop = learning_routine.early_stopping_routine(validation_loss=train_loss, epoch=epoch, model=model, optimizer=optimizer) else: validation_loss = learning_utils.evaluate_model_supervised(model, loader=learning_routine.validation_loader) if learning_routine.writer is not None: learning_routine.writer.add_scalar("Validation loss during training", validation_loss, epoch) early_stop = learning_routine.early_stopping_routine(validation_loss=validation_loss, epoch=epoch, model=model, optimizer=optimizer) if early_stop: break return learning_routine.best_loss
[docs] def train_linkpred(model, optimizer, train_loader_generator, validation_loader_generator ): """ Train a link prediction model : given RNA graphs, predict whether nodes are bound :param model: The model to train :param optimizer: the optimizer to use (eg SGD or Adam) :param train_loader_generator: The edge loader to use for training, as defined in dataset/GraphLoader :param validation_loader_generator: The edge loader to use for training, as defined in dataset/GraphLoader :return: The best loss obtained """ for epoch in range(3): count = 0 time_start = time.time() train_loader = train_loader_generator.get_edge_loader() for g in train_loader: for step, (input_nodes, positive_graph, negative_graph, blocks) in enumerate(g): pos_score = model(positive_graph) neg_score = model(positive_graph, negative_graph=negative_graph) score = torch.cat([pos_score, neg_score]) label = torch.cat([torch.ones_like(pos_score), torch.zeros_like(neg_score)]) loss = F.binary_cross_entropy_with_logits(score, label) optimizer.zero_grad() loss.backward() optimizer.step() count += len(input_nodes) # if True or not count % 50: # print(count, loss.item(), time.time() - time_start) print(f"EPOCH {epoch}, time for the epoch : {time.time() - time_start:2f}, last loss {loss.item():2f}") aucs = [] count = 0 model.eval() validation_loader = validation_loader_generator.get_edge_loader() for i, g in enumerate(validation_loader): print("val graph ", i) for input_nodes, positive_graph, negative_graph, blocks in g: with torch.no_grad(): pos_score = model(positive_graph) neg_score = model(positive_graph, negative_graph=negative_graph) score = torch.cat([pos_score, neg_score]).detach().numpy() label = torch.cat([torch.ones_like(pos_score), torch.zeros_like(neg_score)]) label = label.detach().numpy() # print(score, label) aucs.append(roc_auc_score(label, score)) count += 1 print('Time used : ', time.time() - time_start) print("AUC", np.mean(aucs)) pass
if __name__ == '__main__': pass from rnaglib.learning import models from rnaglib.learning import models from rnaglib.dataset_transforms import rna_loader test_unsupervised = False test_supervised = True if test_unsupervised: from rnaglib.kernels import node_sim embedder_model = models.Embedder([10, 10]) optimizer = torch.optim.Adam(embedder_model.parameters()) node_sim_func = node_sim.SimFunctionNode(method='R_1', depth=2) data_path = os.path.join(script_dir, '..', 'data/annotated/NR_annot/') node_features = ['nt_code'] unsupervised_dataset = rna_loader.GraphDataset(node_simfunc=node_sim_func, node_features=node_features, data_path=data_path, chop=True) train_loader = rna_loader.get_loader(dataset=unsupervised_dataset, split=False, num_workers=0, max_size_kernel=100) pretrain_unsupervised(model=embedder_model, optimizer=optimizer, node_sim=node_sim_func, train_loader=train_loader ) if test_supervised: data_path = os.path.join(script_dir, '../data/annotated/samples/') annotated_path = "../data/graphs" node_features = ['nt_code'] node_target = ['binding_protein'] # Define model # GET THE DATA GOING supervised_dataset = rna_loader.GraphDataset(data_path=annotated_path, node_features=node_features, node_target=node_target) train_loader, validation_loader, test_loader = rna_loader.get_loader(dataset=supervised_dataset, split=True, num_workers=0) embedder_model = models.Embedder([10, 10], infeatures_dim=1) classifier_model = models.Classifier(embedder=embedder_model, classif_dims=[1]) optimizer = torch.optim.Adam(classifier_model.parameters(), lr=0.001) train_supervised(model=classifier_model, optimizer=optimizer, train_loader=train_loader)