import torch
from torch import nn
import math
from torch.nn import Transformer
from dataclasses import dataclass
from ..mha import MultiheadAttention
from typing import Literal, Optional, List
from omegaconf import DictConfig


class TransformerBlock(nn.Module):
    def __init__(self, hidden_dim: int, ff_dim: int, num_heads: int, dropout: float):
        super().__init__()
        self.linear1 = nn.Linear(hidden_dim, ff_dim)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(ff_dim, hidden_dim)
        self.norm1 = nn.LayerNorm(hidden_dim)

        self.mha = MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True,
        )
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = nn.GELU()

    def forward(
        self, src, kv=None, key_padding_mask=None, attn_mask=None, is_causal=False
    ):

        x = src
        x = self.norm1(
            x
            + self._sa_block(
                x,
                attn_mask,
                key_padding_mask,
                is_causal=is_causal,
                kv=None if kv is None else kv,
            )
        )
        x = self.norm2(x + self._ff_block(x))

        return x

    def _sa_block(self, x, attn_mask, key_padding_mask, is_causal=False, kv=None):
        x = self.mha(
            x,
            x if kv is None else kv,
            x if kv is None else kv,
            attn_mask=attn_mask,
            key_padding_mask=key_padding_mask,
            need_weights=False,
            is_causal=is_causal,
        )[0]

        return self.dropout1(x)

    def _ff_block(self, x):
        return self.dropout2(
            self.linear2(self.dropout(self.activation(self.linear1(x))))
        )


@dataclass
class Cfg:
    token_dim: int
    hidden_dim: int
    n_layers: int
    max_seq_len: int
    device: str
    mem_len: Optional[int]
    all_cfg: DictConfig


