import torch
import torch.nn as nn
import torch.nn.functional as nnf
import torch_geometric
import torch_geometric.nn as gnn
import torch_geometric.utils as utils
from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader
import lightning as pl
import torchmetrics
import math
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from typing import List
from collections import defaultdict


# from catgnn.gcn_conv import GCNConv

# seed
torch.manual_seed(2023)



def adapt_size_by(tensor: torch.Tensor, *args, dim: int = 0) -> torch.Tensor:
    """Fast tensor resize using native PyTorch operations."""
    new_size = max(arg.shape[dim] for arg in args)
    current_size = tensor.size(dim)

    if new_size > current_size:
        # Fast padding using torch.cat
        pad_shape = list(tensor.shape)
        pad_shape[dim] = new_size - current_size
        return torch.cat(
            [tensor, tensor.new_zeros(pad_shape)], dim=dim  # Uses same device and dtype
        )

    if new_size < current_size:
        # Fast slicing using narrow
        return tensor.narrow(dim, 0, new_size)

    assert new_size == current_size
    return tensor






import torch
import torch.nn as nn
import torch.nn.functional as F

import sys
sys.path.append("models/")


class PathModel(nn.Module):
    EMBED_NUM_MAX_COUNTS = [32, 32, 32, 32, 32, 32]

    def __init__(
        self,
        model_opts={},
    ):
        """
        Args:
            model_opts (dict): A dictionary containing model options.
                - layers (int): Number of layers, including the final mean pooling layer.
                - mlp_layers (int): Number of layers in the MLP.
                - input_dim (int): Number of input features.
                - output_dim (int): Number of output features.
                - width (int or List[int]): Number of hidden units per layer.
                - embed_dim (int): Embedding dimension.
                - final_dropout (float): Dropout rate for the final layer.
        """
        super().__init__()
        self.layers = model_opts.get("layers", 6)
        self.mlp_layers = model_opts.get("mlp_layers", 0)
        self.width = model_opts.get("width", 17)
        self.input_dim = model_opts.get("input_dim", 29)
        self.output_dim = model_opts.get("output_dim", 1)
        self.embed_dim = model_opts.get("embed_dim", 8)
        self.embed_dropout = model_opts.get("embed_dropout", 0.0)
        self.embed_nums = model_opts.get("embed_nums", self.EMBED_NUM_MAX_COUNTS)
        self.final_dropout = model_opts.get("final_dropout", 0.0)
        self.batch_norm = model_opts.get("batch_norm", False)
        self.categorical = model_opts.get("categorical", False)
        self.force_final_reshape = model_opts.get("force_final_reshape", False)

        if isinstance(self.width, list):
            widths = [self.embed_dim] + self.width + [self.output_dim]
        else:
            widths = [self.embed_dim] + [self.width] * (self.layers - 1) + [self.output_dim]

        follow_widths = list(zip(widths[:-1], widths[1:]))

        # create graph conv layers
        self.convs = nn.ModuleList()
        for w0, w1 in follow_widths:
            conv = gnn.GCNConv(w0, w1, normalize=False)
            self.convs.append(conv)

        # create mlp layers
        self.mlps = nn.ModuleList()
        for _, w1 in follow_widths:
            if self.mlp_layers > 0: 
                mlp = nn.Sequential()
                for _ in range(self.mlp_layers-1):
                    mlp.append(nn.Linear(w1, w1, bias=True))
                    mlp.append(nn.ReLU())
                mlp.append(nn.Linear(w1, w1, bias=True))
            else:
                mlp = nn.Identity()
            self.mlps.append(mlp)

        # create batch norm layers
        self.batch_norms = nn.ModuleList()
        for w0, w1 in follow_widths:
            if self.batch_norm:
                self.batch_norms.append(nn.BatchNorm1d(w1))
            else:
                self.batch_norms.append(nn.Identity())

    
        # create embedding dictionaries
        self.embed = nn.ModuleList()
        for max_cnt in self.embed_nums[:self.input_dim]:
            if self.categorical:
                # every dimansion is a categorical feature
                self.embed.append(nn.Embedding(max_cnt, self.embed_dim))
            else:
                # normalize the input
                self.embed = nn.BatchNorm1d(self.embed_dim)

        # create embedding projection layers
        self.embed_linears = nn.ModuleList()
        for _, w1 in follow_widths:
            self.embed_linears.append(nn.Linear(w1+self.embed_dim, w1, bias=False))

        # create final dropout layer
        if self.final_dropout > 0:
            self.final_dropout = nn.Dropout(self.final_dropout)
        else:
            self.final_dropout = nn.Identity()

        if self.embed_dropout > 0:
            self.embed_dropout = nn.Dropout(self.embed_dropout)
        else:
            self.embed_dropout = nn.Identity()


    def __embed(self, x, embed_layer):
        """
        embed a single feature dimension-by-dimension and sum them up
        """
        if self.categorical:
            x = x.long()
            per_feat = [
                el(torch.clamp(x[:, feat_idx], max=el.num_embeddings-1))
                for feat_idx, el in enumerate(embed_layer)
            ]
            r = torch.stack(per_feat, dim=0).mean(dim=0)
        else:
            # embed is a batch norm
            r = self.embed(x)
        return r

    def __embed_linear(self, c, x, embed_layer, linear_layer):
        """
        embed features dimension-by-dimension, sum them up, and apply linear layer
        to shape for the next layer
        """
        c = torch.cat([c, self.__embed(x, embed_layer)], dim=-1)
        c = self.embed_dropout(c)
        return linear_layer(c)

    def forward(self, lx, ledge_index, ledge_attr, batch):
        """
        Args:
            lx (List[torch.Tensor]): List of node features.
            ledge_index (List[torch.Tensor]): List of edge indices.
            ledge_attr (List[torch.Tensor]): List of edge attributes.
            batch (torch.Tensor): Batch indices.
        Note:
            we need to reverse the layers, so that the last one
            is the global-mean-pooling-type of layer
        """

        # print('x', len(lx))
        # print('edge_index', len(ledge_index))
        # print('edge_attr', len(ledge_attr))

        lx = lx[:self.layers+1][::-1]
        ledge_index = ledge_index[:self.layers][::-1]
        ledge_attr = ledge_attr[:self.layers][::-1]

        assert len(self.convs) == len(lx) - 1, "Check if dataset is generated with enough layers"
        assert len(self.convs) == len(ledge_index)
        assert len(self.convs) == len(ledge_attr)

        # add a 'fake feature' 
        # IMDB-BINARY, REDDIT-BINARY
        if lx[0].shape[1] == 0:
            for i in range(len(lx)):
                lx[i] = torch.ones((lx[i].shape[0], 1), dtype=lx[i].dtype)

        layer_setups = zip(
            self.embed_linears, 
            self.convs, 
            self.mlps,
            self.batch_norms,
            ledge_index, 
            ledge_attr
        )

        # run through the layers
        c = self.__embed(lx[0], self.embed)
        for i, layer_setup in enumerate(layer_setups):
            # layer from lx[i] -> lx[i+1]
            (emb_lin, conv, mlp, batch_norm, edge_index, edge_attr) = layer_setup
            is_not_last_layer = i < len(self.convs) - 1

            # log-size edge_attr
            # if i < len(self.convs) - 1:
            #     edge_attr = torch.sign(edge_attr) * torch.log1p(torch.abs(edge_attr))

            c = adapt_size_by(c, lx[i], lx[i+1], dim=0)
            c = conv(c, edge_index, edge_attr)
            out_size = lx[i+1].shape[0]
            c = c[:out_size]

            if is_not_last_layer:
                c = self.__embed_linear(c, lx[i+1], self.embed, emb_lin)
                c = c.relu()
        
            c = mlp(c)
            c = batch_norm(c)

            if is_not_last_layer:
                c = c.relu()    

        c = self.final_dropout(c)
        if self.force_final_reshape:
            c = c.view(-1)
        return c


