#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Jan 22 16:43:55 2024

@author: anonymous
"""

import numpy as np
import torch 
import torch_geometric as tg
import matplotlib.pyplot as plt 
import tqdm
import pickle


###############################################################################



class GAN(torch.nn.Module):
    """ Graph attention neural network.

	Parameters
	----------
	num_node_features : Int 
			size of node attribute vectors
	num_classes : Int 
			number of classes
    	slices : Int 
    			number of sclices produced by graphcode
	GAN_layers : List(List(Int))
            list of parameters for len(list) graph attention layers [parameters,heads]
    LIN_layers : List(Int)
            list of parameters for len(list) dense layers 
    reg : List(Float)
            list of parameters for dropout layers
    bn : Bool
             if True, batch normalization is used
    """    
    def __init__(self,num_node_features,num_classes,slices,GAN_layers,LIN_layers,reg,bn,edge):
        super(GAN, self).__init__()   
        
        self.slices=slices
        self.reg=reg
        self.node=num_node_features
        self.N=GAN_layers[-1][0]
        self.H=GAN_layers[-1][1]
        
        self.GAN_layers=torch.nn.ModuleList(modules=None)
        self.LIN_layers=torch.nn.ModuleList(modules=None)
        self.BN=torch.nn.ModuleList(modules=None)
        
        self.GAN_layers.append(tg.nn.GATv2Conv(num_node_features, GAN_layers[0][0], heads=GAN_layers[0][1]))       
        for i in range(1,len(GAN_layers)):
            self.GAN_layers.append(tg.nn.GATv2Conv(GAN_layers[i-1][0]*GAN_layers[i-1][1], GAN_layers[i][0], heads=GAN_layers[i][1]))
            
        self.LIN_layers.append(torch.nn.Linear(slices*GAN_layers[-1][0]*GAN_layers[-1][1], LIN_layers[0]))
        self.BN.append(torch.nn.BatchNorm1d(LIN_layers[0]))
        
        for i in range(1,len(LIN_layers)):
            self.LIN_layers.append(torch.nn.Linear(LIN_layers[i-1], LIN_layers[i]))
            self.BN.append(torch.nn.BatchNorm1d(LIN_layers[i]))
            
        self.LIN_layers.append(torch.nn.Linear(LIN_layers[-1], num_classes))
        self.relu=torch.nn.ReLU()      
        self.bn=bn
        self.edge=edge

    def forward(self,data):
        x=data.x[:,:self.node]
        if self.edge==True:
            edge_index=data.edge_index
        else:
            edge_index=torch.empty(2,0, dtype=torch.long).to(torch.device("cuda")) # uncomment to train without edges of graphcodes
        
        for i in range(len(self.GAN_layers)):
            x=self.relu(self.GAN_layers[i](x, edge_index))
            
        cluster=(self.slices*data.batch+data.x[:,self.node]).long()
        x=tg.nn.global_max_pool(x, cluster) 
        x=x.view(len(data),self.slices*self.N*self.H)        
        
        for i in range(len(self.LIN_layers)-1):
            x=torch.nn.functional.dropout(x, p=self.reg[i], training=self.training)
            if self.bn==True:
                x=self.relu(self.BN[i](self.LIN_layers[i](x)))
            else:
                x=self.relu(self.LIN_layers[i](x))
        x=torch.nn.functional.dropout(x, p=self.reg[-1], training=self.training)
        x=self.LIN_layers[-1](x)
        
        return x 
    
    
#############################################################################################



def fit_verbose(model, train_loader, test_loader, optimizer, loss_function, Epochs):
    """ Trains model on training set and tests it on validation/test 
        set after every epoch. Plots the progression of training and 
        validation loss and accuracy at the end.

	Parameters
	----------
	model : Torch neural network 
        			the model to train
	train_loader : Torch geometric data loader 
        			training dataset
	test_loader : Torch geometric data loader 
        			validation dataset
    	optimizer : Torch optimizer 
        			optimizer for training
	loss_function : Torch loss function 
        			loss function for training
    Epochs : Int
                number of training epochs              
    """    
    l1=len(train_loader)
    l2=len(test_loader)
    Train_Acc=[]
    Test_Acc=[]
    torch.backends.cudnn.benchmark=True
    device=torch.device("cuda") 

    
    for epoch in range(Epochs):        
        train_loss=0
        test_loss=0
        train_acc=0
        test_acc=0
        c1=0
        c2=0
        it=iter(train_loader)
        print('\n')
        print('Epoch', epoch+1,'/',Epochs,'\n')
        print('Training:','\n')
        
        model.train()
        
        for i in tqdm.tqdm(range(0,l1)):
            data=next(it).to(device)
            model.zero_grad()
            output=model(data)
            loss=loss_function(output, data.y) 
            loss.backward()
            optimizer.step()
            
        
        if (epoch+1)%1==0:
            model.eval()
            
            it=iter(train_loader)
            print('\n')
            print('Evaluating train data:','\n')

            for i in tqdm.tqdm(range(0,l1)):
                data=next(it).to(device)
                c1+=len(data)
                output=model(data) 
                loss=loss_function(output, data.y) 
                train_loss+=loss.item()
                pred=output.argmax(dim=1)
                train_acc+=int((pred==data.y).sum())
                
            it=iter(test_loader)
            print('\n')
            print('Evaluating test data','\n')
                
            for i in tqdm.tqdm(range(0,l2)):
                data=next(it).to(device)
                c2+=len(data)
                output=model(data) 
                loss=loss_function(output, data.y) 
                test_loss+=loss.item()
                pred=output.argmax(dim=1)
                test_acc+=int((pred==data.y).sum())
            
            print('\n')
            print('Train_Loss: ',round(train_loss/l1,3),' Train_Acc: ',round(train_acc/c1,3),' Test_Loss: ',round(test_loss/l2,3),' Test_Acc: ',round(test_acc/c2,3))
            Train_Acc.append(round(train_acc/c1,3))
            Test_Acc.append(round(test_acc/c2,3))
            
    x_axis=np.arange(1,Epochs+1,1)
    fig, ax = plt.subplots(figsize=(5, 2.7), layout='constrained')
    ax.plot(x_axis,Train_Acc, label='Train Accuracy') 
    ax.plot(x_axis,Test_Acc, label='Test Accuracy')  
    plt.axhline(y=0.5, color='k', linewidth=0.5, label='_nolegend_')
    plt.axhline(y=0.6, color='k', linewidth=0.5, label='_nolegend_')
    plt.axhline(y=0.7, color='k', linewidth=0.5, label='_nolegend_')
    plt.axhline(y=0.8, color='k', linewidth=0.5, label='_nolegend_')
    plt.axhline(y=0.9, color='k', linewidth=0.5, label='_nolegend_')
    ax.set_xlabel('Epochs')  
    ax.set_ylabel('Accuracy')  
    ax.legend() 
    print('\n')
    print('Mean test accuracy over last 10 epochs: ', round(np.mean(np.array(Test_Acc)[-10:]),3))
    


#############################################################################################


def fit_verbose_shuffle(model, train_loaders, test_loader, optimizer, loss_function, Epochs):
    """ Trains model on training set and tests it on validation/test 
        set after every epoch and change the graphcode basis after each epoch. 
        Plots the progression of training and validation loss and accuracy at the end.

	Parameters
	----------
	model : Torch neural network 
        			the model to train
	train_loaders : List of torch geometric data loaders 
        			training dataset with multiple graphcodes for each instance
	test_loader : Torch geometric data loader 
        			validation dataset
    	optimizer : Torch optimizer 
        			optimizer for training
	loss_function : Torch loss function 
        			loss function for training
    Epochs : Int
                number of training epochs              
    """    

    Train_Acc=[]
    Test_Acc=[]
    torch.backends.cudnn.benchmark=True
    device=torch.device("cuda") 

    
    for epoch in range(Epochs):        
        train_loss=0
        test_loss=0
        train_acc=0
        test_acc=0
        c1=0
        c2=0
        train_loader=train_loaders[epoch%len(train_loaders)]
        l1=len(train_loader)
        l2=len(test_loader)
        it=iter(train_loader)
        print('\n')
        print('Epoch', epoch+1,'/',Epochs,'\n')
        print('Training:','\n')
        
        model.train()
        
        for i in tqdm.tqdm(range(0,l1)):
            data=next(it).to(device)
            model.zero_grad()
            output=model(data)
            loss=loss_function(output, data.y) 
            loss.backward()
            optimizer.step()
            
        
        if (epoch+1)%1==0:
            model.eval()
            
            it=iter(train_loader)
            print('\n')
            print('Evaluating train data:','\n')

            for i in tqdm.tqdm(range(0,l1)):
                data=next(it).to(device)
                c1+=len(data)
                output=model(data) 
                loss=loss_function(output, data.y) 
                train_loss+=loss.item()
                pred=output.argmax(dim=1)
                train_acc+=int((pred==data.y).sum())
                
            it=iter(test_loader)
            print('\n')
            print('Evaluating test data','\n')
                
            for i in tqdm.tqdm(range(0,l2)):
                data=next(it).to(device)
                c2+=len(data)
                output=model(data) 
                loss=loss_function(output, data.y) 
                test_loss+=loss.item()
                pred=output.argmax(dim=1)
                test_acc+=int((pred==data.y).sum())
            
            print('\n')
            print('Train_Loss: ',round(train_loss/l1,3),' Train_Acc: ',round(train_acc/c1,3),' Test_Loss: ',round(test_loss/l2,3),' Test_Acc: ',round(test_acc/c2,3))
            Train_Acc.append(round(train_acc/c1,3))
            Test_Acc.append(round(test_acc/c2,3))
            
    x_axis=np.arange(1,Epochs+1,1)
    fig, ax = plt.subplots(figsize=(5, 2.7), layout='constrained')
    ax.plot(x_axis,Train_Acc, label='Train Accuracy') 
    ax.plot(x_axis,Test_Acc, label='Test Accuracy')  
    plt.axhline(y=0.5, color='k', linewidth=0.5, label='_nolegend_')
    plt.axhline(y=0.6, color='k', linewidth=0.5, label='_nolegend_')
    plt.axhline(y=0.7, color='k', linewidth=0.5, label='_nolegend_')
    plt.axhline(y=0.8, color='k', linewidth=0.5, label='_nolegend_')
    plt.axhline(y=0.9, color='k', linewidth=0.5, label='_nolegend_')
    ax.set_xlabel('Epochs')  
    ax.set_ylabel('Accuracy')  
    ax.legend() 
    print('\n')
    print('Mean test accuracy over last 10 epochs: ', round(np.mean(np.array(Test_Acc)[-10:]),3))
    
    
#############################################################################################



def fit(model, train_loader, optimizer, loss_function, Epochs):
    """ Trains model on training set.

	Parameters
	----------
	model : Torch neural network 
        			the model to train
	train_loader : Torch geometric data loader 
        			training dataset
    	optimizer : Torch optimizer 
        			optimizer for training
	loss_function : Torch loss function 
        			loss function for training
    Epochs : Int
                number of training epochs              
    """    
    l=len(train_loader)
    torch.backends.cudnn.benchmark=True
    device=torch.device("cuda") 

    print('\n')
    print('Training model:','\n')
    
    for epoch in tqdm.tqdm(range(Epochs)):        
        model.train()
        it=iter(train_loader)
        
        for i in range(0,l):
            data=next(it).to(device)
            model.zero_grad()
            output=model(data)
            loss=loss_function(output, data.y) 
            loss.backward()
            optimizer.step()    
            
#############################################################################################



def fit_shuffle(model, train_loaders, optimizer, loss_function, Epochs):
    """ Trains model on training set and change the graphcode basis after each epoch.

	Parameters
	----------
	model : Torch neural network 
        			the model to train
	train_loaders : List of torch geometric data loaders 
        			training dataset with multiple graphcodes for each instance
    	optimizer : Torch optimizer 
        			optimizer for training
	loss_function : Torch loss function 
        			loss function for training
    Epochs : Int
                number of training epochs              
    """    

    torch.backends.cudnn.benchmark=True
    device=torch.device("cuda") 

    print('\n')
    print('Training model:','\n')
    
    for epoch in tqdm.tqdm(range(Epochs)):     
        train_loader=train_loaders[epoch%len(train_loaders)]
        l=len(train_loader)
        it=iter(train_loader)
        
        model.train()
        
        for i in range(0,l):
            data=next(it).to(device)
            model.zero_grad()
            output=model(data)
            loss=loss_function(output, data.y) 
            loss.backward()
            optimizer.step()   

  

#############################################################################################    



def evaluate(model, test_loader, loss_function):
    """ Evaluates model on test set.

	Parameters
	----------
	model : Torch neural network 
        			the model to train
	test_loader : Torch geometric data loader 
        			test dataset
	loss_function : Torch loss function 
        			loss function for training             
    """    
    l=len(test_loader)
    c=0
    test_loss=0
    test_acc=0
    torch.backends.cudnn.benchmark=True
    device=torch.device("cuda") 

    it=iter(test_loader)
    print('\n')
    print('Evaluating test data','\n')
         
    model.eval()
     
    for i in tqdm.tqdm(range(0,l)):
        data=next(it).to(device)
        c+=len(data)
        output=model(data) 
        loss=loss_function(output, data.y) 
        test_loss+=loss.item()
        pred=output.argmax(dim=1)
        test_acc+=int((pred==data.y).sum())

    print('\n')
    print(' Test_Loss: ',round(test_loss/l,3),' Test_Acc: ',round(test_acc/c,3))
    return round(test_acc/c,3)



#############################################################################################

data_list=pickle.load(open('Data/torch_graph_dataset_pointclouds.txt','rb'))

# vert_stats=[]
# edge_stats=[]

# for G in data_list:
#     vert_stats.append(len(G.x))
#     if len(G.x)==0: print(0)
#     if len(G.edge_index)>0:
#         edge_stats.append(G.edge_index.shape[1])
#     else: print(0)
    

# print('#Vertices: ',np.mean(np.array(vert_stats)),'\n')
# print('#Edges: ',np.mean(np.array(edge_stats)),'\n')


l=5000
s=0.8 # Train/Test split
Batchsize=25  # batchsize
   
num_classes=5 # number of classes
num_node_features=4 # dimension of feature vectors on nodes
slices=10  # number of slices in the graphcodes
GAN_layers=[[20,5],[40,5],[60,5],[80,5]] # parameters for GAT layers
LIN_layers=[1000,500] # parameters for dense layers
reg=[0.7,0.7,0.7] # parameters for dropout layers 
bn=True # if True batch normalization is used
epochs=100 # number of training epochs
edge=True # if False trains without edges
      
torch.backends.cudnn.benchmark=True
device=torch.device("cuda") 



###############################################################################

"""Train and evaluate model in a tensorflow style using the test 
   set as validation set and plot the progression of train and test accuracy"""

# perm=np.random.permutation(l)
# data_list_perm=[] 
# for i in range(l):
#     data_list_perm.append(data_list[perm[i]])
        
# train_dataset=data_list_perm[:int(s*l)]
# test_dataset=data_list_perm[int(s*l):]   

# train_loader=tg.loader.DataLoader(train_dataset, batch_size=Batchsize, shuffle=True)
# test_loader=tg.loader.DataLoader(test_dataset, batch_size=Batchsize, shuffle=False)

# model=GAN(num_node_features=num_node_features,num_classes=num_classes,slices=slices,GAN_layers=GAN_layers,LIN_layers=LIN_layers,reg=reg, bn=bn, edge=edge)
# model.to(device)
# print(model,'\n')

# optimizer=torch.optim.Adam(model.parameters(), lr=0.001)
# loss_function=torch.nn.CrossEntropyLoss()
    
# fit_verbose(model, train_loader, test_loader, optimizer, loss_function, epochs)


###############################################################################

"""Train and evaluate model with base shuffle """

# shuffle_data_sets=20
# train_loaders=[]

# perm=np.random.permutation(l)

# for q in range(shuffle_data_sets):
#     data_list=pickle.load(open('Data/torch_graph_dataset_pointclouds_shuffle_'+str(q+1)+'.txt','rb'))
#     data_list_perm=[]
#     for i in range(l):
#         data_list_perm.append(data_list[perm[i]])
#     train_loaders.append(tg.loader.DataLoader(data_list_perm[1000:5000], batch_size=Batchsize, shuffle=True))
    
# test_loader=tg.loader.DataLoader(data_list_perm[:1000], batch_size=Batchsize, shuffle=False)
    

# model=GAN(num_node_features=num_node_features,num_classes=num_classes,slices=slices,GAN_layers=GAN_layers,LIN_layers=LIN_layers,reg=reg, bn=bn, edge=edge)
# model.to(device)
# print(model,'\n')

# optimizer=torch.optim.Adam(model.parameters(), lr=0.001)
# loss_function=torch.nn.CrossEntropyLoss()
    
# fit_verbose_shuffle(model, train_loaders, test_loader, optimizer, loss_function, epochs)
   
   

###############################################################################

"""Train and evaluate model N times on training and test set respectively
   and compute average test set accuracy"""
   
accs1=[]
accs2=[]
N=20

for j in range(N):
    print('\n')
    print('Run '+str(j+1))
    
    perm=np.random.permutation(l)
        
    data_list_perm=[] 
    for i in range(l):
        data_list_perm.append(data_list[perm[i]])
        
    train_dataset=data_list_perm[:int(s*l)]
    test_dataset=data_list_perm[int(s*l):]

    train_loader=tg.loader.DataLoader(train_dataset, batch_size=Batchsize, shuffle=True)
    test_loader=tg.loader.DataLoader(test_dataset, batch_size=Batchsize, shuffle=False)
    
    model=GAN(num_node_features=num_node_features,num_classes=num_classes,slices=slices,GAN_layers=GAN_layers,LIN_layers=LIN_layers,reg=reg,bn=bn,edge=edge)
    model.to(device)

    optimizer=torch.optim.Adam(model.parameters(), lr=0.001)
    loss_function=torch.nn.CrossEntropyLoss()
        
    fit(model, train_loader, optimizer, loss_function, epochs)
    acc=evaluate(model, test_loader, loss_function)
    accs1.append(acc)
    
    
edge=False 

    
for j in range(N):
    print('\n')
    print('Run '+str(j+1))
    
    perm=np.random.permutation(l)
        
    data_list_perm=[] 
    for i in range(l):
        data_list_perm.append(data_list[perm[i]])
        
    train_dataset=data_list_perm[:int(s*l)]
    test_dataset=data_list_perm[int(s*l):]

    train_loader=tg.loader.DataLoader(train_dataset, batch_size=Batchsize, shuffle=True)
    test_loader=tg.loader.DataLoader(test_dataset, batch_size=Batchsize, shuffle=False)
    
    model=GAN(num_node_features=num_node_features,num_classes=num_classes,slices=slices,GAN_layers=GAN_layers,LIN_layers=LIN_layers,reg=reg,bn=bn,edge=edge)
    model.to(device)

    optimizer=torch.optim.Adam(model.parameters(), lr=0.001)
    loss_function=torch.nn.CrossEntropyLoss()
        
    fit(model, train_loader, optimizer, loss_function, epochs)
    acc=evaluate(model, test_loader, loss_function)
    accs2.append(acc)

accs1=np.array(accs1)
print('With edges','\n')
print(accs1)
print('\n')
print('Mean test accuracy over '+str(N)+' runs: ',round(np.mean(accs1),3),'\n')
print('Standard deviation of test accuracy over '+str(N)+' runs: ',round(np.std(accs1),3),'\n')

accs2=np.array(accs2)
print('Without edges','\n')
print(accs2)
print('\n')
print('Mean test accuracy over '+str(N)+' runs: ',round(np.mean(accs2),3),'\n')
print('Standard deviation of test accuracy over '+str(N)+' runs: ',round(np.std(accs2),3),'\n')



###############################################################################

"""Train and evaluate model N times on training and test set respectively with base shuffle
   and compute average test set accuracy"""
   
# shuffle_data_sets=20
# N=20

# accs=[]
# data_lists=[]

# for i in range(shuffle_data_sets):
#     data_lists.append(pickle.load(open('Data/torch_graph_dataset_pointclouds_shuffle_'+str(i+1)+'.txt','rb')))

# for i in range(N):
#     print('\n')
#     print('Run '+str(i+1))
    
#     train_loaders=[]
#     perm=np.random.permutation(l)

#     for j in range(shuffle_data_sets):
#         data_list_perm=[]
#         for k in range(l):
#             data_list_perm.append(data_lists[j][perm[k]])
#         train_loaders.append(tg.loader.DataLoader(data_list_perm[1000:5000], batch_size=Batchsize, shuffle=True))
        
#     test_loader=tg.loader.DataLoader(data_list_perm[:1000], batch_size=Batchsize, shuffle=False)
    

#     model=GAN(num_node_features=num_node_features,num_classes=num_classes,slices=slices,GAN_layers=GAN_layers,LIN_layers=LIN_layers,reg=reg,bn=bn,edge=edge)
#     model.to(device)

#     optimizer=torch.optim.Adam(model.parameters(), lr=0.001)
#     loss_function=torch.nn.CrossEntropyLoss()
        
#     fit_shuffle(model, train_loaders, optimizer, loss_function, epochs)
#     acc=evaluate(model, test_loader, loss_function)
#     accs.append(acc)

# accs=np.array(accs)
# print('With edges','\n')
# print(accs)
# print('\n')
# print('Mean test accuracy over '+str(N)+' runs: ',round(np.mean(accs),3),'\n')
# print('Standard deviation of test accuracy over '+str(N)+' runs: ',round(np.std(accs),3),'\n')

###############################################################################

"""Parameters: Pointclouds"""

# num_classes=5 
# num_node_features=4 
# slices=10 
# GAN_layers=[[20,5],[40,5],[60,5],[80,5]] 
# LIN_layers=[1000,500] 
# reg=[0.7,0.7,0.7]  
# bn=True 
# epochs=100
# edge=True


"""Parameters: PROTEINS"""

# num_classes=5
# num_node_features=4
# slices=20 
# GAN_layers=[[20,6],[40,6],[60,6],[80,6],[100,6],[120,6]]
# LIN_layers=[1000,500]
# reg=[0.6,0.6,0.6]
# bn=True
# epochs=70
# edge=True


"""Parameters: DHFR"""

# num_classes=5
# num_node_features=4
# slices=20 
# GAN_layers=[[20,6],[40,6],[60,6]]
# LIN_layers=[1000,500]
# reg=[0.3,0.3,0.3]
# bn=True
# epochs=100
# edge=True


"""Parameters: COX2"""

# num_classes=5
# num_node_features=4
# slices=20 
# GAN_layers=[[20,6],[40,6],[60,6],[80,6],[100,6],[120,6]]
# LIN_layers=[1000,500]
# reg=[0.6,0.6,0.6]
# bn=True
# epochs=70
# edge=True


"""Parameters: MUTAG"""

# num_classes=5
# num_node_features=4
# slices=20 
# GAN_layers=[[20,6],[40,6],[60,6],[80,6],[100,6],[120,6]]
# LIN_layers=[1000,500]
# reg=[0.6,0.6,0.6]
# bn=True
# epochs=100
# edge=True


"""Parameters: IMBD"""

# num_classes=5
# num_node_features=4
# slices=20 
# GAN_layers=[[20,6],[40,6],[60,6],[80,6]]
# LIN_layers=[1000,500]
# reg=[0.2,0.2,0.2]
# bn=True
# epochs=100
# edge=True
