from einops import repeat

import torch
from torch import nn
import torch.nn.functional as F

from torch_geometric.graphgym.config import cfg
import torch_geometric.graphgym.register as register


from torch_geometric.graphgym.register import register_network

import warnings
from torch.cuda.amp import autocast

from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_geometric.data import Batch

import numpy as np

from operator import itemgetter

from typing import Dict, List, Optional, Tuple, Union
from collections.abc import Iterable, Mapping
from torchtyping import TensorType

from custom_modules.loss.loss import compute_loss_multi_task
from custom_modules.network.attention import Memoeryeff_Attention, SDPAttention

try:
    import xformers.ops as xops
except ImportError:
    xops = None
import re


class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim=-1)
        return x * F.gelu(gates)


class FeedForward(nn.Module):
    """Feedforward MLP with input normalization"""

    def __init__(self, dim, dim_out=None, mult=4, dropout=0.2):
        super().__init__()
        if dim_out is None:
            self.net = nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, dim * mult * 2),
                GEGLU(),
                nn.Dropout(p=dropout),
                nn.Linear(dim * mult, dim),
            )
        else:
            self.net = nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, dim_out * 2),
                GEGLU(),
                nn.Dropout(p=dropout),
                nn.Linear(dim_out, dim_out),
            )

    def forward(self, x):

        return self.net(x)


class FeatureEmbedding(nn.Module):
    def __init__(
        self, dim_in, dim_latent, hidden_dim, activate_fn, node_feat_encoder_name=None
    ):
        super().__init__()
        self.dim_in = dim_in
        if node_feat_encoder_name is None:
            Node_Feat_Encoder = register.node_encoder_dict[
                cfg.model.node_feat_encoder_name
            ]
        else:
            Node_Feat_Encoder = register.node_encoder_dict[node_feat_encoder_name]
        self.node_feat_encoder = Node_Feat_Encoder(
            dim_in, dim_latent, hidden_dim=hidden_dim, activate_fn=activate_fn
        )

    def forward(self, inp_feat):
        x = self.node_feat_encoder(inp_feat)
        return x


class PosEmbedding(nn.Module):
    def __init__(self, dim_pos_emb=None):
        super().__init__()

        Node_POS_Encoder = register.node_encoder_dict[cfg.model.node_pos_encoder_name]
        self.node_pos_encoder = Node_POS_Encoder(dim_pos_emb)
        self.mesh_pos_encoder = FeedForward(
            cfg.posenc_SignNet.eigen.max_freqs, cfg.posenc_SignNet.dim_pos_emb
        )

    def forward(self, eigenvec, batch_index, edge_index, pos_type, batch_pos_type):
        x = torch.zeros(eigenvec.size(0), cfg.posenc_SignNet.dim_pos_emb).to(
            eigenvec.device
        )
        with autocast(enabled=False):
            for type in torch.unique(pos_type):
                if type == 1:
                    x[pos_type == type] = self.node_pos_encoder(
                        eigenvec[pos_type == type],
                        batch_index[batch_pos_type == type],
                        edge_index=None,
                    )
                else:
                    x[pos_type == type] = self.mesh_pos_encoder(
                        eigenvec[pos_type == type]
                    )
        return x


