import torch
import torch.nn as nn
import numpy as np

# from networks.layers import *
import torch.nn.functional as F
import clip
from einops import rearrange, repeat
import math
from random import random
from tqdm.auto import tqdm
from typing import Callable, Optional, List, Dict
from copy import deepcopy
from functools import partial

from mGPT.archs.tools.momask import (
    cal_performance,
    cosine_schedule,
    eval_decorator,
    get_mask_subset_prob,
    gumbel_sample,
    lengths_to_mask,
    q_schedule,
    top_k,
    uniform,
)
from torch.distributions.categorical import Categorical
from transformers import T5ForConditionalGeneration, T5Tokenizer


from .mstream_base import MaskTransformerBase

from .modules_dit_rope import FinalLayer, TimestepEmbedder


class InputProcess(nn.Module):
    def __init__(self, input_feats, latent_dim):
        super().__init__()
        self.input_feats = input_feats
        self.latent_dim = latent_dim
        self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim)

    def forward(self, x):
        # [bs, ntokens, input_feats]
        # x = x.permute((1, 0, 2))  # [seqen, bs, input_feats]
        # print(x.shape)
        # x = self.poseEmbedding(x)  # [seqlen, bs, d]
        x = self.poseEmbedding(x)  # [bs, seqlen, d]
        return x


class EmbeddingLayer(nn.Module):
    def __init__(self, dim, vocab_dim):
        """
        Mode arg: 0 -> use a learned layer, 1 -> use eigenvectors,
        2-> add in eigenvectors, 3 -> use pretrained embedding matrix
        """
        super().__init__()
        self.embedding = nn.Parameter(torch.empty((vocab_dim, dim)))
        torch.nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5))

    def forward(self, x):
        return self.embedding[x]


class PositionalEncoding(nn.Module):
    # Borrow from MDM, the same as above, but add dropout, exponential may improve precision
    def __init__(self, d_model, dropout=0.1, max_len=5000, prob=0.0):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.prob = prob

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)  # [max_len, 1, d_model]

        self.register_buffer("pe", pe)

    def forward(self, x):
        # not used in the final model
        if random() < self.prob:
            return self.dropout(x)
        x = x + self.pe[: x.shape[0], :]
        return self.dropout(x)


class OutputProcess(nn.Module):
    def __init__(self, out_feats, latent_dim):
        super().__init__()
        self.dense = nn.Linear(latent_dim, latent_dim)
        self.transform_act_fn = F.gelu
        self.LayerNorm = nn.LayerNorm(latent_dim, eps=1e-12)
        self.poseFinal = nn.Linear(latent_dim, out_feats)  # Bias!

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        output = self.poseFinal(hidden_states)  # [seqlen, bs, out_feats]
        # output = output.permute(1, 2, 0)  # [bs, e, seqlen]
        output = output.permute(0, 2, 1)  # [bs, c, seqlen]
        return output


