import torch
import torch.nn as nn
import copy
from torch_geometric.nn.conv import GINConv
from .convolutions import GRUMLPConv, BFS, GINMLPConv
import torchmetrics
import pytorch_lightning as pl
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.nn.pool import global_mean_pool
import time

criteria = torch.nn.CrossEntropyLoss() 

class FloodModel(pl.LightningModule):

    def init_logging(self, output_dim):
        # Logging
        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=output_dim)
        self.valid_acc = torchmetrics.Accuracy(task="multiclass", num_classes=output_dim)
        self.test_acc  = torchmetrics.Accuracy(task="multiclass", num_classes=output_dim)

        self.train_f1 = torchmetrics.classification.MulticlassF1Score(num_classes=output_dim)
        self.valid_f1 = torchmetrics.classification.MulticlassF1Score(num_classes=output_dim)
        self.test_f1 = torchmetrics.classification.MulticlassF1Score(num_classes=output_dim)

        self.train_graph_acc = torchmetrics.SumMetric()      
        self.valid_graph_acc = torchmetrics.SumMetric()      
        self.test_graph_acc = torchmetrics.SumMetric()      

        self.num_train = torchmetrics.SumMetric()
        self.num_valid = torchmetrics.SumMetric()
        self.num_test  = torchmetrics.SumMetric()

        self.train_loss = torchmetrics.MeanMetric()
        self.valid_loss = torchmetrics.MeanMetric()
        self.test_loss = torchmetrics.MeanMetric()

        self.total_time = torchmetrics.SumMetric()
        self.avg_time   = torchmetrics.MeanMetric()
        self.time       = []
        
    def __init__(self, 
                 input_dim, output_dim=1, 
                 num_rounds=1, num_k = 1,
                 hidden_dim=3, hidden_state_factor=2, mlp_depth = 1, dropout = 0, normalization = None, activation = None,
                 aggregation = "add", conv = "gru", prediction_mode = "node", start_mode = "single", start_selection = "random", pool_mode = "sum", train_mode = "single"):
        super().__init__()
        self.save_hyperparameters()
        print("init model")
        self.init_logging(output_dim)

        self.num_rounds = num_rounds
        self.num_k = num_k
        self.dropout = dropout

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim


        self.encoder = self.get_mlp(input_dim,  hidden_state_factor*hidden_dim, hidden_dim, mlp_depth, normalization, activation, last_relu=True)
        self.decoder = self.get_mlp(hidden_dim, hidden_state_factor*hidden_dim, output_dim, mlp_depth, normalization, activation, last_relu=False)


        edge_mlp    = self.get_mlp(2*hidden_dim, hidden_state_factor * 2 * hidden_dim, hidden_dim, mlp_depth, normalization, activation, last_relu=True)
        state_mlp   = self.get_mlp(hidden_dim,   hidden_state_factor* hidden_dim, hidden_dim, mlp_depth, normalization, activation, last_relu=True)

        num_convs = num_rounds

        if conv == "gru": 
            self.DownConv   = nn.ModuleList([GRUMLPConv(emb_dim = hidden_dim, mlp_edge = copy.deepcopy(edge_mlp), aggr=aggregation) for _ in range(num_convs)])
            self.UpConv     = nn.ModuleList([GRUMLPConv(emb_dim = hidden_dim, mlp_edge = copy.deepcopy(edge_mlp), aggr=aggregation) for _ in range(num_convs)])
            self.CrossConv  = nn.ModuleList([GRUMLPConv(emb_dim = hidden_dim, mlp_edge = copy.deepcopy(edge_mlp), aggr=aggregation) for _ in range(num_convs)])
        elif conv == "gin":
            self.DownConv   = nn.ModuleList([GINConv(copy.deepcopy(state_mlp), aggr=aggregation) for _ in range(num_convs)])
            self.UpConv     = nn.ModuleList([GINConv(copy.deepcopy(state_mlp), aggr=aggregation) for _ in range(num_convs)])
            self.CrossConv  = nn.ModuleList([GINConv(copy.deepcopy(state_mlp), aggr=aggregation) for _ in range(num_convs)])
        elif conv == "gin-mlp":
            self.DownConv   = nn.ModuleList([GINMLPConv(mlp=copy.deepcopy(state_mlp),mlp_edge=copy.deepcopy(edge_mlp), aggr=aggregation) for _ in range(num_convs)])
            self.UpConv     = nn.ModuleList([GINMLPConv(mlp=copy.deepcopy(state_mlp),mlp_edge=copy.deepcopy(edge_mlp), aggr=aggregation) for _ in range(num_convs)])

        else:
            print("Unknown convolution " + conv)
            exit(1)

        self.prediction_mode = prediction_mode
        self.start_mode = start_mode
        self.start_selection = start_selection
        self.pool_mode = pool_mode
        self.train_mode = train_mode

    def dist_mask(self, edge_index, dist_edge_index, a, b):
        m1 = (dist_edge_index[0] == a)
        m2 = (dist_edge_index[1] == b)
        m = m1&m2
        ret = torch.masked_select(edge_index,m).view((2,-1))
        return ret

    def masked_update(self, x, x_new, m):
        mask = torch.zeros(x.shape[0],1).to(x.device).index_fill(0, m, 1).view((-1,1))
        inverse_mask = torch.ones(x.shape[0], 1).to(x.device) - mask
        x = mask.float() * x_new + inverse_mask.float() * x
        return x

    def get_mlp(self, input_dim, hidden_dim, output_dim, mlp_depth, normalization, activation, last_relu = True):
        modules = [torch.nn.Linear(input_dim, int(hidden_dim)), normalization(int(hidden_dim)), activation, torch.nn.Dropout(self.dropout)]
        for i in range(0, int(mlp_depth)):
            modules = modules + [torch.nn.Linear(int(hidden_dim), int(hidden_dim)), normalization(int(hidden_dim)), activation, torch.nn.Dropout(self.dropout)]
        modules = modules + [torch.nn.Linear(int(hidden_dim), output_dim)]
        
        if last_relu:
            modules.append(normalization(output_dim))
            modules.append(activation)

        return torch.nn.Sequential(*modules)

    def step(self, edge_index, dist_e_index, a, b, conv, x):
        masked_edges = self.dist_mask(edge_index, dist_e_index, a, b)
        m = masked_edges[1]                
        x_new = conv(x,masked_edges)
        return self.masked_update(x, x_new, m)

    def forward(self, data):
        # Get data and node features
        x, edge_index = data.x, data.edge_index
        x = x.to(torch.float)

        precomp = BFS()
        x_agg = torch.zeros((data.num_nodes, self.hidden_dim, self.num_k)).to(x.device)

        # Choosing correct starting nodes according to different modes
        flood_start = None
        if self.start_mode == "single": # if self.start_mode is all, then the starts are given as fixed always

            if self.start_selection == "random": 
                flood_start = torch.stack([torch.randint(data.ptr[i], data.ptr[i+1], (self.num_k, )) for i in range(data.num_graphs)])
            elif self.start_selection == "fixed" and hasattr(data, "root") and data.root is not None:
                flood_start = data.root + data.ptr[:-1]
                flood_start = flood_start.view((-1,1))
            else:
                print("WARNING: underspecified starting nodes for this task, defaulting to start at 0")
                flood_start = torch.stack([torch.randint(data.ptr[i], data.ptr[i]+1, (self.num_k, )) for i in range(data.num_graphs)])
        elif self.start_mode == "all": # a fixed start should be chosen anyway
            if hasattr(data, "root") and data.root is not None:
                flood_start = data.root + data.ptr[:-1]
                flood_start = flood_start.view((-1,1))
            else:
                print("WARNING: start_mode = all, but no starting nodes defined, defaulting to start at 0")
                flood_start = torch.stack([torch.randint(data.ptr[i], data.ptr[i]+1, (self.num_k, )) for i in range(data.num_graphs)])
        else: 
            print("Unknown start_mode " + self.start_mode)
            exit(1) 

        # Do the flooding phases for each starting node
        for k in range(self.num_k):

            #precompute distances for wave activation pattern
            starts = flood_start[:,k]
            D = precomp(data, starts).view((1,-1))[0]
            dist_e_index = D[edge_index]
            maxD = torch.max(D[D != float('Inf')]).long().item()

            #encode the input
            x_k = self.encoder(x)

            # do the flooding phases
            for phases in range(self.num_rounds):
                # flood down
                for flood in range(maxD):
                    x_k = self.step(edge_index, dist_e_index, flood, flood+1, self.DownConv[phases], x_k)
                    x_k = self.step(edge_index, dist_e_index, flood+1, flood+1, self.CrossConv[phases], x_k)

                # echo back
                for echo in range(maxD,0,-1):
                    x_k = self.step(edge_index, dist_e_index, echo, echo, self.CrossConv[phases], x_k)
                    x_k = self.step(edge_index, dist_e_index, echo, echo-1, self.UpConv[phases], x_k)

            x_agg[:,:,k] = x_k

        # aggregate the results of the k runs to a single hidden representation 
        x_agg = torch.mean(x_agg, 2)        

        # either return all nodes or only start node embeddings (or even pool them to a single graph embedding)
        if self.prediction_mode == "node":
            if self.start_mode == "all":
                x_agg = x_agg[starts]

        if self.prediction_mode == "graph":
            if self.num_k == 1:
                x_agg = x_agg[starts]
            else: 
                print("undefined behaviour for graph prediction with more than a single starting node")
                # should probably either pool all nodes, or pool all starting nodes
                exit(1)

        # decode the hidden representation to the output
        x_agg = self.decoder(x_agg)
        return x_agg
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=4e-4)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=3, min_lr=1e-6, verbose=True),
                "monitor": "valid_loss_epoch",
                }
            }
    def training_step(self, train_batch, batch_idx):

        if self.start_mode == "all":
            if self.train_mode == "single":
                # choose a random node for each graph
                start_node = torch.tensor([torch.randint(train_batch.ptr[i], train_batch.ptr[i+1], (1, )) for i in range(train_batch.num_graphs)])
                start_node = start_node.to(train_batch.ptr.device)
                
                # if groundtruth is not on graph level
                if self.prediction_mode == "node":
                    train_batch.y = train_batch.y[start_node]
                start_node = start_node - train_batch.ptr[:-1]
                train_batch.root = start_node
    
                pred = self(train_batch)

            elif self.train_mode == "all":
                pred = self.inference_all(train_batch)
            else: 
                print("Unknown train_mode " + self.train_mode)
                exit(1)
        else:
            #actual loss compuation
            pred = self(train_batch)
        loss = criteria(pred, train_batch.y.to(torch.long))

        # compute metrics
        class_pred = torch.argmax(pred,dim=1)
        batch_value = self.train_acc(class_pred, train_batch.y)
        bla = self.train_f1(class_pred, train_batch.y)
        self.train_graph_acc(self.get_graph_accuracy(pred, train_batch.y, train_batch.ptr))
        self.train_loss.update(loss)
        self.num_train(train_batch.ptr.shape[0]-1)

        # logging
        self.log('train_loss_step', loss, prog_bar=True)
        self.log('train_acc_step', batch_value, prog_bar=True)       
        return loss

    def get_graph_accuracy(self, pred, y, ptr, override=False):
        if self.prediction_mode == "graph" or (self.prediction_mode == "node" and self.train_mode == "single" and not override):
            return -1
        dev = torch.abs(torch.argmax(pred,dim=1) - y)
        cum_sum = torch.cumsum(dev, dim=0)
        sum_between = cum_sum[ptr[1:]-1] - cum_sum[ptr[:-1]]
        acc = torch.sum(sum_between == 0)
        return acc


    def on_train_epoch_end(self):
        self.log('train_acc_epoch', self.train_acc.compute(), prog_bar=True)
        self.log('train_graph_acc_epoch', self.train_graph_acc.compute()/self.num_train.compute(), prog_bar=True)
        self.log('train_loss_epoch', self.train_loss.compute(), prog_bar=True)
        self.log('train_f1_epoch', self.train_f1.compute(), prog_bar=True)
        self.train_acc.reset()
        self.train_f1.reset()
        self.train_loss.reset()
        self.train_graph_acc.reset()
        self.num_train.reset()

    def inference_all(self,batch):
        # get maximum graph size
        graph_sizes = [batch.ptr[i+1]-batch.ptr[i] for i in range(batch.num_graphs)]
        max_size = max(graph_sizes)

        pred = torch.zeros((batch.x.shape[0], self.output_dim))
        pred = pred.to(batch.ptr.device)

        # calculate predictions for the i-th node in each graph
        for node_i in range(max_size):
            # dont go out of graph 
            start_node = torch.tensor([min(node_i, graph_sizes[j]-1) for j in range(batch.num_graphs)])
            # calculate offset for storing the prediction in the correct place
            offset = torch.tensor([node_i + batch.ptr[i] for i in range(batch.num_graphs)])
            
            # to device and prep datapoint
            start_node = start_node.to(batch.ptr.device)
            offset = offset.to(batch.ptr.device)
            batch.root = start_node
            
            #make prediction
            pred_i = self(batch)
            pred[offset]=pred_i

        if self.prediction_mode == "graph":
            pred = torch.zeros(batch.num_graphs,self.output_dim).to(batch.ptr.device).scatter_add_(0, batch.batch.unsqueeze(1).expand(-1, self.output_dim).to(batch.ptr.device), pred)
        
        return pred

    def validation_step(self, val_batch, batch_idx):

        if self.start_mode == "all": 
            pred = self.inference_all(val_batch)
        else:
            pred = self(val_batch)

        loss = criteria(pred, val_batch.y.to(torch.long))

        # compute metrics
        self.valid_acc.update(torch.argmax(pred,dim=1), val_batch.y)
        self.valid_graph_acc(self.get_graph_accuracy(pred, val_batch.y, val_batch.ptr, override=(self.prediction_mode == "node")))
        self.valid_loss.update(loss)
        self.num_valid(val_batch.ptr.shape[0]-1)

        # logging
        self.log('val_loss_step', loss, prog_bar=True)
        self.log('val_acc_step', self.valid_acc, prog_bar=True)
        return loss

    def on_validation_epoch_end(self):
        self.log('valid_acc_epoch', self.valid_acc.compute(), prog_bar=True)
        self.log('valid_graph_acc_epoch', self.valid_graph_acc.compute()/self.num_valid.compute(), prog_bar=True)
        self.log('valid_loss_epoch', self.valid_loss.compute(), prog_bar=True)
        self.valid_acc.reset()
        self.valid_loss.reset()
        self.valid_graph_acc.reset()
        self.num_valid.reset()

    def test_step(self, test_batch, batch_idx):

        start_time = time.perf_counter()

        if self.start_mode == "all": 
            pred = self.inference_all(test_batch)
        else:
            pred = self(test_batch)

        end_time = time.perf_counter()

        self.total_time.update(end_time-start_time)
        self.avg_time.update((end_time-start_time))
        self.time.append(end_time-start_time)

        #actual loss compuation
        loss = criteria(pred, test_batch.y.to(torch.long))
        # compute metrics
        self.test_graph_acc(self.get_graph_accuracy(pred, test_batch.y, test_batch.ptr, override=(self.prediction_mode == "node")))
        self.test_acc.update(torch.argmax(pred,dim=1), test_batch.y)
        self.test_loss.update(loss)
        self.num_test(test_batch.ptr.shape[0]-1)

        # logging
        self.log('test_loss', loss, prog_bar=True)
        return loss

    def on_test_epoch_end(self):
        self.log('test_acc', self.test_acc.compute())
        self.log('test_graph_acc', self.test_graph_acc.compute()/self.num_test.compute())
        self.log('test_loss', self.test_loss.compute()/self.num_test.compute())
        self.log('total_time', self.total_time.compute())
        self.log('avg_time', self.avg_time.compute())
        self.log('std_time', torch.std(torch.tensor(self.time)))
        self.time = []
        self.total_time.reset()
        self.avg_time.reset()
        self.test_acc.reset()
        self.test_graph_acc.reset()
        self.test_loss.reset()
        self.num_test.reset()