import torch
import pytorch_lightning as pl
import torchmetrics

from .convolutions import GRUMLPConv
import copy


criteria = torch.nn.CrossEntropyLoss()


import torch
import torch.nn as nn
import copy
from torch_geometric.nn.conv import GINConv
from .convolutions import GRUMLPConv, BFS, GINMLPConv
import torchmetrics
import pytorch_lightning as pl
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.nn.pool import global_mean_pool
import time

criteria = torch.nn.CrossEntropyLoss()

class RecGNN(pl.LightningModule):


    def init_logging(self, output_dim):
        # Logging
        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=output_dim)
        self.valid_acc = torchmetrics.Accuracy(task="multiclass", num_classes=output_dim)
        self.test_acc  = torchmetrics.Accuracy(task="multiclass", num_classes=output_dim)

        self.train_f1 = torchmetrics.classification.MulticlassF1Score(num_classes=output_dim)
        self.valid_f1 = torchmetrics.classification.MulticlassF1Score(num_classes=output_dim)
        self.test_f1 = torchmetrics.classification.MulticlassF1Score(num_classes=output_dim)

        self.train_graph_acc = torchmetrics.SumMetric()      
        self.valid_graph_acc = torchmetrics.SumMetric()      
        self.test_graph_acc = torchmetrics.SumMetric()      

        self.num_train = torchmetrics.SumMetric()
        self.num_valid = torchmetrics.SumMetric()
        self.num_test  = torchmetrics.SumMetric()

        self.train_loss = torchmetrics.MeanMetric()
        self.valid_loss = torchmetrics.MeanMetric()
        self.test_loss = torchmetrics.MeanMetric()

        self.total_time = torchmetrics.SumMetric()
        self.avg_time   = torchmetrics.MeanMetric()
        self.time       = []
        


    def __init__(self, 
                 input_dim=2, output_dim=1, 
                 num_rounds=1, num_k = 1,
                 hidden_dim=3, hidden_state_factor=2, mlp_depth = 1, dropout = 0, normalization = None, activation = None,
                 aggregation = "add", conv = "gru", prediction_mode = "node"):

        super().__init__()

        self.save_hyperparameters()
        print("init model")
        self.init_logging(output_dim)

        self.num_rounds = num_rounds
        self.dropout = dropout

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
   
             
        self.encoder = self.get_mlp(input_dim,  hidden_state_factor*hidden_dim, hidden_dim, mlp_depth, normalization, activation, last_relu=True)
        self.decoder = self.get_mlp(hidden_dim, hidden_state_factor*hidden_dim, output_dim, mlp_depth, normalization, activation, last_relu=False)


        self.skip_input = self.get_mlp(hidden_dim+input_dim, (hidden_dim+input_dim)*hidden_state_factor, hidden_dim, mlp_depth, normalization, activation, last_relu=True)

        edge_mlp    = self.get_mlp(2*hidden_dim, hidden_state_factor * 2 * hidden_dim, hidden_dim, mlp_depth, normalization, activation, last_relu=True)
        state_mlp   = self.get_mlp(hidden_dim,   hidden_state_factor* hidden_dim, hidden_dim, mlp_depth, normalization, activation, last_relu=True)

        self.conv   = GRUMLPConv(emb_dim = hidden_dim, mlp_edge = copy.deepcopy(edge_mlp), aggr=aggregation)
        self.prediction_mode = prediction_mode


    def forward(self,data):
        x, edge_index = data.x, data.edge_index
        x_orig = x

        x = self.encoder(x)
        iters = int(1.2*len(data.x) / data.num_graphs*1)
        for i in range(iters):
            x = self.skip_input(torch.cat([x, x_orig], dim=1))

            x = self.conv(x, edge_index)

        return self.decoder(x)

    def get_mlp(self, input_dim, hidden_dim, output_dim, mlp_depth, normalization, activation, last_relu = True):
        modules = [torch.nn.Linear(input_dim, int(hidden_dim)), normalization(int(hidden_dim)), activation, torch.nn.Dropout(self.dropout)]
        for i in range(0, int(mlp_depth)):
            modules = modules + [torch.nn.Linear(int(hidden_dim), int(hidden_dim)), normalization(int(hidden_dim)), activation, torch.nn.Dropout(self.dropout)]
        modules = modules + [torch.nn.Linear(int(hidden_dim), output_dim)]
        
        if last_relu:
            modules.append(normalization(output_dim))
            modules.append(activation)

        return torch.nn.Sequential(*modules)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=4e-4)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=3, min_lr=1e-6, verbose=True),
                "monitor": "valid_loss_epoch",
                }
            }
    
    def gin_predict(self, batch):
        if self.prediction_mode == "graph":
            pred = self(batch)
            pred = torch.zeros(batch.num_graphs,self.output_dim).to(batch.ptr.device).scatter_add_(0, batch.batch.unsqueeze(1).expand(-1, self.output_dim).to(batch.ptr.device), pred)
        else:
            #actual loss compuation
            pred = self(batch)
        return pred


    def training_step(self, train_batch, batch_idx):

        pred = self.gin_predict(train_batch)
        loss = criteria(pred, train_batch.y.to(torch.long))

        # compute metrics
        class_pred = torch.argmax(pred,dim=1)
        batch_value = self.train_acc(class_pred, train_batch.y)
        bla = self.train_f1(class_pred, train_batch.y)
        self.train_graph_acc(self.get_graph_accuracy(pred, train_batch.y, train_batch.ptr))
        self.train_loss.update(loss)
        self.num_train(train_batch.ptr.shape[0]-1)

        # logging
        self.log('train_loss_step', loss, prog_bar=True)
        self.log('train_acc_step', batch_value, prog_bar=True)       
        return loss

    def get_graph_accuracy(self, pred, y, ptr, override=False):
        if self.prediction_mode == "graph" or (self.prediction_mode == "node" and not override):
            return -1
        dev = torch.abs(torch.argmax(pred,dim=1) - y)
        cum_sum = torch.cumsum(dev, dim=0)
        sum_between = cum_sum[ptr[1:]-1] - cum_sum[ptr[:-1]]
        acc = torch.sum(sum_between == 0)
        return acc


    def on_train_epoch_end(self):
        self.log('train_acc_epoch', self.train_acc.compute(), prog_bar=True)
        self.log('train_graph_acc_epoch', self.train_graph_acc.compute()/self.num_train.compute(), prog_bar=True)
        self.log('train_loss_epoch', self.train_loss.compute(), prog_bar=True)
        self.log('train_f1_epoch', self.train_f1.compute(), prog_bar=True)
        self.train_acc.reset()
        self.train_f1.reset()
        self.train_loss.reset()
        self.train_graph_acc.reset()
        self.num_train.reset()

    def validation_step(self, val_batch, batch_idx):

        pred = self.gin_predict(val_batch)


        loss = criteria(pred, val_batch.y.to(torch.long))

        # compute metrics
        self.valid_acc.update(torch.argmax(pred,dim=1), val_batch.y)
        self.valid_graph_acc(self.get_graph_accuracy(pred, val_batch.y, val_batch.ptr, override=(self.prediction_mode == "node")))
        self.valid_loss.update(loss)
        self.num_valid(val_batch.ptr.shape[0]-1)

        # logging
        self.log('val_loss_step', loss, prog_bar=True)
        self.log('val_acc_step', self.valid_acc, prog_bar=True)
        return loss

    def on_validation_epoch_end(self):
        self.log('valid_acc_epoch', self.valid_acc.compute(), prog_bar=True)
        self.log('valid_graph_acc_epoch', self.valid_graph_acc.compute()/self.num_valid.compute(), prog_bar=True)
        self.log('valid_loss_epoch', self.valid_loss.compute(), prog_bar=True)
        self.valid_acc.reset()
        self.valid_loss.reset()
        self.valid_graph_acc.reset()
        self.num_valid.reset()

    def test_step(self, test_batch, batch_idx):
        start_time = time.perf_counter()

        pred = self.gin_predict(test_batch)

        end_time = time.perf_counter()

        self.total_time.update(end_time-start_time)
        self.avg_time.update((end_time-start_time))
        self.time.append(end_time-start_time)
        #actual loss compuation
        loss = criteria(pred, test_batch.y.to(torch.long))
        # compute metrics
        self.test_graph_acc(self.get_graph_accuracy(pred, test_batch.y, test_batch.ptr, override=(self.prediction_mode == "node")))
        self.test_acc.update(torch.argmax(pred,dim=1), test_batch.y)
        self.test_loss.update(loss)
        self.num_test(test_batch.ptr.shape[0]-1)

        # logging
        self.log('test_loss', loss, prog_bar=True)
        return loss

    def on_test_epoch_end(self):
        self.log('test_acc', self.test_acc.compute())
        self.log('test_graph_acc', self.test_graph_acc.compute()/self.num_test.compute())
        self.log('test_loss', self.test_loss.compute()/self.num_test.compute())
        self.log('total_time', self.total_time.compute())
        self.log('avg_time', self.avg_time.compute())
        self.log('std_time', torch.std(torch.tensor(self.time)))
        self.time = []
        self.total_time.reset()
        self.avg_time.reset()
        self.test_acc.reset()
        self.test_graph_acc.reset()
        self.test_loss.reset()
        self.num_test.reset()
