from functools import wraps
import math
from einops import rearrange, repeat

import torch
from torch import nn, einsum
import torch.nn.functional as F

from torch_geometric.graphgym.config import cfg
import torch_geometric.graphgym.register as register


from torch_geometric.graphgym.register import register_network
from torch_geometric.utils import to_dense_batch
import torch_geometric
import copy

import warnings
from torch.cuda.amp import autocast

from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_geometric.data import Batch

import numpy as np

from operator import itemgetter

from typing import Dict, List, Optional, Tuple, Union
from collections.abc import Iterable, Mapping
from torchtyping import TensorType

from custom_modules.loss.loss import compute_loss_multi_task
import time

try:
    import xformers.ops as xops
except ImportError:
    xops = None


# have a look at this before using
class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim=-1)
        return x * F.gelu(gates)


class FeedForward(nn.Module):
    """Feedforward MLP with input normalization"""

    def __init__(self, dim, dim_out=None, mult=4, dropout=0.2):
        super().__init__()
        if dim_out is None:
            self.net = nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, dim * mult * 2),
                GEGLU(),
                nn.Dropout(p=dropout),
                nn.Linear(dim * mult, dim),
            )
        else:
            self.net = nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, dim_out * 2),
                GEGLU(),
                nn.Dropout(p=dropout),
                nn.Linear(dim_out, dim_out),
            )

    def forward(self, x):

        return self.net(x)


# For node feature embedding - will be different for every dataset
class FeatureEmbedding(nn.Module):
    def __init__(
        self, dim_in, dim_latent, hidden_dim, activate_fn, node_feat_encoder_name=None
    ):
        super().__init__()
        self.dim_in = dim_in
        if node_feat_encoder_name is None:
            Node_Feat_Encoder = register.node_encoder_dict[
                cfg.model.node_feat_encoder_name
            ]
        else:
            Node_Feat_Encoder = register.node_encoder_dict[node_feat_encoder_name]
        self.node_feat_encoder = Node_Feat_Encoder(
            dim_in, dim_latent, hidden_dim=hidden_dim, activate_fn=activate_fn
        )

    def forward(self, inp_feat):
        x = self.node_feat_encoder(inp_feat)
        return x


class PosEmbedding(nn.Module):
    def __init__(self, dim_pos_emb=None):
        super().__init__()

        Node_POS_Encoder = register.node_encoder_dict[cfg.model.node_pos_encoder_name]
        self.node_pos_encoder = Node_POS_Encoder(dim_pos_emb)

    def forward(self, eigenvec, batch_index):
        with autocast(enabled=False):
            x = self.node_pos_encoder(eigenvec, batch_index)
        return x


