import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
import torch as pt
import torch_geometric as ptg
import pytorch_lightning as ptl
import time

from torch.nn import Linear, Sequential, ReLU, BatchNorm1d, ModuleList, Sigmoid, Tanh, ELU, CELU, Identity, Parameter, LeakyReLU, GELU
from torch.optim import Adam, AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.swa_utils import AveragedModel, SWALR
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GINConv, GCNConv, GATConv, global_mean_pool, global_add_pool
from torch_geometric.nn.norm import BatchNorm, GraphNorm, pair_norm, PairNorm, InstanceNorm
from torch_geometric.nn.inits import zeros, ones, normal
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import to_torch_coo_tensor, scatter, to_scipy_sparse_matrix

from torchmetrics import Accuracy

from lightning.pytorch.callbacks import LearningRateFinder
            
from copy import deepcopy     

from scipy.sparse.linalg import eigs, eigsh    
from scipy.sparse import diags, csr_matrix
            

def np_to_pt(A_list):
    G_pt_list = []
    for A in A_list:
        G = nx.from_numpy_array(A, create_using=nx.DiGraph)
        G_pt = ptg.utils.convert.from_networkx(G)
        X = np.random.normal(size=(A.shape[0],32))
        G_pt.x = pt.tensor(X, dtype=pt.float32)
        G_pt.A = pt.tensor(A, dtype=pt.float32)
        G_pt.edge_weight = pt.tensor([A[e[0], e[1]] for e in G_pt.edge_index.T.clone().detach().numpy()], dtype=pt.float32)
        G_pt_list.append(G_pt)
    return G_pt_list  

class AppendEVs(BaseTransform):
    def __init__(self, num_evs, make_undirected=False, resize_y=False):
        self.num_evs = num_evs -1
        self.v = pt.randn((10000, ))
        self.v = self.v * pt.sign(pt.sum(self.v[0]))
        self.i = 0
        self.make_undirected = make_undirected
        self.resize_y = resize_y
        self.first_EVs = None
        
    def __call__(self, data):
        self.i = self.i + 1
        if self.make_undirected:
            data.edge_index = ptg.utils.to_undirected(data.edge_index)
            
        if self.resize_y:
            data.y = data.y.reshape((-1,))
            
        topology = to_scipy_sparse_matrix(data.edge_index)
        topology = topology.todense()
        topology = np.concatenate((topology, 1e-5 * np.ones((topology.shape[0],1))), axis=1)
        topology = np.concatenate((topology, 1e-5 * np.ones((1,topology.shape[1]))), axis=0)
        topology = csr_matrix(topology, dtype=np.float32)

        gin_conv = deepcopy(topology)
        
        gcn_conv = deepcopy(topology)
        d = gcn_conv @ np.ones(gcn_conv.shape[1])
        d[d == 0] = 1
        d = 1/np.sqrt(d)
        gcn_conv = diags(d) @ gcn_conv @ diags(d)

        gat_conv = deepcopy(topology)
        d = gat_conv @ np.ones(gat_conv.shape[1])
        d = 1/d
        gat_conv = diags(d) @ gat_conv
        
        gnn_operators = [gin_conv, gcn_conv, gat_conv]
        
        for idx, op in enumerate(gnn_operators):
            
            if idx == 0 or idx == 1:
                evalues, evectors = eigs(op, k=min(self.num_evs, op.shape[0]-2), which='LM', tol=1e-8, maxiter=100000)
                evectors = evectors[:, np.argsort(evalues.real)].real
            
                evectors = pt.tensor(evectors, dtype=pt.float32)

                evectors = evectors[:-1] * pt.sign(evectors[-1])

                ones = pt.ones(evectors.shape[0])/np.sqrt(evectors.shape[0])
                evectors = ((evectors.transpose(0,1) @ ones) + 1e-4) * evectors
                rest = ones - pt.sum(evectors, dim=1)
                if idx == 0:
                    data.gin_EVs = pt.cat((evectors, pt.zeros((evectors.shape[0], (self.num_evs - evectors.shape[1]))), rest.reshape(-1,1)), dim=1)
                elif idx == 1:
                    data.gcn_EVs = pt.cat((evectors, pt.zeros((evectors.shape[0], (self.num_evs - evectors.shape[1]))), rest.reshape(-1,1)), dim=1)
            elif idx == 2:
                evalues, evectors = eigs(op, k=min(self.num_evs+1, op.shape[0]-2), which='LM', tol=1e-8,  maxiter=100000)
                evectors = evectors[:, np.argsort(evalues.real)].real
                
                evectors = pt.tensor(evectors, dtype=pt.float32)
                
                evectors = evectors[:-1] * pt.sign(evectors[-1])
                    
                data.gat_EVs = pt.cat((evectors, pt.zeros((evectors.shape[0], (self.num_evs +1 - evectors.shape[1])))), dim=1)

        print(self.i, data.gin_EVs.shape, data.gcn_EVs.shape, data.y.shape)


        return data

