import torch
import torch.nn as nn
import numpy as np
# from networks.layers import *
import torch.nn.functional as F
import clip
from model.encode_text import T5TextEncoder
from einops import repeat
from functools import partial
from model.transformer.tools import *
from torch.distributions.categorical import Categorical


class InputProcess(nn.Module):
    def __init__(self, input_feats, latent_dim):
        super().__init__()
        self.input_feats = input_feats
        self.latent_dim = latent_dim
        self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim)

    def forward(self, x):
        # [bs, ntokens, input_feats]
        # x = x.permute((1, 0, 2)) # [seqen, bs, input_feats]
        # print(x.shape)
        x = self.poseEmbedding(x)  # [bs, seqlen, d]
        return x

class PositionalEncoding(nn.Module):
    #Borrow from MDM, the same as above, but add dropout, exponential may improve precision
    def __init__(self, d_model, dropout=0.1, max_len=500):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1) #[max_len, 1, d_model]

        self.register_buffer('pe', pe)

    def forward(self, x):
        # not used in the final model
        x = x + self.pe[:x.shape[0], :]
        return self.dropout(x)

class OutputProcess_Bert(nn.Module):
    def __init__(self, out_feats, latent_dim):
        super().__init__()
        self.dense = nn.Linear(latent_dim, latent_dim)
        self.transform_act_fn = F.gelu
        self.LayerNorm = nn.LayerNorm(latent_dim, eps=1e-12)
        self.poseFinal = nn.Linear(latent_dim, out_feats) #Bias!

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        output = self.poseFinal(hidden_states)  # [bs, seqlen, out_feats]
        output = output.permute(0, 2, 1)  # [bs, c, seqlen]
        return output
    