class EmbeddingWithVocab(nn.Module):
    """Embedding layer with a fixed, stored vocabulary."""

    def __init__(self, vocab: Union[Dict, List], embedding_dim, init_scale=0.02):
        super(EmbeddingWithVocab, self).__init__()

        # Create a mapping from words to indices
        if isinstance(vocab, str):
            raise ValueError("vocab cannot be a single string")
        elif isinstance(vocab, Iterable):
            # OmegaConf wraps the list in omageconf.listconfig.ListConfig
            self.vocab = {word: int(i) for i, word in enumerate(vocab)}
        elif isinstance(vocab, Mapping):
            self.vocab = vocab
        else:
            raise ValueError("vocab must be a list or dict")

        len_vocab = int(max(self.vocab.values()) + 1)

        if "NA" not in self.vocab:
            # Always add a "not available" token
            self.vocab["NA"] = len_vocab

        # Create the reverse mapping from indices to words
        self.reverse_vocab = {i: word for word, i in self.vocab.items()}

        # Create the embedding layer
        self.embedding = nn.Embedding(len_vocab + 1, embedding_dim)
        self.init_scale = init_scale

        # Unfortunately, this hook is private, though there has been a PR to make it
        # public: https://github.com/pytorch/pytorch/issues/75287
        self._register_load_state_dict_pre_hook(
            self._hook_vocab_on_load_state_dict, with_module=False
        )

    def forward(self, tokens):
        # Convert tokens to indices and pass through the embedding layer
        if (
            isinstance(tokens, list)
            and len(tokens) >= 1
            and isinstance(tokens[0], list)
        ):
            indices = [
                [self.vocab[token] for token in token_list] for token_list in tokens
            ]
        else:
            indices = [self.vocab[token] for token in tokens]
        x = (
            torch.Tensor(indices)
            .to(dtype=torch.long)
            .to(device=self.embedding.weight.device)
        )
        return self.embedding(x)

    def state_dict(self, destination=None, prefix="", keep_vars=False):
        state = super(EmbeddingWithVocab, self).state_dict(
            destination=destination, prefix=prefix, keep_vars=keep_vars
        )
        state[prefix + "vocab"] = self.vocab
        return state

    def _hook_vocab_on_load_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        try:
            self.vocab = state_dict.pop(prefix + "vocab")
            self.reverse_vocab = {i: word for word, i in self.vocab.items()}
        except KeyError:
            warnings.warn("Could not find vocab in state_dict. Using existing vocab.")

    def reset_parameters(self) -> None:
        torch.nn.init.normal_(self.embedding.weight, mean=0, std=self.init_scale)
        self.embedding._fill_padding_idx_with_zero()


# class MultiheadAttention_memory_efficient(nn.Module):

#     def __init__(self, embed_dim, num_heads, dropout=0., bias=True,context_dim: int = None,use_memory_efficient_attn: bool = True,) -> None:

#         super().__init__()
#         self.embed_dim = embed_dim
#         self.context_dim = context_dim if context_dim is not None else embed_dim
#         self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim

#         self.num_heads = num_heads
#         self.dropout = dropout
#         self.head_dim = embed_dim // num_heads

#         inner_dim = embed_dim * num_heads
#         assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

#         self.norm = nn.LayerNorm(embed_dim)

#         # calculate query, key, value
#         self.to_q = nn.Linear(embed_dim, inner_dim, bias=False)
#         self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=True)
#         self.to_out = nn.Linear(inner_dim, dim, bias=False)


class Attention(nn.Module):
    """Wrapper around torch.nn.MultiheadAttention"""

    def __init__(self, dim, heads=8, dropout=0.0):
        super().__init__()

        self.norm_q = nn.LayerNorm(dim)
        self.norm_k = nn.LayerNorm(dim)
        self.norm_v = nn.LayerNorm(dim)

        self.use_memory_efficient_attn = cfg.model.use_memory_efficient_attn

        # if use_memory_efficient_attn and xops is None:
        #     print("xformers is not installed, falling back to default attention")
        #     self.use_memory_efficient_attn = False

        self.attn = nn.MultiheadAttention(
            embed_dim=dim,
            num_heads=heads,
            batch_first=True,
            dropout=dropout,
            bias=False,
        )

    def forward(self, x, context=None, context_mask=None):
        q = self.norm_q(x)
        k = self.norm_k(context if context is not None else x)
        v = self.norm_v(context if context is not None else x)

        if context_mask is not None:
            return self.attn(q, k, v, key_padding_mask=~context_mask)[0]
        else:
            return self.attn(q, k, v)[0]


def gather_feat(k_hop_neigh_idx, feat):
    zeros = torch.zeros(feat.size(0), 1, feat.size(2)).to(feat.device)
    feat = torch.cat([zeros, feat], dim=1)
    k_hop_neigh_flat = k_hop_neigh_idx.view(-1)
    selected_features_flat = torch.index_select(feat, 1, k_hop_neigh_flat)

    k_hop_neigh_feat = selected_features_flat.view(
        feat.size(0), k_hop_neigh_idx.size(0), k_hop_neigh_idx.size(1), feat.size(2)
    )

    return k_hop_neigh_feat


from custom_modules.network.attention import Memoeryeff_Attention, SDPAttention