class TreeModel(nn.Module):

    def __init__(self, model_opts={}):
        super().__init__()

        # -1 we do not count the mean pooling layer
        self.layers = model_opts.get("layers", 6) - 1
        self.width = model_opts.get("width", 17)
        self.input_dim = model_opts.get("input_dim", 1)
        self.embed_dim = model_opts.get("embed_dim", 8)

        widths = [self.embed_dim] + [self.width] * (self.layers-1)
        self.lin = nn.Linear(widths[-1], 1)

        conv_list = []
        for w0, w1 in zip(widths, widths[1:]):
            # normalize=True is the default, 
            # as we do not normalize in datamodule
            conv = gnn.GCNConv(w0, w1, normalize=True)
            conv_list.append(conv)

        self.convs = nn.ModuleList(conv_list)


        embed_list = []
        for w0, w1 in zip(widths, widths[1:]):
            embed = nn.Embedding(32, w0)
            embed_list.append(embed)

        self.embed = nn.ModuleList(embed_list)


    def forward(self, x, edge_index, batch):
        assert x.shape[1] == 1
        t = self.embed[0](x[:, 0])
        for conv in self.convs:
            t = conv(t, edge_index)
            t = t.relu()

        t = torch_geometric.nn.global_mean_pool(t, batch)
        t = self.lin(t)
        return t