class GraphNorm2(pt.nn.Module):
    def __init__(self, num_evs: int, in_channels: int, eps: float = 1e-5):
        super().__init__()

        self.in_channels = in_channels
        self.eps = eps

        self.weight = pt.nn.Parameter(pt.empty(in_channels))
        self.bias = pt.nn.Parameter(pt.empty(in_channels))
        self.EV_scales = pt.nn.Parameter(pt.empty(num_evs, in_channels))

        self.reset_parameters()

    def reset_parameters(self):
        r"""Resets all learnable parameters of the module."""
        ones(self.weight)
        zeros(self.bias)
        normal(self.EV_scales, 0, 0.1)


    def forward(self, x: pt.Tensor, v: pt.Tensor, batch: pt.Tensor) -> pt.Tensor:
        batch_size = int(batch.max()) + 1
        
        contributions = v.transpose(0,1) @ x 
        scaled_contributions = (1 + self.EV_scales) * contributions
        mean = v @ scaled_contributions
        out = x - mean
        
        var = scatter(out.pow(2), batch, 0, batch_size, reduce='mean')
        std = (var + self.eps).sqrt().index_select(0, batch)
        
        return  self.weight * out / std + self.bias

    def __repr__(self):
        return f'{self.__class__.__name__}({self.in_channels})'
    