class VAR(nn.Module):
    def __init__(self, code_dim, quantizer, latent_dim=256, ff_size=1024, num_layers=8,
                 num_heads=4, dropout=0.1, text_dim=512, cond_drop_prob=0.1,
                 device=None, cfg=None, full_length=80):
        super().__init__()
        print(f'latent_dim: {latent_dim}, ff_size: {ff_size}, nlayers: {num_layers}, nheads: {num_heads}, dropout: {dropout}, scales: {quantizer.scales}')

        self.code_dim = code_dim
        self.latent_dim = latent_dim
        self.text_dim = text_dim
        self.dropout = dropout
        self.cfg = cfg
        self.device = device
        self.full_length = full_length
        self.scales = quantizer.scales
        self.patch_sizes = [int(full_length // scale) for scale in self.scales]
        self.cond_drop_prob = cond_drop_prob

        assert code_dim == quantizer.code_dim
        

        '''
        Preparing Networks
        '''
        init_std = math.sqrt(1 / self.latent_dim / 3)

        # 1. Input 
        self.token_emb = InputProcess(self.code_dim, self.latent_dim)
        self.quantizer = quantizer

        self.cond_emb = InputProcess(self.text_dim, self.latent_dim)

        self.pos_start = nn.Parameter(torch.empty(1, self.patch_sizes[0], self.code_dim))
        nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std)


        # 2. absolute position embedding
        pos_1LC = []
        for ps in self.patch_sizes:
            pe = torch.empty(1, ps, self.latent_dim)
            nn.init.trunc_normal_(pe, mean=0, std=init_std)
            pos_1LC.append(pe)
        pos_1LC = torch.cat(pos_1LC, dim=1)
        self.pos_1LC = nn.Parameter(pos_1LC)

        # 3. prepare backbones
        if self.cfg.model.fuse_mode == 'in_context':
            seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
                                                            nhead=num_heads,
                                                            batch_first=True,
                                                            dim_feedforward=ff_size,
                                                            dropout=dropout,
                                                            activation='gelu')

            self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,
                                                        num_layers=num_layers)
            
            self.lvl_embed = nn.Embedding(len(self.patch_sizes) + 1, self.latent_dim)
            nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std)
            input_patch_size = [cfg.data.max_text_length] + self.patch_sizes
        elif self.cfg.model.fuse_mode == 'cross_attention':
            seqTransEncoderLayer = nn.TransformerDecoderLayer(d_model=self.latent_dim,
                                                              nhead=num_heads,
                                                              batch_first=True,
                                                              dim_feedforward=ff_size,
                                                              dropout=dropout,
                                                              activation='gelu')
            self.seqTransEncoder = nn.TransformerDecoder(seqTransEncoderLayer,
                                                         num_layers=num_layers)
            self.lvl_embed = nn.Embedding(len(self.patch_sizes), self.latent_dim)
            nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std)
            input_patch_size = self.patch_sizes
        self.output_process = OutputProcess_Bert(out_feats=self.quantizer.nb_code, latent_dim=latent_dim)
        self.nb_tokens = self.quantizer.nb_code

        self.apply(self.__init_weights)

        # 4. prepare attention mask for training
        
        d = torch.cat([torch.full((ps,), i) for i, ps in enumerate(input_patch_size)]).view(1, -1, 1) #[1, 2, 2, 3, 3, 3, 3, 4, ...,]
        dT = d.transpose(1, 2) # d: (1, T, 1) dT: (1, 1, T)
        self.register_buffer('lvl_1L', dT[:, 0].contiguous())
        attn_mask = torch.where(d >= dT, False, True) # True: disable, False: enable
        self.register_buffer('attn_mask', attn_mask[0].contiguous())


        '''
        Preparing text encoder
        '''

        self.text_emb = T5TextEncoder(
            device, 
        #  use_text_preprocessing,
            local_files_only=False, 
            from_pretrained=cfg.text_embedder.version, 
            model_max_length=cfg.data.max_text_length
        )


    def __init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


    def encode_text(self, raw_text):
        text_embedding, mask = self.text_emb.get_text_embeddings(raw_text)
        return text_embedding, mask
    
    def sinusoidal_encoding(self, t):
        """
        Compute sinusoidal positional encoding for a batch of timesteps t.
        Args:
            t (Tensor): Shape (B, L), representing the timestep indices.
            d_model (int): Embedding dimension.

        Returns:
            Tensor of shape (B, L, D).
        """
        div_term = torch.exp(torch.arange(0, self.latent_dim, 2, dtype=torch.float32, device=t.device) * (-math.log(10000.0) / self.latent_dim))
        
        pe = torch.zeros(*t.shape, self.latent_dim, device=t.device)  # (B, L, D)
        pe[..., 0::2] = torch.sin(t.unsqueeze(-1) * div_term)  # Apply sin to even indices
        pe[..., 1::2] = torch.cos(t.unsqueeze(-1) * div_term)  # Apply cos to odd indices
        
        return pe
    
    def get_pe_from_mlens(self, mlens, max_len):
        B = len(mlens)
        t = torch.arange(max_len, device=mlens.device).unsqueeze(0).expand(B, max_len) # [0, 1, 2, 3,..., max_len]
        T = mlens.unsqueeze(1).expand(B, max_len) # [12, 12, 12, 12, 12, ..., 12]
        t_progress = ((T - t - 1) / (T - 1 + 1e-4)) * 80 # [11/11, 10/11, 9/11, ..., 0/11] * 80
        torch.clamp_min_(t_progress, 0.)
        return self.sinusoidal_encoding(t_progress)

    
    def mask_cond(self, cond, force_mask=False):
        bs, _, _ =  cond.shape
        if force_mask:
            return torch.zeros_like(cond)
        elif self.training and self.cond_drop_prob > 0.:
            mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_drop_prob).view(bs, 1, 1)
            return cond * (1. - mask)
        else:
            return cond
        
    def trans_forward(self, motion_seq, cond, toa_pe, cond_padding_mask, motion_padding_mask):
        '''
        :param motion_ids: (b, seqlen, code_dim)
        :cond_padding_mask: (b, seqlen), all pad positions are TRUE else FALSE
        :motion_padding_mask: (b, t_seqlen), all pad positions are TRUE else FALSE
        :param cond: (b, t_seqlen, embed_dim) for text
        :param force_mask: boolean
        :return:
            -logits: (b, num_token, seqlen)
        '''
        b, t_seqlen, _ = cond.shape

        x = self.token_emb(motion_seq)
        # x = torch

        cond = self.cond_emb(cond)

        # add positional encoding
        x = x + self.pos_1LC[:, :x.shape[1]]

        if 'enable_toa_pe' in self.cfg.model and self.cfg.model.enable_toa_pe:
            x = x + toa_pe[:, :x.shape[1]]

        if self.cfg.model.fuse_mode == "in_context":
            xseq = torch.cat([cond, x], dim=1) #(b, t_seqlen + seqlen, latent_dim)
            xseq = xseq + self.lvl_embed(self.lvl_1L[:, :xseq.shape[1]]) 
            padding_mask = torch.cat([cond_padding_mask, motion_padding_mask], dim=1).bool()
            output = self.seqTransEncoder(xseq, 
                                          mask=self.attn_mask[:xseq.shape[1], :xseq.shape[1]].bool(), 
                                          src_key_padding_mask=padding_mask)[:, t_seqlen:]
        elif self.cfg.model.fuse_mode == 'cross_attention':
            xseq = x + self.lvl_embed(self.lvl_1L[:, :x.shape[1]])
            output = self.seqTransEncoder(xseq,
                                          cond,
                                          tgt_mask=self.attn_mask[:xseq.shape[1], :xseq.shape[1]].bool(),
                                          tgt_key_padding_mask=motion_padding_mask.bool(),
                                          memory_key_padding_mask=cond_padding_mask.bool())
            
        logits = self.output_process(output)
        return logits
    

    def forward(self, id_list, y, m_lens, aug=None):
        '''
        :param ids: (b, n)
        :param y: raw text for cond_mode=text, (b, ) for cond_mode=action
        :m_lens: (b,)
        :return:
        '''

        non_pad_mask = []
        tgt_ids = []
        # n_id_list = []
        assert self.full_length == id_list[-1].shape[1]
        assert len(id_list) == len(self.scales)

        range_ind = [0]
        time_to_arrival_pe = []

        for scale, ele in zip(self.scales, id_list):
            ds_mlens = (m_lens // scale).long()
            ds_non_pad_mask = lengths_to_mask(ds_mlens, int(self.full_length // scale))
            non_pad_mask.append(ds_non_pad_mask)
            tgt_ids.append(ele)
            range_ind.append(int(self.full_length // scale))
            time_to_arrival_pe.append(self.get_pe_from_mlens(ds_mlens, int(self.full_length // scale)))

        range_ind = np.cumsum(range_ind)
        non_pad_mask = torch.cat(non_pad_mask, dim=1)
        time_to_arrival_pe = torch.cat(time_to_arrival_pe, dim=1)
        tgt_ids = torch.cat(tgt_ids, dim=1)
        tgt_ids = torch.where(non_pad_mask, tgt_ids, -1)

        assert tgt_ids.shape[:2] == non_pad_mask.shape[:2]

        input_ids = tgt_ids.clone()

        if self.training and aug is not None:
            if 'perturb' in aug: # 'replace:0.2' or 'replace:-1', -1 means random probablity
                ratio = float(aug.split(":")[-1])
                if ratio == -1.: ratio = torch.rand(1)[0]
                replace_mask = (torch.rand_like(input_ids.float()) < ratio) & non_pad_mask
                rand_tokens = torch.randint_like(input_ids, high=self.nb_tokens)
                input_ids = torch.where(replace_mask, rand_tokens, input_ids)

        id_list = []
        for ind in range(len(range_ind)-1):
            id_list.append(input_ids[:, range_ind[ind]:range_ind[ind+1]])

        x_seq_wo_first_l = self.quantizer.idx_to_var_input(id_list)

        pos_start = repeat(self.pos_start, '1 t c -> b t c', b=len(x_seq_wo_first_l))
        x_seq = torch.cat([pos_start, x_seq_wo_first_l], dim=1)

        with torch.no_grad():
            cond_embs, cond_att_mask = self.encode_text(y)
            cond_padding_mask = (cond_att_mask==0)
        
        logits = self.trans_forward(x_seq, cond_embs, time_to_arrival_pe, cond_padding_mask, ~non_pad_mask)
        ce_loss, pred_id, acc = cal_performance(logits, tgt_ids, ignore_index=-1)

        return ce_loss, pred_id, acc
    
    def forward_with_cond_scale(self, 
                                motion_seq, 
                                cond_embs, 
                                time_to_arrival_pe,
                                cond_padding_mask, 
                                motion_padding_mask,
                                cond_scale):
        if cond_scale == 0:
            scaled_logits = self.trans_forward(motion_seq, 
                                               cond_embs, 
                                               time_to_arrival_pe, 
                                               cond_padding_mask, 
                                               motion_padding_mask)
        else:
            input_motion_seq = torch.cat([motion_seq, motion_seq], dim=0)
            input_cond_embs = torch.cat([self.mask_cond(cond_embs, force_mask=True),
                                        self.mask_cond(cond_embs, force_mask=False)], dim=0)
            input_cond_padding_mask = torch.cat([cond_padding_mask, cond_padding_mask], dim=0)
            input_motion_padding_mask = torch.cat([motion_padding_mask, motion_padding_mask], dim=0)
            input_toa_pe = torch.cat([time_to_arrival_pe, time_to_arrival_pe], dim=0)

            output_logits = self.trans_forward(input_motion_seq, 
                                               input_cond_embs, 
                                               input_toa_pe, 
                                               input_cond_padding_mask, 
                                               input_motion_padding_mask)
            aux_logits, logits = output_logits.chunk(2, dim=0)

            # scaled_logits = aux_logits + (logits - aux_logits) * cond_scale
            scaled_logits = (1 + cond_scale) * logits -  cond_scale * aux_logits
        return scaled_logits
    
    @torch.no_grad()
    @eval_decorator
    def generate(self,
                 conds,
                 m_lens,
                 cond_scale,
                 topk_filter_thres=0.9,
                 temperature=1,
                 gssample=False):
        B = len(m_lens)
        full_non_padding_mask = []
        range_ind = []
        time_to_arrival_pe = []
        for scale in self.scales:
            full_non_padding_mask.append(lengths_to_mask((m_lens//scale).long(), int(self.full_length//scale)))
            range_ind.append(int(self.full_length//scale))
            time_to_arrival_pe.append(self.get_pe_from_mlens((m_lens//scale).long(), int(self.full_length//scale)))
        
        range_ind = np.cumsum(range_ind)
        full_non_padding_mask = torch.cat(full_non_padding_mask, dim=1)
        time_to_arrival_pe = torch.cat(time_to_arrival_pe, dim=1)

        with torch.no_grad():
            cond_embs, cond_att_mask = self.encode_text(conds)
            cond_padding_mask = (cond_att_mask==0)
        
        xseq = repeat(self.pos_start, '1 t c -> b t c', b=len(m_lens))

        code = xseq.new_zeros(B, self.code_dim, self.full_length)
        output = []

        for level, ps in enumerate(self.patch_sizes):
            ratio = level / (len(self.patch_sizes) - 1)
            non_padding_mask = full_non_padding_mask[:, :range_ind[level]]
            logits = self.forward_with_cond_scale(motion_seq=xseq, 
                                                  cond_embs=cond_embs, 
                                                  time_to_arrival_pe=time_to_arrival_pe,
                                                  cond_padding_mask=cond_padding_mask, 
                                                  motion_padding_mask=~non_padding_mask, 
                                                  cond_scale=cond_scale * ratio) #(b, :range_ind[level], num_tokens)
            logits = logits[:, :, -ps:].permute(0, 2, 1) #(b, patch_size, ntoken)
            filtered_logits = top_k(logits, topk_filter_thres, dim=-1)

            if gssample:  # use gumbel_softmax sampling
                # print("1111")
                pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1)  # (b, seqlen)
            else:  # use multinomial sampling
                # print("2222")
                probs = F.softmax(filtered_logits / temperature, dim=-1)  # (b, seqlen, ntoken)
                # print(temperature, starting_temperature, steps_until_x0, timesteps)
                # print(probs / temperature)
                pred_ids = Categorical(probs).sample()  # (b, seqlen)
            
            pred_ids = torch.where(non_padding_mask[:, -ps:], pred_ids, -1)

            output.append(pred_ids)

            code, next_input = self.quantizer.get_next_var_input(level, pred_ids, code, self.full_length)
            if level != len(self.scales):
                xseq = torch.cat([xseq, next_input.permute(0, 2, 1)], dim=1)

        return code, output