class GraphNetModel(pl.LightningModule):

    def __init__(self, task="regression", model_type="path", model_opts={}, optim_opts={}):
        
        super().__init__()

        self.task = task.split(":")
        self.model_type = model_type
        self.model_opts = model_opts
        self.optim_opts = optim_opts

        # save model opts
        self.save_hyperparameters({"model_opts": model_opts})
        self.save_hyperparameters({"optim_opts": optim_opts})
        self.save_hyperparameters({"model_type": model_type})
        self.save_hyperparameters({"task": task})

        if model_type == "path":
            self.model = PathModel(model_opts)
        elif model_type == "tree":
            self.model = TreeModel(model_opts)
        else:
            raise ValueError(f"Invalid model type: {model_type}")
        

        self.num_classes = int(self.task[1]) if len(self.task) > 1 else 2
        self.val_acc = None
        self.val_mae = None
        self.val_auc = None
        self.__setup_metrics(self.task[0])

    def __setup_metrics(self, task_name):
        if task_name == "regression":
            self.val_mae = torchmetrics.MeanAbsoluteError()
            self.criterion = nn.MSELoss()
        elif task_name == "classification":
            self.criterion = nn.CrossEntropyLoss()
            self.val_acc = torchmetrics.Accuracy(
                num_classes=self.num_classes,
                task="multiclass" if self.num_classes > 2 else "binary",
            )
        elif task_name == "auc":
            self.criterion = nn.BCEWithLogitsLoss()
            self.val_auc = torchmetrics.classification.BinaryAUROC()
        elif task_name == "multiclass_auc":
            self.criterion = nn.BCEWithLogitsLoss()
            self.val_acc = torchmetrics.classification.MultilabelAUROC(
                num_labels=self.num_classes,
                average="macro",
            )
                
        else:
            raise ValueError(f"Invalid task: {task_name}")


        # val_mae is the hyperparameter metric
        # we let it know to the logger
        # todo
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), **self.optim_opts)
        return optimizer



    def forward(self, data):
        if self.model_type == "path":
            return self.model(data.x, data.edge_index, data.edge_attr, data.batch)
        elif self.model_type == "tree":
            return self.model(data.x, data.edge_index, data.batch)
        else:
            raise ValueError(f"Invalid model type: {self.model_type}")


    def on_train_start(self):
        self.logger.log_hyperparams(self.model_opts)
        self.logger.log_hyperparams({"model_type": self.model_type})


    def compute_loss(self, y_hat, y):
        if self.task[0] == "regression":
            return self.criterion(y_hat, y)
        elif self.task[0] == "classification":
            # if self.num_classes == 2:
            #     return nnf.binary_cross_entropy_with_logits(y_hat, y.float())
            # else:
            y_hat = y_hat.view(-1, self.num_classes)
            return self.criterion(y_hat, y.long())
        elif self.task[0] == "auc":
            return self.criterion(y_hat, y.float())
        elif self.task[0] == "multiclass_auc":
            y_hat = y_hat.view(-1, self.num_classes)

            # # just on the non-nan values (in training set)
            mask = ~y.isnan()
            y_hat, y = y_hat[mask], y[mask]
            print(y_hat.shape, y.shape)

            return self.criterion(y_hat, y)
        else:
            raise ValueError(f"Invalid task: {self.task}")


    def training_step(self, data, batch_idx):
        y_hat = self.forward(data)
        loss = self.compute_loss(y_hat, data.y)

        self.log(
            "train_loss",
            loss, 
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
            batch_size=data.y.size(0)
        )
        return loss


    def validation_step(self, data, batch_idx):
        y_hat = self.forward(data)
        loss = self.compute_loss(y_hat, data.y)
        self.__log_metrics(loss, y_hat, data.y)


    def __log_metrics(self, loss, y_hat, y):
        self.log(
            "val_loss", 
            loss,
            on_epoch=True,
            prog_bar=True,
            logger=True,
            batch_size=y.size(0)
        )

        if self.task[0] == "regression":
            mae = self.val_mae(y_hat, y)
            self.log(
                "val_mae", 
                mae,
                on_epoch=True,
                prog_bar=True,
                logger=True,
                batch_size=y.size(0)
            )
        elif self.task[0] == "classification":
            preds = torch.argmax(y_hat.view(-1, self.num_classes), dim=1)
            acc = self.val_acc(preds, y)
            self.log(
                "val_acc", 
                acc,
                on_epoch=True,
                prog_bar=True,
                logger=True,
                batch_size=y.size(0)
            )
        elif self.task[0] == "auc":
            auc = self.val_auc(y_hat, y)
            self.log(
                "val_auc", 
                auc,
                on_epoch=True,
                prog_bar=True,
                logger=True,
                batch_size=y.size(0)
            )
        elif self.task[0] == "multiclass_auc":
            # preds = torch.sigmoid(y_hat.view(-1, self.num_classes))
            preds = (y_hat > 0.5).float().view(-1, self.num_classes)
            # # just on the non-nan values (in training set)
            mask = ~y.isnan()
            preds, y = preds[mask], y[mask]

            acc = self.val_acc(preds, y.long())
            self.log(
                "val_acc", 
                acc,
                on_epoch=True,
                prog_bar=True,
                logger=True,
                batch_size=y.size(0)
            )
        else:
            raise ValueError(f"Invalid task: {self.task}")

        return loss