class EmbeddingWithVocab(nn.Module):
    """Embedding layer with a fixed, stored vocabulary."""

    def __init__(self, vocab: Union[Dict, List], embedding_dim, init_scale=0.02):
        super(EmbeddingWithVocab, self).__init__()

        # Create a mapping from words to indices
        if isinstance(vocab, str):
            raise ValueError("vocab cannot be a single string")
        elif isinstance(vocab, Iterable):
            # OmegaConf wraps the list in omageconf.listconfig.ListConfig
            self.vocab = {word: int(i) for i, word in enumerate(vocab)}
        elif isinstance(vocab, Mapping):
            self.vocab = vocab
        else:
            raise ValueError("vocab must be a list or dict")

        len_vocab = int(max(self.vocab.values()) + 1)

        if "NA" not in self.vocab:
            # Always add a "not available" token
            self.vocab["NA"] = len_vocab

        # Create the reverse mapping from indices to words
        self.reverse_vocab = {i: word for word, i in self.vocab.items()}

        # Create the embedding layer
        self.embedding = nn.Embedding(len_vocab + 1, embedding_dim)
        self.init_scale = init_scale

        # Unfortunately, this hook is private, though there has been a PR to make it
        # public: https://github.com/pytorch/pytorch/issues/75287
        self._register_load_state_dict_pre_hook(
            self._hook_vocab_on_load_state_dict, with_module=False
        )

    def forward(self, tokens):
        # Convert tokens to indices and pass through the embedding layer
        if (
            isinstance(tokens, list)
            and len(tokens) >= 1
            and isinstance(tokens[0], list)
        ):
            indices = [
                [self.vocab[token] for token in token_list] for token_list in tokens
            ]
        else:
            indices = [self.vocab[token] for token in tokens]
        x = (
            torch.Tensor(indices)
            .to(dtype=torch.long)
            .to(device=self.embedding.weight.device)
        )
        return self.embedding(x)

    def state_dict(self, destination=None, prefix="", keep_vars=False):
        state = super(EmbeddingWithVocab, self).state_dict(
            destination=destination, prefix=prefix, keep_vars=keep_vars
        )
        state[prefix + "vocab"] = self.vocab
        return state

    def _hook_vocab_on_load_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        try:
            self.vocab = state_dict.pop(prefix + "vocab")
            self.reverse_vocab = {i: word for word, i in self.vocab.items()}
        except KeyError:
            warnings.warn("Could not find vocab in state_dict. Using existing vocab.")

    def reset_parameters(self) -> None:
        torch.nn.init.normal_(self.embedding.weight, mean=0, std=self.init_scale)
        self.embedding._fill_padding_idx_with_zero()


class Attention(nn.Module):
    """Wrapper around torch.nn.MultiheadAttention"""

    def __init__(self, dim, heads=8, dropout=0.0):
        super().__init__()

        self.norm_q = nn.LayerNorm(dim)
        self.norm_k = nn.LayerNorm(dim)
        self.norm_v = nn.LayerNorm(dim)

        self.use_memory_efficient_attn = cfg.model.use_memory_efficient_attn

        self.attn = nn.MultiheadAttention(
            embed_dim=dim,
            num_heads=heads,
            batch_first=True,
            dropout=dropout,
            bias=False,
        )

    def forward(self, x, context=None, context_mask=None):
        q = self.norm_q(x)
        k = self.norm_k(context if context is not None else x)
        v = self.norm_v(context if context is not None else x)

        if context_mask is not None:
            return self.attn(q, k, v, key_padding_mask=~context_mask)[0]
        else:
            return self.attn(q, k, v)[0]


def gather_feat(k_hop_neigh_idx, feat):
    zeros = torch.zeros(feat.size(0), 1, feat.size(2)).to(feat.device)
    feat = torch.cat([zeros, feat], dim=1)
    k_hop_neigh_flat = k_hop_neigh_idx.view(-1)
    selected_features_flat = torch.index_select(feat, 1, k_hop_neigh_flat)

    k_hop_neigh_feat = selected_features_flat.view(
        feat.size(0), k_hop_neigh_idx.size(0), k_hop_neigh_idx.size(1), feat.size(2)
    )

    return k_hop_neigh_feat


