import torch
import torch.nn as nn
import copy
from torch_geometric.nn.conv import GINConv
from .convolutions import GRUMLPConv, BFS, GINMLPConv
import torchmetrics
import pytorch_lightning as pl
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.nn.pool import global_mean_pool
import time
criteria = torch.nn.CrossEntropyLoss()

class GIN(pl.LightningModule):

    def init_logging(self, output_dim):
        # Logging
        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=output_dim)
        self.valid_acc = torchmetrics.Accuracy(task="multiclass", num_classes=output_dim)
        self.test_acc  = torchmetrics.Accuracy(task="multiclass", num_classes=output_dim)

        self.train_f1 = torchmetrics.classification.MulticlassF1Score(num_classes=output_dim)
        self.valid_f1 = torchmetrics.classification.MulticlassF1Score(num_classes=output_dim)
        self.test_f1 = torchmetrics.classification.MulticlassF1Score(num_classes=output_dim)

        self.train_graph_acc = torchmetrics.SumMetric()      
        self.valid_graph_acc = torchmetrics.SumMetric()      
        self.test_graph_acc = torchmetrics.SumMetric()      

        self.num_train = torchmetrics.SumMetric()
        self.num_valid = torchmetrics.SumMetric()
        self.num_test  = torchmetrics.SumMetric()

        self.train_loss = torchmetrics.MeanMetric()
        self.valid_loss = torchmetrics.MeanMetric()
        self.test_loss = torchmetrics.MeanMetric()

        self.total_time = torchmetrics.SumMetric()
        self.avg_time   = torchmetrics.MeanMetric()
        self.time       = []
        
        
    def __init__(self, 
                 input_dim, output_dim=1, 
                 num_rounds=1, hidden_dim=3, hidden_state_factor=2, mlp_depth = 1, dropout = 0, normalization = None, activation = None,
                 aggregation = "add", conv = "gru", random_init = False, prediction_mode = "node"):
        super().__init__()
        self.save_hyperparameters()
        print("init model")
        self.init_logging(output_dim)

        self.num_rounds = num_rounds
        self.dropout = dropout

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        self.prediction_mode = prediction_mode
        self.random_init = random_init


        self.encoder = self.get_mlp(input_dim,  hidden_state_factor*hidden_dim, hidden_dim, mlp_depth, normalization, activation, last_relu=True)
        self.decoder = self.get_mlp(hidden_dim, hidden_state_factor*hidden_dim, output_dim, mlp_depth, normalization, activation, last_relu=False)


        edge_mlp    = self.get_mlp(2*hidden_dim, hidden_state_factor * 2 * hidden_dim, hidden_dim, mlp_depth, normalization, activation, last_relu=True)
        state_mlp   = self.get_mlp(hidden_dim,   hidden_state_factor* hidden_dim, hidden_dim, mlp_depth, normalization, activation, last_relu=True)

        num_convs = num_rounds

        if conv == "gin":
            self.Conv   = nn.ModuleList([GINConv(copy.deepcopy(state_mlp), aggr=aggregation) for _ in range(num_convs)])
        else:
            print("Unknown convolution " + conv)
            exit(1)

    def get_mlp(self, input_dim, hidden_dim, output_dim, mlp_depth, normalization, activation, last_relu = True):
        modules = [torch.nn.Linear(input_dim, int(hidden_dim)), normalization(int(hidden_dim)), activation, torch.nn.Dropout(self.dropout)]
        for i in range(0, int(mlp_depth)):
            modules = modules + [torch.nn.Linear(int(hidden_dim), int(hidden_dim)), normalization(int(hidden_dim)), activation, torch.nn.Dropout(self.dropout)]
        modules = modules + [torch.nn.Linear(int(hidden_dim), output_dim)]
        
        if last_relu:
            modules.append(normalization(output_dim))
            modules.append(activation)

        return torch.nn.Sequential(*modules)

    def forward(self, data):
        # Get data and node features
        x, edge_index = data.x, data.edge_index
        x = x.to(torch.float)


        x = self.encoder(x)
        if self.random_init: 
            x = torch.rand(x.shape).to(x.device)

        # do the flooding phases
        for phases in range(self.num_rounds):
            x = self.Conv[phases](x, edge_index)
        # decode the hidden representation to the output
        x = self.decoder(x)
        return x
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=4e-4)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=3, min_lr=1e-6, verbose=True),
                "monitor": "valid_loss_epoch",
                }
            }


    def gin_predict(self, batch):
        if self.prediction_mode == "graph":
            pred = self(batch)
            pred = torch.zeros(batch.num_graphs,self.output_dim).to(batch.ptr.device).scatter_add_(0, batch.batch.unsqueeze(1).expand(-1, self.output_dim).to(batch.ptr.device), pred)
        else:
            #actual loss compuation
            pred = self(batch)
        return pred


    def training_step(self, train_batch, batch_idx):

        pred = self.gin_predict(train_batch)
        loss = criteria(pred, train_batch.y.to(torch.long))

        # compute metrics
        class_pred = torch.argmax(pred,dim=1)
        batch_value = self.train_acc(class_pred, train_batch.y)
        bla = self.train_f1(class_pred, train_batch.y)
        self.train_graph_acc(self.get_graph_accuracy(pred, train_batch.y, train_batch.ptr))
        self.train_loss.update(loss)
        self.num_train(train_batch.ptr.shape[0]-1)

        # logging
        self.log('train_loss_step', loss, prog_bar=True)
        self.log('train_acc_step', batch_value, prog_bar=True)       
        return loss

    def get_graph_accuracy(self, pred, y, ptr, override=False):
        if self.prediction_mode == "graph" or (self.prediction_mode == "node" and not override):
            return -1
        dev = torch.abs(torch.argmax(pred,dim=1) - y)
        cum_sum = torch.cumsum(dev, dim=0)
        sum_between = cum_sum[ptr[1:]-1] - cum_sum[ptr[:-1]]
        acc = torch.sum(sum_between == 0)
        return acc


    def on_train_epoch_end(self):
        self.log('train_acc_epoch', self.train_acc.compute(), prog_bar=True)
        self.log('train_graph_acc_epoch', self.train_graph_acc.compute()/self.num_train.compute(), prog_bar=True)
        self.log('train_loss_epoch', self.train_loss.compute(), prog_bar=True)
        self.log('train_f1_epoch', self.train_f1.compute(), prog_bar=True)
        self.train_acc.reset()
        self.train_f1.reset()
        self.train_loss.reset()
        self.train_graph_acc.reset()
        self.num_train.reset()

    def validation_step(self, val_batch, batch_idx):

        pred = self.gin_predict(val_batch)


        loss = criteria(pred, val_batch.y.to(torch.long))

        # compute metrics
        self.valid_acc.update(torch.argmax(pred,dim=1), val_batch.y)
        self.valid_graph_acc(self.get_graph_accuracy(pred, val_batch.y, val_batch.ptr, override=(self.prediction_mode == "node")))
        self.valid_loss.update(loss)
        self.num_valid(val_batch.ptr.shape[0]-1)

        # logging
        self.log('val_loss_step', loss, prog_bar=True)
        self.log('val_acc_step', self.valid_acc, prog_bar=True)
        return loss

    def on_validation_epoch_end(self):
        self.log('valid_acc_epoch', self.valid_acc.compute(), prog_bar=True)
        self.log('valid_graph_acc_epoch', self.valid_graph_acc.compute()/self.num_valid.compute(), prog_bar=True)
        self.log('valid_loss_epoch', self.valid_loss.compute(), prog_bar=True)
        self.valid_acc.reset()
        self.valid_loss.reset()
        self.valid_graph_acc.reset()
        self.num_valid.reset()

    def test_step(self, test_batch, batch_idx):
        start_time = time.perf_counter()

        pred = self.gin_predict(test_batch)

        end_time = time.perf_counter()

        self.total_time.update(end_time-start_time)
        self.avg_time.update((end_time-start_time))
        self.time.append(end_time-start_time)
        #actual loss compuation
        loss = criteria(pred, test_batch.y.to(torch.long))
        # compute metrics
        self.test_graph_acc(self.get_graph_accuracy(pred, test_batch.y, test_batch.ptr, override=(self.prediction_mode == "node")))
        self.test_acc.update(torch.argmax(pred,dim=1), test_batch.y)
        self.test_loss.update(loss)
        self.num_test(test_batch.ptr.shape[0]-1)

        # logging
        self.log('test_loss', loss, prog_bar=True)
        return loss

    def on_test_epoch_end(self):
        self.log('test_acc', self.test_acc.compute())
        self.log('test_graph_acc', self.test_graph_acc.compute()/self.num_test.compute())
        self.log('test_loss', self.test_loss.compute()/self.num_test.compute())
        self.log('total_time', self.total_time.compute())
        self.log('avg_time', self.avg_time.compute())
        self.log('std_time', torch.std(torch.tensor(self.time)))
        self.time = []
        self.total_time.reset()
        self.avg_time.reset()
        self.test_acc.reset()
        self.test_graph_acc.reset()
        self.test_loss.reset()
        self.num_test.reset()