class TransformerModel(nn.Module):
    def __init__(self, cfg: Cfg, tokenizers, head, variant) -> None:
        super().__init__()
        self.cfg = cfg
        assert cfg.all_cfg.algo.token_dim <= cfg.all_cfg.algo.hidden_dim
        print(
            f"token_dim: {cfg.all_cfg.algo.token_dim}:{cfg.token_dim}, hidden_dim: {cfg.all_cfg.algo.hidden_dim}:{cfg.hidden_dim}"
        )
        self.token_dim = cfg.token_dim
        assert variant in ["mem", "ad", "xl", "ed"]
        self.variant: Literal["mem", "ad", "xl", "ed"] = variant
        self.transformer = nn.ModuleList(
            [
                TransformerBlock(
                    hidden_dim=cfg.hidden_dim,
                    # ff_dim=cfg.hidden_dim * 4,
                    ff_dim=cfg.all_cfg.algo.ff_dim,
                    num_heads=cfg.all_cfg.algo.num_heads,
                    dropout=0.1,
                )
                for _ in range(cfg.n_layers)
            ]
        )
        self.proj = nn.Linear(cfg.token_dim, cfg.hidden_dim)
        # self.pre_norm = nn.LayerNorm(cfg.hidden_dim)

        self.tokn = tokenizers
        self.n_step_comp = 4
        self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.token_dim)
        self.kind_emb = nn.Embedding(self.n_step_comp, cfg.token_dim)
        self.head = head
        if self.variant == "mem":
            assert cfg.mem_len != 0
            self.mems = nn.Parameter(
                torch.randn(cfg.mem_len * 4, 4, cfg.hidden_dim), requires_grad=True
            )

        self.post_norm = nn.LayerNorm(cfg.hidden_dim)

    def _init_weights(self, module: nn.Module) -> None:
        """Meant to be used with `gpt.apply(gpt._init_weights)`."""
        # GPT-NeoX  https://arxiv.org/pdf/2204.06745.pdf
        # print module name
        if isinstance(module, nn.Embedding):
            # RWKV: set it to 1e-4
            torch.nn.init.normal_(
                module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / module.weight.size(1))
            )
            # torch.nn.init.normal_(module.weight,  -1e-4, 1e-4)
        elif isinstance(module, nn.Linear):
            # fan-in variance scaling intializer
            torch.nn.init.normal_(
                module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / module.weight.size(1))
            )
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

    def forward(
        self,
        obs: torch.Tensor,
        actions: torch.Tensor,
        rewards: torch.Tensor,
        terms: torch.Tensor,
        _pos: torch.Tensor,
        emb_mems: Optional[List[torch.Tensor]] = None,
    ):
        b, t = rewards.shape[:2]
        T = t * self.n_step_comp

        assert (
            obs.shape[:2]
            == (b, t)
            == actions.shape[:2]
            == terms.shape[:2]
            == _pos.shape[:2]
        )

        _obs_emb, _act_emb, _rwd_emb, _term_emb = self.tokn(
            obs, actions, rewards, terms
        )

        # pos = self.pos_emb(_pos)
        pos = self.pos_emb(_pos)

        seq = (
            torch.stack(
                [_obs_emb + pos, _act_emb + pos, _rwd_emb + pos, _term_emb + pos],
                dim=1,
            )
        ).permute(0, 2, 1, 3)

        kind = self.kind_emb(torch.arange(self.n_step_comp, device=obs.device))

        seq += kind

        seq = seq.reshape(b, T, self.token_dim)

        x = (
            self.proj(seq)
            if self.cfg.all_cfg.algo.token_dim < self.cfg.all_cfg.algo.hidden_dim
            else seq
        )

        # if self.cfg.all_cfg.algo.pre_norm:
        #     x = self.pre_norm(x)

        if self.variant in ["xl", "ed"]:
            _new_emb_mems = [None for _ in range(self.cfg.n_layers)]

        for i, _mod in enumerate(self.transformer):
            if self.variant == "ad":
                x = _mod(
                    x,
                    attn_mask=Transformer.generate_square_subsequent_mask(
                        T, device=x.device, dtype=x.dtype
                    ),
                    is_causal=True,
                )
            elif self.variant == "mem":
                _kv = torch.cat((self.mems[:, i].repeat(b, 1, 1), x), dim=1)
                _src_len = _kv.shape[1]
                _tgt_len = x.shape[1]
                _mem_len = _src_len - _tgt_len
                x = _mod(
                    x,
                    kv=_kv,
                    attn_mask=torch.triu(
                        torch.full(
                            (_tgt_len, _src_len),
                            -torch.inf,
                            device=x.device,
                            dtype=x.dtype,
                        ),
                        _mem_len + 1,
                    ),  # This is not causal!
                )
            else:
                # XL and ED
                assert self.variant in ["xl", "ed"]

                _tgt_len = x.shape[1]
                if self.variant == "xl":
                    _new_emb_mems[i] = x.detach()

                if emb_mems is None:
                    x = _mod(
                        x,
                        attn_mask=Transformer.generate_square_subsequent_mask(
                            T, device=x.device, dtype=x.dtype
                        ),
                        is_causal=True,
                    )

                else:
                    _kv = torch.cat(
                        (
                            emb_mems[i],
                            x,
                        ),
                        dim=1,
                    )
                    _src_len = _kv.shape[1]
                    _mem_len = _src_len - _tgt_len
                    x = _mod(
                        x,
                        kv=_kv,
                        attn_mask=torch.triu(
                            torch.full(
                                (_tgt_len, _src_len),
                                -torch.inf,
                                device=x.device,
                                dtype=x.dtype,
                            ),
                            _mem_len + 1,
                        ),  # This is not causal!
                    )

                if self.variant == "ed":
                    _new_emb_mems[i] = x.detach()

        logits = self.head(self.post_norm(x))

        if self.variant in ["xl", "ed"]:
            return logits, _new_emb_mems

        return logits

    def save(self):

        to_save_dict = {
            "transformer": self.transformer,
            "proj": self.proj,
            "pos": self.pos_emb,
            "tokn": self.tokn,
            # "pre_norm": self.pre_norm,
            "post_norm": self.post_norm,
            "kind": self.kind_emb,
            "head": self.head,
        }
        if self.variant == "mem":
            to_save_dict["mems"] = self.mems

        return to_save_dict

    def load(self, load_dict):

        self.transformer = load_dict["transformer"]
        self.proj = load_dict["proj"]

        self.pos_emb = load_dict["pos"]
        self.tokn = load_dict["tokn"]
        # self.pre_norm = load_dict["pre_norm"]
        self.kind_emb = load_dict["kind"]
        self.head = load_dict["head"]
        self.post_norm = load_dict["post_norm"]

        if self.variant == "mem":
            self.mems = load_dict["mems"]