class GNN_Module(ptl.LightningModule):
    def __init__(self, config):
        super().__init__()

        # hyperparameters
        self.gnn_conv = config['gnn_conv']
        self.embedding_size = config['embedding_size']
        self.num_layers = config['num_layers']
        self.initialization = lambda x : pt.nn.init.normal_(x, mean=1, std=1)
        self.learning_rate = config['learning_rate']
        self.regularization = config['regularization']
        self.num_inputs = config['num_inputs']
        self.num_outputs = config['num_outputs']
        self.num_evs = config['num_evs']
        self.graph_Level = config['graph_level']
        self.loss = pt.nn.CrossEntropyLoss()
        self.out_function = pt.nn.Softmax(dim=1)
        
        if config['activation_func'] == 'relu':
            self.relu = ReLU()
        elif config['activation_func'] == 'leakyrelu':
            self.relu = LeakyReLU()
        elif config['activation_func'] == 'id':
            self.relu = Identity()
        elif config['activation_func'] == 'sigmoid':
            self.relu = Sigmoid()
        elif config['activation_func'] == 'tanh':
            self.relu = Tanh()
        elif config['activation_func'] == 'gelu':
            self.relu = GELU(approximate='tanh')
        else:
            raise ValueError('Activation function not supported. ' + config['activation_func'])    
        
        
        self.dropout = pt.nn.Dropout(config['dropout'])
        
        self.graphnorm_lambda = Parameter(pt.tensor(1.0, device=self.device, dtype=pt.float32))
        self.graphnorm2_lambda = Parameter(pt.ones((self.num_evs, self.embedding_size), device=self.device, dtype=pt.float32, requires_grad=False))
        if self.gnn_conv == 'GIN' or self.gnn_conv == 'GAT':
            
            self.EV_key = 'gin_EVs'
        elif self.gnn_conv == 'GCN':

            self.EV_key = 'gcn_EVs'
        elif self.gnn_conv == 'GAT':
            
            self.EV_key = 'gat_EVs'
        else:
            raise ValueError('GNN type not supported.')

        self.norm_gamma = Parameter(pt.tensor(1.0, device=self.device, dtype=pt.float32))
        self.norm_beta = Parameter(pt.tensor(1.0, device=self.device, dtype=pt.float32))
        self.p_n = PairNorm()
        
        if config['normalization'] == 'batch':
            self.extract_data = lambda x,data : (x, data.batch) 
            self.normalization = ModuleList([InstanceNorm(self.embedding_size, track_running_stats=False) for _ in range(self.num_layers)])
        elif config['normalization'] == 'graph':
            self.extract_data = lambda x,data : (x, data.batch) 
            self.normalization = ModuleList([GraphNorm(self.embedding_size) for _ in range(self.num_layers)])
        elif config['normalization'] == 'graph2':
            #raise NotImplementedError('Graph2.0 not implemented yet.')
            self.extract_data = lambda x,data : (x, data[self.EV_key], data.batch) 
            self.normalization = ModuleList([GraphNorm2(self.num_evs, self.embedding_size) for _ in range(self.num_layers)])
        elif config['normalization'] == 'pair':    
            self.extract_data = lambda x,data : (x, data.batch) 
            self.normalization = ModuleList([PairNorm() for _ in range(self.num_layers)])
        else:
            self.extract_data = lambda x,data : (x, data.batch) 
            self.normalization = [lambda x,y : x]*self.num_layers
            
            
        
        if self.gnn_conv == 'GCN':
            self.convs = ModuleList([GCNConv(self.embedding_size, self.embedding_size, cached=True, normalize=False) for _ in range(self.num_layers)])
        elif self.gnn_conv == 'GIN':
            self.convs = ModuleList([GINConv(Sequential(Linear(self.embedding_size, self.embedding_size),)) for _ in range(self.num_layers)])
        elif self.gnn_conv == 'GAT':
            self.convs = ModuleList([GATConv(self.embedding_size, self.embedding_size, cached=True) for _ in range(self.num_layers)])
        else:
            raise ValueError('GNN type not supported.')
        self.encoder = Linear(self.num_inputs, self.embedding_size, bias=True, device=self.device)
        self.decoder = Linear(self.embedding_size, self.num_outputs, bias=True, device=self.device)
        
        self.reset_parameters()
        
        self.train_mask = []
        self.val_mask = []
        self.test_mask = []

        # metrics
        self.save_hyperparameters(config, logger=False)
        self.train_accuracy = Accuracy(task='multiclass', num_classes=self.num_outputs)
        self.val_accuracy = Accuracy(task='multiclass', num_classes=self.num_outputs)
        self.test_accuracy = Accuracy(task='multiclass', num_classes=self.num_outputs)
        self.reset_parameters()

    def forward(self, data):
        embeddings = pt.empty((data.num_nodes, self.embedding_size), device=self.device)
        
        
        
        if data.x is not None:
            embeddings = self.encoder(data.x)
        else:
            self.initialization(embeddings)
            embeddings.requires_grad = True
        
        
       
        for it in range(self.num_layers):
            embeddings = self.convs[it](embeddings, data.edge_index)
            embeddings = self.relu(embeddings)
            embeddings = self.normalization[it](*self.extract_data(embeddings, data))
            embeddings = self.dropout(embeddings)
        
       
        return self.decoder(embeddings if not self.graph_Level else global_mean_pool(embeddings, data.batch))
        

    def training_step(self, batch):
        # Forward pass
        if self.graph_Level:
            prediction = self(batch)
            y = batch.y
        else:
            prediction = self(batch)[batch.train_mask[0],:]
            y = batch.y[batch.train_mask[0]]
        # Compute loss
        loss = self.loss(prediction, y)
        # Logging
        self.log('train_loss', loss, on_epoch=True, prog_bar=True)
        self.log('train_acc', self.train_accuracy(self.out_function(prediction), y), on_epoch=True, prog_bar=True)
        
        # Loss is passed to backward pass
        return loss
    
    def validation_step(self, batch):
        # Forward pass       
        if self.graph_Level:
            prediction = self(batch)
            y = batch.y
        else:
            prediction = self(batch)[batch.val_mask[0],:]
            y = batch.y[batch.val_mask[0]]

        
        # Compute loss
        loss = self.loss(prediction, y)
        acc = self.val_accuracy(self.out_function(prediction), y)
        score = pt.nn.functional.mse_loss(self.out_function(prediction), pt.nn.functional.one_hot(y, self.num_outputs).float())
        # Logging
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_acc', acc, on_epoch=True, prog_bar=True)
        self.log('val_score', score, on_epoch=True, prog_bar=True)

    
    def test_step(self, batch):
        # Forward pass
        if self.graph_Level:
            prediction = self(batch)
            y = batch.y
        else:
            prediction = self(batch)[batch.test_mask[0],:]
            y = batch.y[batch.test_mask[0]]
        # Compute loss
        loss = self.loss(prediction, y)
        # Logging
        self.log('test_loss', loss)
        self.log('test_acc', self.test_accuracy(self.out_function(prediction), y))

    
    
    def reset_parameters(self):
        for comp in self.convs:
            comp.reset_parameters()
        self.encoder.reset_parameters()
        self.decoder.reset_parameters()
        

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(),
                          lr=self.learning_rate, 
                          weight_decay=self.regularization, 
                          fused=True)
        return {'optimizer':optimizer,
                'lr_scheduler': ReduceLROnPlateau(optimizer,
                                                  factor=0.5, 
                                                  patience=50, 
                                                  verbose=True, 
                                                  mode='min'),
                'monitor':'train_loss'}
                

    def __module_name__(self=None):
        return 'GIN_Module'
    
    