@register_network("PerceiverGraph_MultiDataset_NodeClass")
class PerceiverGraph_MultiDataset(nn.Module):
    def __init__(
        self,
        dim_in,
        dim_out,
    ):
        super().__init__()

        # Embeddings
        print(dim_in)

        self.pos_emb = PosEmbedding()

        self.dataset_config_list = []
        for dataset_name in cfg.dataset_multi.name_list:
            dataset_cfg = getattr(cfg, dataset_name)
            dataset_cfg.enable = True
            self.dataset_config_list.append(dataset_cfg)

        self.feat_emb = MultiDatasetFeatureProcessor(
            task_specs_list=self.dataset_config_list
        )

        self.latent_dim_ca = (
            cfg.feenc.dim_feat_emb + self.pos_emb.node_pos_encoder.dim_pe
        )
        self.latent_emb = nn.Parameter(
            # torch.randn(cfg.model.num_latents, self.latent_dim)
            torch.randn(cfg.model.num_latents, cfg.model.latent_dim)
        )

        self.latent_dim = cfg.model.latent_dim

        self.dropout = nn.Dropout(p=cfg.model.lin_dropout)

        # self.latent_adapter_inv = nn.Linear(cfg.model.latent_dim, self.latent_dim_ca)

        self.enc_atn = Memoeryeff_Attention(
            dim=self.latent_dim_ca,
            q_dim=cfg.model.latent_dim,
            # v_dim=cfg.model.latent_dim,
            heads=cfg.ca.cross_heads,
            dropout=0,
        )
        self.latent_adapter = nn.Linear(self.latent_dim_ca, cfg.model.latent_dim)

        self.enc_ffn = FeedForward(
            dim=cfg.model.latent_dim, dropout=cfg.model.ffn_dropout
        )

        # self.latent_ffn = nn.Linear(cfg.model.latent_dim, self.latent_dim)

        # Processing transfomers (qkv-latent)
        self.proc_layers = nn.ModuleList([])
        for i in range(cfg.sa.depth):
            self.proc_layers.append(
                nn.ModuleList(
                    [
                        SDPAttention(
                            dim=self.latent_dim,
                            heads=cfg.sa.n_heads,
                            dropout=cfg.model.attn_dropout,
                            use_memory_efficient_attn=True,
                        ),
                        FeedForward(dim=self.latent_dim, dropout=cfg.model.ffn_dropout),
                    ]
                )
            )

        if cfg.dataset.task == "graph":
            # Output projection (graph classification)
            self.decoder_out = nn.Linear(self.latent_dim, dim_out)

        elif cfg.dataset.task == "node":

            self.decoder_nodes_proj = FeedForward(
                dim=self.latent_dim_ca,
                dim_out=cfg.model.latent_dim,
                dropout=cfg.model.ffn_dropout,
            )

            # self.dec_token_type_emb = nn.Embedding(3, self.latent_dim)
            self.dec_token_type_emb = nn.Embedding(3, cfg.model.tok_emb_dim)
            self.dec_proc_layers = nn.ModuleList([])
            for i in range(cfg.sa.node_decoder.depth):
                self.dec_proc_layers.append(
                    nn.ModuleList(
                        [
                            SDPAttention(
                                dim=self.latent_dim + cfg.model.tok_emb_dim,
                                heads=cfg.sa.n_heads,
                                dropout=cfg.model.attn_dropout,
                                use_memory_efficient_attn=True,
                            ),
                            FeedForward(
                                dim=self.latent_dim + cfg.model.tok_emb_dim,
                                dropout=cfg.model.ffn_dropout,
                            ),
                        ]
                    )
                )

            self.readout = MultitaskReadout(
                latent_dim=self.latent_dim + cfg.model.tok_emb_dim,
                task_specs_list=self.dataset_config_list,
            )

    def _apply_index(self, batch):
        return batch.graph_feature, batch.y

    def hops_gather_feat(
        self,
        k_hops_idx,
        feat,
        node_dataset_id,
    ):

        zeros = torch.zeros(feat.size(0), 1, feat.size(2)).to(feat.device)
        feat = torch.cat([zeros, feat], dim=1)
        expanded_node_dataset_id = node_dataset_id[:, None].expand_as(k_hops_idx)
        k_hop_neigh_feat = feat[expanded_node_dataset_id, k_hops_idx]

        return k_hop_neigh_feat

    def forward(self, batch):
        # create input embeddings
        # x_dict: {taskname: x}
        x_out_dict = self.feat_emb(batch["x_dict"])

        # todo stack x_out
        x_out = torch.cat(
            [x_out_dict[taskname] for taskname in batch["graph_names"]], dim=0
        )

        pos_emb = self.pos_emb(batch["pos"], batch["batch_index"])

        input_tokens = torch.cat((x_out, pos_emb), 1)

        # latents
        b = len(batch["graph_seq_len"])
        latents_ptr = [self.latent_emb.shape[0] for _ in range(b)]
        latents_org = self.latent_emb
        # latents_org = self.latent_ffn(self.latent_emb)
        latents = repeat(latents_org, "n d -> (b n) d", b=b)

        ## Network

        # latents = self.latent_adapter_inv(latents)

        # Encode
        latents = latents + self.latent_adapter(
            self.enc_atn(
                latents,
                input_tokens,
                context_mask=[latents_ptr, batch["graph_seq_len"]],
            )
        )

        # latents = self.enc_atn(
        #     latents,
        #     input_tokens,
        #     context_mask=[latents_ptr, batch["graph_seq_len"]],
        # )

        # latents = self.latent_adapter(latents)

        latents = latents + self.enc_ffn(latents)

        latents = latents.view(b, latents_org.shape[0], latents_org.shape[1])

        # Process
        for self_attn, self_ff in self.proc_layers:
            latents = latents + self.dropout(self_attn(latents))
            latents = latents + self.dropout(self_ff(latents))

        latents = torch.index_select(latents, 0, batch["node_dataset_indices"])

        input_neigh = input_tokens[
            batch["batch_wise_neigh_index"]
        ]  # (all_nodes, dim) --> (B, 20, dim)

        input_neigh = self.decoder_nodes_proj(input_neigh)

        token_type_emb = self.dec_token_type_emb(batch["token_type_id"])
        decoder_input_tokens = torch.cat((input_neigh, latents), dim=1)
        decoder_input_tokens = torch.cat((decoder_input_tokens, token_type_emb), dim=-1)

        # Output projection
        output = decoder_input_tokens
        for self_attn, self_ff in self.dec_proc_layers:
            output = output + self.dropout(self_attn(output))
            output = output + self.dropout(self_ff(output))

        output = output[:, 0, :]

        loss, losses_taskwise, pred_taskwise, num_elements_taskwise = self.readout(
            latents=output,
            output_task_indices=batch["output_task_indices"],
            output_values=batch["output_values"],
            compute_loss=True,
        )

        # return loss
        return (
            batch["output_values"],
            loss,
            losses_taskwise,
            pred_taskwise,
            num_elements_taskwise,
        )