@register_network("GraphFM")
class GraphFM(nn.Module):
    def __init__(
        self,
        dim_in,
        dim_out,
    ):
        super().__init__()

        # Embeddings
        print(dim_in)

        self.pos_emb = PosEmbedding()

        self.dataset_config_list = []
        for dataset_name in cfg.dataset_multi.name_list:
            dataset_cfg = getattr(cfg, dataset_name)
            dataset_cfg.enable = True
            self.dataset_config_list.append(dataset_cfg)

        self.feat_emb = MultiDatasetFeatureProcessor(
            task_specs_list=self.dataset_config_list
        )

        self.latent_dim_ca = (
            cfg.feenc.dim_feat_emb + self.pos_emb.node_pos_encoder.dim_pe
        )
        self.latent_emb = nn.Parameter(
            torch.randn(cfg.model.num_latents, cfg.model.latent_dim)
        )

        self.latent_dim = cfg.model.latent_dim

        self.dropout = nn.Dropout(p=cfg.model.lin_dropout)

        self.enc_atn = Memoeryeff_Attention(
            dim=self.latent_dim_ca,
            q_dim=cfg.model.latent_dim,
            heads=cfg.ca.cross_heads,
            dropout=0,
        )
        self.latent_adapter = nn.Linear(self.latent_dim_ca, cfg.model.latent_dim)

        self.enc_ffn = FeedForward(
            dim=cfg.model.latent_dim, dropout=cfg.model.ffn_dropout
        )

        self.proc_layers = nn.ModuleList([])
        for i in range(cfg.sa.depth):
            self.proc_layers.append(
                nn.ModuleList(
                    [
                        SDPAttention(
                            dim=self.latent_dim,
                            heads=cfg.sa.n_heads,
                            dropout=cfg.model.attn_dropout,
                            use_memory_efficient_attn=True,
                        ),
                        FeedForward(dim=self.latent_dim, dropout=cfg.model.ffn_dropout),
                    ]
                )
            )

        self.task_emb = nn.Parameter(
            torch.randn(1, self.latent_dim + cfg.model.tok_emb_dim)
        )

        self.dec_atn = SDPAttention(
            dim=self.latent_dim + cfg.model.tok_emb_dim,
            heads=cfg.ca.cross_heads,
            dropout=cfg.model.attn_dropout,
            use_memory_efficient_attn=True,
        )
        self.dec_ffn = FeedForward(
            dim=self.latent_dim + cfg.model.tok_emb_dim, dropout=cfg.model.ffn_dropout
        )

        self.decoder_nodes_proj = FeedForward(
            dim=self.latent_dim_ca,
            dim_out=cfg.model.latent_dim,
            dropout=cfg.model.ffn_dropout,
        )

        self.dec_token_type_emb = nn.Embedding(3, cfg.model.tok_emb_dim)
        self.dec_proc_layers = nn.ModuleList([])
        for i in range(cfg.sa.node_decoder.depth):
            self.dec_proc_layers.append(
                nn.ModuleList(
                    [
                        SDPAttention(
                            dim=self.latent_dim + cfg.model.tok_emb_dim,
                            heads=cfg.sa.n_heads,
                            dropout=cfg.model.attn_dropout,
                            use_memory_efficient_attn=True,
                        ),
                        FeedForward(
                            dim=self.latent_dim + cfg.model.tok_emb_dim,
                            dropout=cfg.model.ffn_dropout,
                        ),
                    ]
                )
            )

        self.readout = MultitaskReadout(
            latent_dim=self.latent_dim + cfg.model.tok_emb_dim,
            task_specs_list=self.dataset_config_list,
        )

    def _apply_index(self, batch):
        return batch.graph_feature, batch.y

    def hops_gather_feat(
        self,
        k_hops_idx,
        feat,
        node_dataset_id,
    ):

        zeros = torch.zeros(feat.size(0), 1, feat.size(2)).to(feat.device)
        feat = torch.cat([zeros, feat], dim=1)
        expanded_node_dataset_id = node_dataset_id[:, None].expand_as(k_hops_idx)
        k_hop_neigh_feat = feat[expanded_node_dataset_id, k_hops_idx]

        return k_hop_neigh_feat

    def forward(self, batch):
        x_out_dict = self.feat_emb(batch["x_dict"])

        x_out = torch.cat(
            [x_out_dict[taskname] for taskname in batch["graph_name_both"]], dim=0
        )

        pos_emb = self.pos_emb(
            batch["pos"],
            batch["batch_index"],
            batch["edge_index"],
            batch["pos_type"],
            batch["batch_pos_type"],
        )
        input_tokens = torch.cat((x_out, pos_emb), 1)
        # latents
        b = len(batch["graph_seq_len"])
        latents_ptr = [self.latent_emb.shape[0] for _ in range(b)]
        latents_org = self.latent_emb
        latents = repeat(latents_org, "n d -> (b n) d", b=b)

        ## Network

        # Encode
        latents = latents + self.latent_adapter(
            self.enc_atn(
                latents,
                input_tokens,
                context_mask=[latents_ptr, batch["graph_seq_len"]],
            )
        )

        latents = latents + self.enc_ffn(latents)

        latents = latents.view(b, latents_org.shape[0], latents_org.shape[1])
        # Process
        for self_attn, self_ff in self.proc_layers:
            latents = latents + self.dropout(self_attn(latents))
            latents = latents + self.dropout(self_ff(latents))

        latents = torch.index_select(latents, 0, batch["node_dataset_indices"])
        final_output = torch.empty(latents.shape[0], 576).to(latents.device)
        for key in batch["task_indices"]:
            if key == "graph_classification" and len(batch["task_indices"][key]) > 0:
                op_query = self.task_emb
                op_query = repeat(
                    self.task_emb, "n d -> b n d", b=len(batch["task_indices"][key])
                )
                token_type_emb = self.dec_token_type_emb(batch["token_type_id_graph"])
                latents_graph = latents[batch["task_indices"][key]]
                latents_graph = torch.cat((latents_graph, token_type_emb), dim=-1)

                op_query = op_query + self.dec_atn(op_query, latents_graph)
                op_query = op_query + self.dec_ffn(op_query)
                final_output[batch["task_indices"][key]] = op_query[:, 0, :]

            elif key == "node_classification" and len(batch["task_indices"][key]) > 0:
                latents_node = latents[batch["task_indices"][key]]

                input_neigh = input_tokens[batch["batch_wise_neigh_index"]]

                input_neigh = self.decoder_nodes_proj(input_neigh)
                token_type_emb = self.dec_token_type_emb(batch["token_type_id"])
                decoder_input_tokens = torch.cat((input_neigh, latents_node), dim=1)
                decoder_input_tokens = torch.cat(
                    (decoder_input_tokens, token_type_emb), dim=-1
                )

                output = decoder_input_tokens
                for self_attn, self_ff in self.dec_proc_layers:
                    output = output + self.dropout(self_attn(output))
                    output = output + self.dropout(self_ff(output))
                final_output[batch["task_indices"][key]] = output[:, 0, :]

        loss, losses_taskwise, pred_taskwise, num_elements_taskwise = self.readout(
            latents=final_output,
            output_task_indices=batch["output_task_indices"],
            output_values=batch["output_values"],
            compute_loss=True,
        )
        return (
            batch["output_values"],
            loss,
            losses_taskwise,
            pred_taskwise,
            num_elements_taskwise,
        )


