import numpy as np
import torch
import torch.nn as nn
import torch_geometric.data
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.layer import MLP, new_layer_config
from torch_geometric.graphgym.register import register_head

from graphgps.slt.monarch_linear import MonarchLinear
from graphgps.slt.sparse_modules import (
    NMSparseMultiLinear,
    SparseLinear,
    SparseLinearMulti_mask,
)


class CustomSequential(nn.Module):
    def __init__(self, *args):
        super(CustomSequential, self).__init__()
        self.mlpmodules = nn.ModuleList(args)

    def forward(self, x, y=None):
        flag = True
        for module in self.mlpmodules:
            if isinstance(module, nn.ReLU):
                x = module(x)
            else:
                if isinstance(x, torch_geometric.data.Batch) and flag:
                    batch_x = x.x
                    flag = False
                if (
                    cfg.slt.sm is True or cfg.slt.mm is True
                ) and cfg.slt.encoder is True:
                    batch_x = module(batch_x, y)
                else:
                    batch_x = module(batch_x)
        x.x = batch_x
        return x


@register_head("inductive_edge")
class GNNInductiveEdgeHead(nn.Module):
    """GNN prediction head for inductive edge/link prediction tasks.

    Implementation adapted from the transductive GraphGym's GNNEdgeHead.

    Args:
        dim_in (int): Input dimension
        dim_out (int): Output dimension. For binary prediction, dim_out=1.
    """

    def __init__(self, dim_in, dim_out):
        super().__init__()
        if cfg.model.edge_decoding == "concat":
            if cfg.slt.encoder is True:
                layers = []

                if cfg.monarch.encoder:
                    layer_class = MonarchLinear
                elif cfg.slt.srste:
                    layer_class = NMSparseMultiLinear
                elif cfg.slt.sm:
                    layer_class = SparseLinear
                elif cfg.slt.mm:
                    layer_class = SparseLinearMulti_mask

                for _ in range(cfg.gnn.layers_post_mp - 1):
                    layers.append(layer_class(dim_in, dim_in, bias=False))

                layers.append(layer_class(dim_in, dim_out, bias=False))

                self.layer_post_mp = CustomSequential(*layers)
            else:
                self.layer_post_mp = MLP(
                    new_layer_config(
                        dim_in * 2,
                        dim_out,
                        cfg.gnn.layers_post_mp,
                        has_act=False,
                        has_bias=True,
                        cfg=cfg,
                    )
                )

            # requires parameter
            self.decode_module = lambda v1, v2: self.layer_post_mp(
                torch.cat((v1, v2), dim=-1)
            )
        else:
            if dim_out > 1:
                raise ValueError(
                    "Binary edge decoding ({})is used for multi-class "
                    "edge/link prediction.".format(cfg.model.edge_decoding)
                )
            if cfg.slt.encoder is True:
                layers = []

                if cfg.monarch.encoder:
                    layer_class = MonarchLinear
                elif cfg.slt.srste:
                    layer_class = NMSparseMultiLinear
                elif cfg.slt.sm:
                    layer_class = SparseLinear
                elif cfg.slt.mm:
                    layer_class = SparseLinearMulti_mask

                for _ in range(cfg.gnn.layers_post_mp):
                    layers.append(layer_class(dim_in, dim_in, bias=False))

                self.layer_post_mp = CustomSequential(*layers)
            else:
                self.layer_post_mp = MLP(
                    new_layer_config(
                        dim_in,
                        dim_in,
                        cfg.gnn.layers_post_mp,
                        has_act=False,
                        has_bias=True,
                        cfg=cfg,
                    )
                )

            if cfg.model.edge_decoding == "dot":
                self.decode_module = lambda v1, v2: torch.sum(v1 * v2, dim=-1)
            elif cfg.model.edge_decoding == "cosine_similarity":
                self.decode_module = nn.CosineSimilarity(dim=-1)
            else:
                raise ValueError(
                    f"Unknown edge decoding {cfg.model.edge_decoding}."
                )

    def _apply_index(self, batch):
        return batch.x[batch.edge_index_labeled], batch.edge_label

    def forward(
        self,
        batch,
        cur_epoch=None,
        mpnn_th=None,
        msa_th=None,
        ffn_th=None,
        encoder_th=None,
        pred_th=None,
        global_th=None,
    ):
        if cfg.model.edge_decoding != "concat":
            if (
                cfg.slt.sm is True or cfg.slt.mm is True
            ) and cfg.slt.encoder is True:
                if cfg.slt.pruning == "blockwise":
                    batch = self.layer_post_mp(batch, pred_th)
                elif cfg.slt.pruning == "global":
                    batch = self.layer_post_mp(batch, global_th)
            else:
                batch = self.layer_post_mp(batch)

        pred, label = self._apply_index(batch)
        nodes_first = pred[0]
        nodes_second = pred[1]
        pred = self.decode_module(nodes_first, nodes_second)
        if not self.training:  # Compute extra stats when in evaluation mode.
            stats = self.compute_mrr(batch)
            return pred, label, stats
        else:
            return pred, label

    def compute_mrr(self, batch):
        if cfg.model.edge_decoding != "dot":
            raise ValueError(
                f"Unsupported edge decoding {cfg.model.edge_decoding}."
            )

        stats = {}
        for data in batch.to_data_list():
            # print(data.num_nodes)
            # print(data.edge_index_labeled)
            # print(data.edge_label)
            pred = data.x @ data.x.transpose(0, 1)
            # print(pred.shape)

            pos_edge_index = data.edge_index_labeled[:, data.edge_label == 1]
            num_pos_edges = pos_edge_index.shape[1]
            # print(pos_edge_index, num_pos_edges)

            pred_pos = pred[pos_edge_index[0], pos_edge_index[1]]
            # print(pred_pos)

            if num_pos_edges > 0:
                neg_mask = torch.ones(
                    [num_pos_edges, data.num_nodes], dtype=torch.bool
                )
                neg_mask[torch.arange(num_pos_edges), pos_edge_index[1]] = (
                    False
                )
                pred_neg = pred[pos_edge_index[0]][neg_mask].view(
                    num_pos_edges, -1
                )
                # print(pred_neg, pred_neg.shape)
                mrr_list = self._eval_mrr(pred_pos, pred_neg, "torch")
            else:
                # Return empty stats.
                mrr_list = self._eval_mrr(pred_pos, pred_pos, "torch")

            # print(mrr_list)
            for key, val in mrr_list.items():
                if key.endswith("_list"):
                    key = key[: -len("_list")]
                    val = float(val.mean().item())
                if np.isnan(val):
                    val = 0.0
                if key not in stats:
                    stats[key] = [val]
                else:
                    stats[key].append(val)
                # print(key, val)
            # print('-' * 80)

        # print('=' * 80, batch.split)
        batch_stats = {}
        for key, val in stats.items():
            mean_val = sum(val) / len(val)
            batch_stats[key] = mean_val
            # print(f"{key}: {mean_val}")
        return batch_stats

    def _eval_mrr(self, y_pred_pos, y_pred_neg, type_info):
        """Compute Hits@k and Mean Reciprocal Rank (MRR).

        Implementation from OGB:
        https://github.com/snap-stanford/ogb/blob/master/ogb/linkproppred/evaluate.py

        Args:
            y_pred_neg: array with shape (batch size, num_entities_neg).
            y_pred_pos: array with shape (batch size, )
        """

        if type_info == "torch":
            y_pred = torch.cat([y_pred_pos.view(-1, 1), y_pred_neg], dim=1)
            argsort = torch.argsort(y_pred, dim=1, descending=True)
            ranking_list = torch.nonzero(argsort == 0, as_tuple=False)
            ranking_list = ranking_list[:, 1] + 1
            hits1_list = (ranking_list <= 1).to(torch.float)
            hits3_list = (ranking_list <= 3).to(torch.float)
            hits10_list = (ranking_list <= 10).to(torch.float)
            mrr_list = 1.0 / ranking_list.to(torch.float)

            return {
                "hits@1_list": hits1_list,
                "hits@3_list": hits3_list,
                "hits@10_list": hits10_list,
                "mrr_list": mrr_list,
            }

        else:
            y_pred = np.concatenate(
                [y_pred_pos.reshape(-1, 1), y_pred_neg], axis=1
            )
            argsort = np.argsort(-y_pred, axis=1)
            ranking_list = (argsort == 0).nonzero()
            ranking_list = ranking_list[1] + 1
            hits1_list = (ranking_list <= 1).astype(np.float32)
            hits3_list = (ranking_list <= 3).astype(np.float32)
            hits10_list = (ranking_list <= 10).astype(np.float32)
            mrr_list = 1.0 / ranking_list.astype(np.float32)

            return {
                "hits@1_list": hits1_list,
                "hits@3_list": hits3_list,
                "hits@10_list": hits10_list,
                "mrr_list": mrr_list,
            }