class OutputProcess_DiT(nn.Module):
    def __init__(self, out_feats, latent_dim, cond_dim):
        super().__init__()
        self.final_layer = FinalLayer(out_feats, latent_dim, cond_dim)

    def forward(self, hidden_states: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
        output = self.final_layer(hidden_states, c)  # [bs, seqlen, c]
        output = output.permute(0, 2, 1)  # [bs, c, seqlen]
        return output


class MaskTransformer(MaskTransformerBase):
    def __init__(
        self,
        code_dim,
        cond_mode,
        num_tokens,
        num_quantizers,
        latent_dim=256,
        num_layers=8,
        num_heads=4,
        dropout=0.1,
        mlp_ratio=4.0,
        clip_dim=512,
        cond_drop_prob=0.1,
        clip_version=None,
        activation="gelu",
        poe_type="absolute",
        cond_emb_type="clip",
        cond_in_type="default",
        dit_type="rope",
        out_type="default",
        emb_type="default",
        qk_norm=True,
        cond_time=True,
        t5_path=None,
        timesetps=10,
        **kargs,
    ):
        super().__init__()

        self.code_dim = code_dim
        self.latent_dim = latent_dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.clip_dim = clip_dim
        self.dropout = dropout
        self.num_quantizers = num_quantizers
        self.cond_mode = cond_mode
        self.cond_drop_prob = cond_drop_prob
        self.num_tokens = num_tokens
        self.poe_type = poe_type
        self.out_type = out_type
        self.cond_emb_type = cond_emb_type
        self.cond_time = cond_time
        self.emb_type = emb_type
        self.timesteps = timesetps

        if self.cond_mode == "action":
            assert "num_actions" in kargs
        self.num_actions = kargs.get("num_actions", 1)

        if dit_type == "default":
            from .modules_dit import DiTBlock
        elif dit_type == "rope":
            from .modules_dit_rope import DiTBlock
        elif dit_type == "hunyuan":
            from .modules_dit_hunyuan import DiTBlock
        else:
            raise KeyError("Unsupported DiT type!!!")

        _num_tokens = (
            num_tokens + 2
        )  # two dummy tokens, one for masking, one for padding
        self.mask_id = num_tokens
        self.pad_id = num_tokens + 1

        """
        Embedding layers
        """

        if self.emb_type == "default":
            self.token_emb = nn.Embedding(_num_tokens, self.code_dim)
            self.input_process = InputProcess(self.code_dim, self.latent_dim)
        elif self.emb_type == "ddit":
            self.token_emb = EmbeddingLayer(self.latent_dim, _num_tokens)
            self.input_process = nn.Identity()
        else:
            raise KeyError("Unsupported embedding type!!!")

        """
        Positional Encoding
        """

        poe_args = {}
        if self.poe_type == "absolute":
            self.position_enc = PositionalEncoding(self.latent_dim, self.dropout)
            poe_args = {"apply_rotary_pos_emb": False}
        elif self.poe_type == "rotary":
            self.position_enc = nn.Dropout(self.dropout)
        elif self.poe_type == "abs+rotary":
            self.position_enc = PositionalEncoding(
                self.latent_dim, self.dropout, prob=0.5
            )
        else:
            raise KeyError("Unsupported positional encoding type!!!")

        """
        Conditional Embedding
        """

        self.encode_action = partial(F.one_hot, num_classes=self.num_actions)

        if cond_in_type in ["default", "cat"]:
            if self.cond_mode == "text":
                self.cond_emb = nn.Linear(self.clip_dim, self.latent_dim)
            elif self.cond_mode == "action":
                self.cond_emb = nn.Linear(self.num_actions, self.latent_dim)
            elif self.cond_mode == "uncond":
                self.cond_emb = nn.Identity()
            else:
                raise KeyError("Unsupported condition mode!!!")
            clip_dim = self.latent_dim
        elif cond_in_type == "none":
            self.cond_emb = nn.Identity()
        else:
            raise KeyError("Unsupported condition embedding type!!!")

        if self.cond_time:
            self.t_embedder = TimestepEmbedder(latent_dim)

        """
        Transformer layers
        """

        seqTransEncoderLayer = nn.TransformerEncoderLayer(
            d_model=self.latent_dim,
            nhead=num_heads,
            dim_feedforward=1024,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
        )

        self.seqTransEncoder = nn.TransformerEncoder(
            seqTransEncoderLayer, num_layers=num_layers
        )

        """
        Output Process
        """

        if out_type == "default":
            self.output_process = OutputProcess(
                out_feats=num_tokens, latent_dim=latent_dim
            )
        elif out_type == "dit":
            self.output_process = OutputProcess_DiT(
                out_feats=num_tokens, latent_dim=latent_dim, cond_dim=clip_dim
            )
        else:
            raise KeyError("Unsupported output type!!!")

        self.__init_weights()

        """
        Preparing frozen weights
        """

        if self.cond_mode == "text":
            self.clip_version = clip_version
            if cond_emb_type == "clip":
                self.clip_model = self.load_and_freeze_clip(clip_version)
            elif cond_emb_type == "t5":
                self.t5_model = self.load_and_freeze_t5(t5_path)
            else:
                raise KeyError("Unsupported condition embedding type!!!")

        self.noise_schedule = cosine_schedule

    def trans_forward(self, motion_ids, cond, padding_mask, timestep, force_mask=False):
        """
        :param motion_ids: (b, seqlen)
        :padding_mask: (b, seqlen), all pad positions are TRUE else FALSE
        :param cond: (b, embed_dim) for text, (b, num_actions) for action
        :param force_mask: boolean
        :return:
            -logits: (b, num_token, seqlen)
        """

        cond = self.mask_cond(cond, force_mask=force_mask)

        x = self.token_emb(motion_ids)
        # (b, seqlen, d) -> (b, seqlen,latent_dim)
        x = self.input_process(x)

        """ Conditional Embedding """

        cond = self.cond_emb(cond)  # (b, latent_dim)

        if self.cond_time:
            cond = cond + self.t_embedder(timestep)

        x = self.position_enc(x)
        if self.cond_emb_type == "cat":
            x = torch.cat([cond.unsqueeze(1), x], dim=1)  # (seqlen+1, b, latent_dim)
            padding_mask = torch.cat(
                [torch.zeros_like(padding_mask[:, 0:1]), padding_mask], dim=1
            )  # (b, seqlen+1)

        # for block in self.blocks:
        #     x = block(x, cond, key_padding_mask=padding_mask)
        x = self.seqTransEncoder(x, src_key_padding_mask=padding_mask)

        if self.cond_emb_type == "cat":
            x = x[:, 1:]  # (seqlen, b, latent_dim)

        if self.out_type == "default":
            logits = self.output_process(x)
        elif self.out_type == "dit":
            logits = self.output_process(x, cond)

        # logits = self.output_process(x)  # (seqlen, b, e) -> (b, ntoken, seqlen)
        return logits

    def __init_weights(self):
        # if isinstance(module, (nn.Linear, nn.Embedding)):
        #     module.weight.data.normal_(mean=0.0, std=0.02)
        #     if isinstance(module, nn.Linear) and module.bias is not None:
        #         module.bias.data.zero_()
        #     elif isinstance(module, nn.LayerNorm):
        #         module.bias.data.zero_()
        #         module.weight.data.fill_(1.0)

        # Initialize transformer layers:
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

        self.apply(_basic_init)

        # Initialize text embedding table:
        nn.init.normal_(self.cond_emb.weight, std=0.02)

        # Initialize timestep embedding MLP:
        if self.cond_time:
            nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
            nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)

        # Zero-out adaLN modulation layers in DiT blocks:
        # for block in self.blocks:
        #     nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
        #     nn.init.constant_(block.adaLN_modulation[-1].bias, 0)

        # Zero-out output layers:
        if self.out_type == "dit":
            nn.init.constant_(
                self.output_process.final_layer.adaLN_modulation[-1].weight, 0
            )
            nn.init.constant_(
                self.output_process.final_layer.adaLN_modulation[-1].bias, 0
            )
            nn.init.constant_(self.output_process.final_layer.linear.weight, 0)
            nn.init.constant_(self.output_process.final_layer.linear.bias, 0)

    def forward(self, ids, y, m_lens):
        """
        :param ids: (b, n)
        :param y: raw text for cond_mode=text, (b, ) for cond_mode=action
        :m_lens: (b,)
        :return:
        """

        bs, ntokens = ids.shape
        device = ids.device

        # Positions that are PADDED are ALL FALSE
        non_pad_mask = lengths_to_mask(m_lens, ntokens)  # (b, n)
        ids = torch.where(non_pad_mask, ids, self.pad_id)

        force_mask = False
        if self.cond_mode == "text":
            with torch.no_grad():
                cond_vector = self.encode_text(y)
        elif self.cond_mode == "action":
            cond_vector = self.enc_action(y).to(device).float()
        elif self.cond_mode == "uncond":
            cond_vector = torch.zeros(bs, self.latent_dim).float().to(device)
            force_mask = True
        else:
            raise NotImplementedError("Unsupported condition mode!!!")

        """
        Prepare mask
        """
        rand_time = uniform((bs,), device=device)
        rand_mask_probs = self.noise_schedule(rand_time)
        num_token_masked = (ntokens * rand_mask_probs).round().clamp(min=1)
        timestep = rand_time * self.timesteps

        batch_randperm = torch.rand((bs, ntokens), device=device).argsort(dim=-1)

        # Positions to be MASKED are ALL TRUE
        mask = batch_randperm < num_token_masked.unsqueeze(-1)

        # Positions to be MASKED must also be NON-PADDED
        mask &= non_pad_mask

        # Note this is our training target, not input
        labels = torch.where(mask, ids, self.mask_id)

        x_ids = ids.clone()

        # Further Apply Bert Masking Scheme
        # Step 1: 10% replace with an incorrect token
        mask_rid = get_mask_subset_prob(mask, 0.1)
        rand_id = torch.randint_like(x_ids, high=self.num_tokens)
        x_ids = torch.where(mask_rid, rand_id, x_ids)

        # Step 2: 90% x 10% replace with correct token, and 90% x 88% replace with mask token
        mask_mid = get_mask_subset_prob(mask & ~mask_rid, 0.88)
        x_ids = torch.where(mask_mid, self.mask_id, x_ids)

        logits = self.trans_forward(
            x_ids, cond_vector, ~non_pad_mask, timestep, force_mask
        )
        ce_loss, pred_id, acc = cal_performance(
            logits, labels, ignore_index=self.mask_id
        )

        return ce_loss, pred_id, acc

    def forward_with_cond_scale(
        self,
        motion_ids,
        cond_vector,
        padding_mask,
        timestep,
        cond_scale=3,
        force_mask=False,
    ):
        if force_mask:
            return self.trans_forward(
                motion_ids, cond_vector, padding_mask, timestep, force_mask=True
            )

        logits = self.trans_forward(motion_ids, cond_vector, padding_mask, timestep)
        if cond_scale == 1:
            return logits

        aux_logits = self.trans_forward(
            motion_ids, cond_vector, padding_mask, timestep, force_mask=True
        )

        scaled_logits = aux_logits + (logits - aux_logits) * cond_scale
        return scaled_logits

    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        conds,
        m_lens,
        timesteps: int,
        cond_scale: int,
        temperature=1,
        topk_filter_thres=0.9,
        gsample=False,
        force_mask=False,
    ):
        device = next(self.parameters()).device
        seq_len = max(m_lens)
        batch_size = len(m_lens)

        if self.cond_mode == "text":
            with torch.no_grad():
                cond_vector = self.encode_text(conds)
        elif self.cond_mode == "action":
            cond_vector = self.enc_action(conds).to(device)
        elif self.cond_mode == "uncond":
            cond_vector = torch.zeros(batch_size, self.latent_dim).float().to(device)
        else:
            raise NotImplementedError("Unsupported condition mode!!!")

        padding_mask = ~lengths_to_mask(m_lens, seq_len)

        # Start from all tokens being masked
        ids = torch.where(padding_mask, self.pad_id, self.mask_id)
        scores = torch.where(padding_mask, 1e5, 0.0)
        starting_temperature = temperature

        for timestep, steps_until_x0 in zip(
            torch.linspace(0, 1, timesteps, device=device), reversed(range(timesteps))
        ):
            # 0 < timestep < 1
            rand_mask_prob = self.noise_schedule(timestep)  # Tensor
            timestep = (timestep * timesteps).round()[None]

            """
            Maskout, and cope with variable length
            """
            # fix: the ratio regarding lengths, instead of seq_len
            num_token_masked = torch.round(rand_mask_prob * m_lens).clamp(
                min=1
            )  # (b, )

            # select num_token_masked tokens with lowest scores to be masked
            sorted_indices = scores.argsort(
                dim=1
            )  # (b, k), sorted_indices[i, j] = the index of j-th lowest element in scores on dim=1
            ranks = sorted_indices.argsort(
                dim=1
            )  # (b, k), rank[i, j] = the rank (0: lowest) of scores[i, j] on dim=1
            is_mask = ranks < num_token_masked.unsqueeze(-1)
            ids = torch.where(is_mask, self.mask_id, ids)

            """
            Preparing input
            """
            # (b, num_token, seqlen)
            logits = self.forward_with_cond_scale(
                ids,
                cond_vector=cond_vector,
                padding_mask=padding_mask,
                timestep=timestep,
                cond_scale=cond_scale,
                force_mask=force_mask,
            )

            logits = logits.permute(0, 2, 1)  # (b, seqlen, ntoken)

            filtered_logits = top_k(logits, topk_filter_thres, dim=-1)

            """
            Update ids
            """
            # if force_mask:
            temperature = starting_temperature
            # else:
            # temperature = starting_temperature * (steps_until_x0 / timesteps)
            # temperature = max(temperature, 1e-4)
            # print(filtered_logits.shape)
            # temperature is annealed, gradually reducing temperature as well as randomness
            if gsample:  # use gumbel_softmax sampling
                pred_ids = gumbel_sample(
                    filtered_logits, temperature=temperature, dim=-1
                )  # (b, seqlen)
            else:  # use multinomial sampling
                probs = F.softmax(
                    filtered_logits / temperature, dim=-1
                )  # (b, seqlen, ntoken)

                pred_ids = Categorical(probs).sample()  # (b, seqlen)

            ids = torch.where(is_mask, pred_ids, ids)

            """
            Updating scores
            """
            probs_without_temperature = logits.softmax(dim=-1)  # (b, seqlen, ntoken)
            scores = probs_without_temperature.gather(
                2, pred_ids.unsqueeze(dim=-1)
            )  # (b, seqlen, 1)
            scores = scores.squeeze(-1)  # (b, seqlen)

            # We do not want to re-mask the previously kept tokens, or pad tokens
            scores = scores.masked_fill(~is_mask, 1e5)

        ids = torch.where(padding_mask, -1, ids)
        # print("Final", ids.max(), ids.min())
        return ids
