
import sys
import torch
import torch.nn as nn
import torch.nn.functional as Fun
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import matplotlib
import math
import pandas as pd
import scipy.io as sio
from sklearn.metrics import accuracy_score
from scipy.io import loadmat
import shutil
import os
from model_random import Gene_ontology_network
from net_train import net_train

from net_test import net_test

from model_random import Gene_ontology_network


from convert_to_gpu import gpu
from convert_to_gpu_and_tensor import gpu_t
from convert_to_gpu_scalar import gpu_ts
from convert_to_cpu import cpu
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from sklearn import metrics
plt.ioff()


id_fold    = math.floor(int(sys.argv[1])%10)

e_losses   = []
l_dim      = 50
num_epochs = 501
num_s      = 0
num_f      = 5
iter_n     = 20
temperature = gpu_ts(0.1)    
lambda0 = gpu_ts(0.00001)
prob_ref = [gpu_ts(0.01)]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

######################################################################################################################################################################################
# Loading data.
gene_train_l1 = loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_pretrain/train_test_data_gcn_5_layers_joint_model'+str(id_fold+1)+'.mat')['gene_train']
gene_train    = gpu(torch.tensor(gene_train_l1).float())

gene = gene_train

#load adjacency matrix, A_p
A = loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_pretrain/train_test_data_gcn_5_layers_joint_model'+str(id_fold+1)+'.mat')['A']
A = gpu(torch.tensor(A).float().t().to_sparse().coalesce() )

# Load gene to pathway embedding matrix.
A_g = loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_pretrain/train_test_data_gcn_5_layers_joint_model'+str(id_fold+1)+'.mat')['mask_gene']
A_g = gpu(torch.tensor(A_g).float().to_sparse().coalesce() )


Y_train = loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_pretrain/train_test_data_gcn_5_layers_joint_model'+str(id_fold+1)+'.mat')['class_train']


Y_train = gpu(torch.tensor(Y_train).float())


Y       =  Y_train

# Pol_n contains the number of nodes in each layer of ontology.
pool_dim = loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_pretrain/train_test_data_gcn_5_layers_joint_model'+str(id_fold+1)+'.mat')['pool_n']
pool_dim = pool_dim.tolist()

############################################################################################################################################################################################################################################################################################################################################################################

# Thresholdin probabilities.
sps_g = [0] 


# Batch Size
batch_size = 113

# Create folder to store models and results
if( not os.path.exists('interactions') ):
    os.mkdir('interactions')

if( not os.path.exists('interactions/pretrain'+str(id_fold+1)) ):
    os.mkdir('interactions/pretrain'+str(id_fold+1))

# Load epoch number.    
mid = loadmat('checkpoint_model_'+str(id_fold+1)+'.mat')['ind']
mid=mid[0][0]


for f in range(0,1):
    # Load model
    net = Gene_ontology_network(A_g, A, 2, 4, [5,5,5,5], pool_dim,l_dim)
    net.load_state_dict(torch.load('pretrained'+str(id_fold+1)+'/models/model'+str(mid)+'.pth'))
    net = net.to(device)
    
    # We extract the importance values between each pai of nodes.
    bno = 0
    batch_size = 257
    with torch.no_grad():
        net.eval()
        for beg_i in range(0, gene.size(0), batch_size):
            # Initialise input.
            x_batch = gene[beg_i:beg_i + batch_size, :]
            y_batch   = Y[beg_i:beg_i + batch_size, :]
            
            # gene encoding
            W = torch.sparse.FloatTensor(net.i,   net.t[0], net.size)            
            x = torch.sparse.mm(W, x_batch.t()).t()
            x = x.unsqueeze(2)
            
            for ii in range(1,len(net.t)):            
                W = torch.sparse.FloatTensor(net.i,net.t[ii], net.size)
                t = torch.sparse.mm(W, x_batch.t()).t()
                x = torch.cat((x,t.unsqueeze(2)), dim=2) 
            
            
            ###########################################################################################
            
            # We initialize a matrix with no of [subjects X total number of interactions]
            int_map_final = gpu(torch.zeros(x.size()[0], len(net.n_loc_in[0][0]) )) 
            
            for jj in range(len(net.w_inc)):  
                
                out       = gpu(torch.zeros((x.size()[0], x.size()[1],net.f_dim[jj+1]))) # Temporary var
                
                i_in      = net.n_loc_in[jj].clone()      
                store_in  = net.store_in[jj] 
                
                x_in   =  net.w_inc[jj](x)                           
                x_s    =  net.w_s_loop[jj](x)
                
                x1_temp = x_in.clone()
                x2_temp = x_in.clone()
                
                v_inc  =  net.helper(x1_temp[:,i_in[0,:].clone(),:], x2_temp[:,i_in[1,:].clone(),:], net.w_att_in[jj],  net.w_att_in_act[jj])
                v_s    =  net.w_att_s_act[jj](net.w_att_s[jj](x_s))
                
                print(jj)
                
                for k in range(x.size()[0]):   

                    A_int = torch.sparse.FloatTensor(net.n_loc_in[0], int_map_final[k,:], A.size()).to_dense() # For each subject we create a sparse matrix A_int with values from int_map_final
                    
                    A_hat_in, ind, v, _       = net.attention_adj(i_in,  v_inc[k,:].squeeze(), torch.Size([x.size()[1], x.size()[1]]), store_in,  i_in)  # v contains all the interation importance values.
                    
                    x_incoming     = net.gcn( A_hat_in,  x_in[k,:,:])  # n_sub x no of nodes x  d 
                    x_self         = x_s[k,:,:]*v_s[k,:,:]
     
                    t              = x_incoming + x_self
                    out[k,:,:]     = t
                    

                    A_int.index_put_((sum(net.pool[:jj])+ind[0,:],sum(net.pool[:jj])+ind[1,:]), v) # We replace the importance score betwween the leaf nodes and their immediate parent nodes before discarding them due to pooling.
                    
                    int_map_final[k,:] =  A_int.to_sparse().coalesce().values() # The importance score are taken between a child node and a parent node only when the child node is a leaf node.

                
                out1 = net.gcn_D[jj](net.w_act[jj](net.G_B[jj](out.permute(0,2,1)).permute(0,2,1) ) )
                                    
                ind_pool = net.pool[jj]
                           
                x = out1[:,ind_pool:,:].clone()
                
            map_data = cpu(int_map_final).data.numpy()
            int_index = cpu(A_int.to_sparse().coalesce().indices()).data.numpy()
            # Save the interaction maps for all the subsampled subjects.
            vis = {'map':map_data,'indx':int_index, 'class':cpu(y_batch).data.numpy()}
            sio.savemat('interactions/pretrain'+str(id_fold+1)+'/batch'+str(bno)+'.mat',vis)
            bno = bno+1


        
