# Python program to print all paths from a source to destination.
import sys
import torch
import torch.nn as nn
import torch.nn.functional as Fun
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import matplotlib
import math
import pandas as pd
import scipy.io as sio
from sklearn.metrics import accuracy_score
from scipy.io import loadmat
import shutil
import os

from net_train import net_train

from net_test import net_test

from model_random import Gene_ontology_network


from convert_to_gpu import gpu
from convert_to_gpu_and_tensor import gpu_t
from convert_to_gpu_scalar import gpu_ts
from convert_to_cpu import cpu
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from collections import defaultdict
import pickle
   
# This class represents a directed graph 
# using adjacency list representation

def save_object(obj, filename):
    with open(filename, 'wb') as output:  # Overwrites any existing file.
        pickle.dump(obj, output, pickle.HIGHEST_PROTOCOL)
        
def load_object(filename):
    with open(filename, 'rb') as input:  # Overwrites any existing file.
        obj = pickle.load(input)
    return obj
        
class Graph:
   
    def __init__(self, vertices):
        # No. of vertices
        self.V = vertices 
          
        # default dictionary to store graph
        self.graph = defaultdict(list) 
        
        self.path = []
   
    # function to add an edge to graph
    def addEdge(self, u, v):
        self.graph[u].append(v)
   
    '''A recursive function to print all paths from 'u' to 'd'.
    visited[] keeps track of vertices in current path.
    path[] stores actual vertices and path_index is current
    index in path[]'''
    def printAllPathsUtil(self, u, d, visited, path):
  
        # Mark the current node as visited and store in path
        visited[u]= True
        path.append(u)
  
        # If current vertex is same as destination, then print
        # current path[]
        if u == d:
            self.path = self.path + [path.copy()]
            print(path)
        else:
            # If current vertex is not destination
            # Recur for all the vertices adjacent to this vertex
            for i in self.graph[u]:
                if visited[i]== False:
                    self.printAllPathsUtil(i, d, visited, path)
                      
        # Remove current vertex from path[] and mark it as unvisited
        path.pop()
        visited[u]= False
   
   
    # Prints all paths from 's' to 'd'
    def printAllPaths(self, s, d):
  
        # Mark all the vertices as not visited
        visited =[False]*(self.V)
  
        # Create an array to store paths
        path = []
  
        # Call the recursive helper function to print all paths
        self.printAllPathsUtil(s, d, visited, path)
        
        f_path = self.path
        
#load adjacency matrix, A_p
A = loadmat('/mnt/sdb1/sayan/pathway_analysis/data/train_test_data_gcn_5_layers_joint_model.mat')['A']
A = torch.tensor(A).float().t()    

#pool_n contains the nnumber of nodes in each layer of ontology.
pool_dim = loadmat('/mnt/sdb1/sayan/pathway_analysis/data/train_test_data_gcn_5_layers_joint_model.mat')['pool_n']
pool_dim = pool_dim.tolist()[0]

g = Graph(sum(pool_dim))    

# Create a graph given in the above diagram
for i in range(A.size(0)):
    print(i)
    for j in range(A.size(1)):
        if A[i,j]==1:
            g.addEdge(j, i)
    
save_object(g,'graph.pkl')
#g=load_object('graph.pkl')

# Find all paths between each leaf and each root.

# Create folder to store models and results
if( not os.path.exists('paths')):
    os.mkdir('paths')
    
    
roots = list(range(pool_dim[0]))
leaves =  list(range(sum(pool_dim)-pool_dim[-1], sum(pool_dim)))
iterr = 0
for r in roots:
    for l in leaves:
        g.printAllPaths(r,l)
        if len(g.path)>0:
            path = np.array(g.path)
            for pp in path:
                np.save('paths/path'+str(iterr), pp)
                vis = {'path':np.array(pp)}
                sio.savemat('paths/path'+str(iterr)+'.mat',vis)
                iterr +=1
            g.path=[]

   


