import main
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx

import torch as pt
import torch_geometric as ptg
import pytorch_lightning as ptl

from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import global_add_pool
from torch.nn import Linear, Sequential, ReLU, Sigmoid, Softmax, CrossEntropyLoss
from torch.optim import Adam   

from torchmetrics import Accuracy, Precision

from torch_geometric.nn.inits import reset

from typing import Callable, Optional, Union

import torch
from torch import Tensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.utils import spmm

from sklearn.preprocessing import normalize, minmax_scale
from typing import Any


def np_to_pt(A_list):
    G_pt_list = []
    for A in A_list:
        G = nx.from_numpy_array(A, create_using=nx.DiGraph)
        
        G_pt = ptg.utils.convert.from_networkx(G)
        X = np.random.normal(size=(A.shape[0],32))
        G_pt.x = pt.tensor(X, dtype=pt.float32)
        G_pt.A = pt.tensor(A, dtype=pt.float32)
        G_pt.edge_weight = pt.tensor([A[e[0], e[1]] for e in G_pt.edge_index.T.clone().detach().numpy()], dtype=pt.float32)
        G_pt_list.append(G_pt)
    
    return G_pt_list  




def pt_ep_cost(H, A, device=pt.device('cuda')):
    D_minus1 = pt.diag(1/(pt.sum(H, axis=0))).to(device)
    X = A @ H
    
    D_minus1 = pt.nan_to_num(D_minus1, nan=0, posinf=0, neginf=0)
    return pt.sqrt(pt.sum(pt.pow(X - H @ D_minus1 @ H.T @ X,2)))


 
class GINConv(MessagePassing):
    def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False,
                 **kwargs):
        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)
        self.nn = nn
        self.initial_eps = eps
        self.reset_parameters()

    def reset_parameters(self):
        reset(self.nn)
        


    def forward(self, x, edge_index, edge_weight=None):
        # propagate_type: (x: OptPairTensor)
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight)

        x_r = x[1]
        if x_r is not None:
            out = out + x_r

        return self.nn(out)


    def message(self, x_j, x_i, edge_weight=None) -> Tensor:
        if edge_weight is not None:
            m = edge_weight.view(-1, 1) * x_j
        else :
            m = x_j
        return m

    def message_and_aggregate(self, adj_t, x) -> Tensor:
        return spmm(adj_t, x[0], reduce=self.aggr)

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}(nn={self.nn})'
 
 
 
 
 
 
    

class GIN_Module(ptl.LightningModule):
    def __init__(self, config, embedding_size: int = 32, initialization=pt.nn.init.ones_, out_classifier=None, gin_nn=None):
        super().__init__()

        # hyperparameters
        self.embedding_size = config['embedding_size']
        self.num_layers = config['num_layers']
        self.initialization = initialization
        self.learning_rate = config['learning_rate']
        self.regularization = config['regularization']
        self.num_inputs = config['num_inputs']
        self.num_outputs = config['num_outputs']
        self.ep_cost_func = pt_ep_cost
        self.activation_func = Softmax(dim=1)
        self.relu = ReLU()


        # sub-modules
        # Graph Convolution Module for the GIN.
        gin_nn = gin_nn if gin_nn != None else [Sequential(
            Linear(self.embedding_size, self.embedding_size),
            self.relu,
            Linear(self.embedding_size, self.embedding_size)
        ) for i in range(self.num_layers)]
        # Convolutional layer that is different for each layer.
        self.convs = pt.nn.ModuleList([GINConv(gin_nn[i]) for i in range(self.num_layers)])
        # Output MLP
        self.out = out_classifier if out_classifier != None else Sequential(
            Linear(self.embedding_size,self.embedding_size),
            self.relu,
            Linear(self.embedding_size, self.embedding_size),
            self.relu,
            Linear(self.embedding_size, self.num_outputs)
        )
        self.reset_parameters()

        # metrics
        self.train_acc = Accuracy(task='multiclass',num_classes=self.num_outputs)
        self.val_prec = Precision(task='multiclass',num_classes=self.num_outputs)
        self.val_acc = Accuracy(task='multiclass',num_classes=self.num_outputs)
        self.test_prec = Precision(task='multiclass',num_classes=self.num_outputs)
        self.test_acc = Accuracy(task='multiclass',num_classes=self.num_outputs)

        self.save_hyperparameters(config, logger=False)
        self.reset_parameters()

    def forward(self, data):
        embeddings = pt.empty((data.num_nodes, self.embedding_size), device=self.device)
        # Standard initialization is constant
        self.initialization(embeddings)
        # If the GNN is informed, the data is written into the initialisation. (If uninformed, data.x is the all-ones-vector.)
        embeddings[:, :data.x.shape[1]] = data.x
        embeddings.requires_grad = True
        
        # Compute the Graph Convolutions
        for it in range(self.num_layers):
            embeddings = self.relu(self.convs[it](embeddings, data.edge_index, edge_weight=data.edge_weight))

        # return the output of the MLP
        return self.out(embeddings)
        

    def training_step(self, batch):
        # Forward pass
        prediction = self(batch)
        prediction = self.activation_func(prediction)
        # Compute loss
        loss = self.ep_cost_func(prediction, batch.A)
        # Logging
        self.log('train_loss', loss, on_epoch=True, prog_bar=True)
        # Loss is passed to backward pass
        return loss

    def validation_step(self, batch):
        # Forward pass
        prediction = self(batch)
        prediction = self.activation_func(prediction)
        # Compute loss
        loss = self.ep_cost_func(prediction, batch.A)

        
        # Logging
        self.log('val_loss', loss, on_epoch=True)
    
    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
        return self.activation_func(self(batch))
        
    def test_step(self, batch, batch_idx):
        # Forward pass
        prediction = self(batch, batch_idx)
        
        prediction = self.activation_func(prediction)
        # Compute loss
        loss = self.loss_func(prediction, batch.y)

        
        
        # Logging
        self.log('test_loss', loss, on_epoch=True)

    def reset_parameters(self):
        for comp in self.convs:
            comp.reset_parameters()
        pt.nn.init.xavier_normal_(self.out[0].weight)
        pt.nn.init.xavier_normal_(self.out[2].weight)
        pt.nn.init.xavier_normal_(self.out[4].weight)

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.regularization)

    def __module_name__(self=None):
        return 'GIN_Module'
    
    
    
    
    
    
def get_GNN_embedding(A, k):
    ds = np_to_pt([A])
    model_config = {'embedding_size':32,
            'num_layers':1,
            'num_outputs':int(k),
            'num_inputs': 250,
            'learning_rate':0.003,
            'regularization':0.000000001,}
    model = GIN_Module(model_config)
    train_loader = DataLoader(ds, batch_size=1)
    trainer = ptl.Trainer(max_epochs=1000)
    trainer.fit(model, train_loader)
    return trainer.predict(model, train_loader)[0].cpu().detach().numpy()