import re


class MultiDatasetFeatureProcessor(nn.Module):
    def __init__(self, task_specs_list):
        super().__init__()

        # Create a bunch of projection layers. One for each task
        self.projections = nn.ModuleDict({})

        for data_cfg in task_specs_list:
            self.projections[f"{data_cfg.dataset_name}"] = FeatureEmbedding(
                data_cfg.feat_dim,
                cfg.feenc.dim_feat_emb,
                data_cfg.hidden_dim,
                data_cfg.activate_fn,
            )

        if cfg.dataset_multi.use_synthetic:
            for syn_cfg in cfg.dataset_multi.syn_graph_config:
                self.projections[f"{syn_cfg['dataset_name']}"] = FeatureEmbedding(
                    syn_cfg["feat_dim"],
                    cfg.feenc.dim_feat_emb,
                    None,
                    None,
                    syn_cfg["node_feat_encoder_name"],
                )

    def forward(self, x_dict):

        x_out = {}
        for taskname, x in x_dict.items():
            pattern = "_num_part_\d+|_graph_\d+"

            taskname_projection = re.sub(pattern, "", taskname)
            x_out[taskname] = self.projections[taskname_projection](x)
        return x_out


class MultitaskReadout(nn.Module):
    def __init__(self, latent_dim: int, task_specs_list):
        super().__init__()

        # Create a bunch of projection layers. One for each task
        self.projections = nn.ModuleDict({})
        task_loss_dict = {}
        for data_cfg in task_specs_list:
            self.projections[
                f"{data_cfg.dataset_name}_{data_cfg.task}_{data_cfg.task_type}"
            ] = nn.Linear(latent_dim, data_cfg.task_dim)

            task_loss_dict[
                f"{data_cfg.dataset_name}_{data_cfg.task}_{data_cfg.task_type}"
            ] = data_cfg.loss_fun

        if cfg.dataset_multi.use_synthetic:
            for syn_cfg in cfg.dataset_multi.syn_graph_config:
                self.projections[
                    f"{syn_cfg['dataset_name']}_{syn_cfg['task']}_{syn_cfg['task_type']}"
                ] = nn.Linear(latent_dim, syn_cfg["task_dim"])

                task_loss_dict[
                    f"{syn_cfg['dataset_name']}_{syn_cfg['task']}_{syn_cfg['task_type']}"
                ] = syn_cfg["loss_fun"]

        # Need task specs layer to decide loss type
        self.task_specs_list = task_loss_dict

    def forward(
        self,
        latents,
        output_task_indices,
        output_values,
        compute_loss: bool,
    ):
        """
        Args:
            latents: Outputs of the last transformer layer. These are padded to max_ntout
            output_task_indices: Dictionary keyed by task name.
                For each task key, this contains a tensor of (batch, latent_index) pairs.
                These are locations within the `latents` input corresponding to this task.
                See output_values for how this is used
            output_values: Ground-truth values for loss computation.
                output_values[task][i] is ground truth value for latent index given by
                output_task_indices[task][i]

        """

        # Apply task specific projections

        # Flatten the latents so we can index them linearly

        outputs: Dict[str, TensorType["batch", "*out_dim"]] = {}
        for taskname, linear_indices in output_task_indices.items():
            # Separate this task's latents, and apply projection

            latents_for_this_task = latents[linear_indices]
            outputs[taskname] = self.projections[taskname](latents_for_this_task)

        # Loss computation
        # Only do it if told to, e.g. pure forward-passes will not ask for it
        losses_taskwise = {}
        num_elements_taskwise = {}
        loss = torch.tensor(0, device=latents.device, dtype=torch.float32)
        pred_taskwise = {}

        if not compute_loss:
            return outputs, loss, losses_taskwise

        num_task = 0
        for taskname, output in outputs.items():
            num_task += 1
            target = output_values[taskname]
            loss_fun = self.task_specs_list[taskname]

            (
                losses_taskwise[taskname],
                pred_taskwise[taskname],
            ) = compute_loss_multi_task(output, target, loss_fun)

            # Since we calculate a mean across all elements, scale by the number of
            # items in the batch so we don't get wild swings in loss depending on
            # whether we have large or small numbers of non-dominant classes.
            nbatch_el = len(set(output_task_indices[taskname]))
            # loss = loss + losses_taskwise[taskname]

            loss = loss + losses_taskwise[taskname] * nbatch_el

            losses_taskwise[taskname] = losses_taskwise[taskname] * nbatch_el
            num_elements_taskwise[taskname] = nbatch_el

        # weight = F.softmax(torch.randn(num_task), dim=-1) # RLW is only this!
        # loss = torch.sum(loss*weight)

        loss = loss / latents.shape[0]

        # loss = loss
        # loss = losses_taskwise[taskname]

        return loss, losses_taskwise, pred_taskwise, num_elements_taskwise