class MultiDatasetFeatureProcessor(nn.Module):
    def __init__(self, task_specs_list):
        super().__init__()

        # Create a bunch of projection layers. One for each task
        self.projections = nn.ModuleDict({})

        for data_cfg in task_specs_list:
            self.projections[f"{data_cfg.dataset_name}"] = FeatureEmbedding(
                data_cfg.feat_dim,
                cfg.feenc.dim_feat_emb,
                data_cfg.hidden_dim,
                data_cfg.activate_fn,
            )

        if cfg.dataset_multi.use_synthetic:
            for syn_cfg in cfg.dataset_multi.syn_graph_config:
                self.projections[f"{syn_cfg['dataset_name']}"] = FeatureEmbedding(
                    syn_cfg["feat_dim"],
                    cfg.feenc.dim_feat_emb,
                    None,
                    None,
                    syn_cfg["node_feat_encoder_name"],
                )

    def forward(self, x_dict):

        x_out = {}
        for taskname, x in x_dict.items():
            pattern = "_num_part_\d+|_graph_\d+"

            taskname_projection = re.sub(pattern, "", taskname)
            x_out[taskname] = self.projections[taskname_projection](x)

        return x_out


class MultitaskReadout(nn.Module):
    def __init__(self, latent_dim: int, task_specs_list):
        super().__init__()

        # Create a bunch of projection layers. One for each task
        self.projections = nn.ModuleDict({})
        task_loss_dict = {}
        for data_cfg in task_specs_list:
            self.projections[
                f"{data_cfg.dataset_name}_{data_cfg.task}_{data_cfg.task_type}"
            ] = nn.Linear(latent_dim, data_cfg.task_dim)

            task_loss_dict[
                f"{data_cfg.dataset_name}_{data_cfg.task}_{data_cfg.task_type}"
            ] = data_cfg.loss_fun

        if cfg.dataset_multi.use_synthetic:
            for syn_cfg in cfg.dataset_multi.syn_graph_config:
                self.projections[
                    f"{syn_cfg['dataset_name']}_{syn_cfg['task']}_{syn_cfg['task_type']}"
                ] = nn.Linear(latent_dim, syn_cfg["task_dim"])

                task_loss_dict[
                    f"{syn_cfg['dataset_name']}_{syn_cfg['task']}_{syn_cfg['task_type']}"
                ] = syn_cfg["loss_fun"]

        self.task_specs_list = task_loss_dict

    def forward(
        self,
        latents,
        output_task_indices,
        output_values,
        compute_loss: bool,
    ):
        """
        Args:
            latents: Outputs of the last transformer layer. These are padded to max_ntout
            output_task_indices: Dictionary keyed by task name.
                For each task key, this contains a tensor of (batch, latent_index) pairs.
                These are locations within the `latents` input corresponding to this task.
                See output_values for how this is used
            output_values: Ground-truth values for loss computation.
                output_values[task][i] is ground truth value for latent index given by
                output_task_indices[task][i]

        """

        # Apply task specific projections

        # Flatten the latents so we can index them linearly

        outputs: Dict[str, TensorType["batch", "*out_dim"]] = {}
        for taskname, linear_indices in output_task_indices.items():
            # Separate this task's latents, and apply projection

            latents_for_this_task = latents[linear_indices]
            outputs[taskname] = self.projections[taskname](latents_for_this_task)

        # Loss computation
        # Only do it if told to, e.g. pure forward-passes will not ask for it
        losses_taskwise = {}
        num_elements_taskwise = {}
        loss = torch.tensor(0, device=latents.device, dtype=torch.float32)
        pred_taskwise = {}

        if not compute_loss:
            return outputs, loss, losses_taskwise

        num_task = 0
        for taskname, output in outputs.items():
            num_task += 1
            target = output_values[taskname]
            loss_fun = self.task_specs_list[taskname]

            (
                losses_taskwise[taskname],
                pred_taskwise[taskname],
            ) = compute_loss_multi_task(output, target, loss_fun)

            # Since we calculate a mean across all elements, scale by the number of
            # items in the batch so we don't get wild swings in loss depending on
            # whether we have large or small numbers of non-dominant classes.
            nbatch_el = len(set(output_task_indices[taskname]))
            # loss = loss + losses_taskwise[taskname]

            loss = loss + losses_taskwise[taskname] * nbatch_el

            losses_taskwise[taskname] = losses_taskwise[taskname] * nbatch_el
            num_elements_taskwise[taskname] = nbatch_el

        loss = loss / latents.shape[0]

        return loss, losses_taskwise, pred_taskwise, num_elements_taskwise
