

import torch
import torch.nn.functional as F
from transformers import AutoConfig
from .modeling_utils import ConfigMixin, ModelMixin, register_to_config
from .sampling import cosine_schedule, mask_by_random_topk
from .phi import PhiForCausalLM
from .dis_diff_loss import mp2sigma, p2score, score_entropy, DWDSE_loss

class Face(ModelMixin, ConfigMixin):
    _supports_gradient_checkpointing = True

    @register_to_config
    def __init__(
            self,
            w_clip_vit,
            vocab_size,
            llm_vocab_size,
            llm_model_path='',
            codebook_size=8192,
            num_vq_tokens=256,
            load_from_Face=True,
            **kwargs,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.register_to_config(mask_token_id=vocab_size - 1)
        if load_from_Face:
            config = AutoConfig.from_pretrained(llm_model_path)
            self.Face = PhiForCausalLM(config)
        else:
            self.Face = PhiForCausalLM.from_pretrained(llm_model_path, attn_implementation='sdpa')
        self.Face.resize_token_embeddings(self.vocab_size)
        self.output_size = self.vocab_size

        if self.w_clip_vit:
            self.mm_projector = torch.nn.Sequential(
                torch.nn.Linear(1024, 2048),
                torch.nn.GELU(),
                torch.nn.Linear(2048, 2048)
            )
    def apply_moe_arch(
            self, 
            config_moe,
            **kwargs,
        ):
        self.Face.apply_moe_arch(
            config_moe,
            bsz_t2i=kwargs.get('bsz_t2i', None),
            num_image_tokens=kwargs.get('num_image_tokens', None),
            num_text_tokens=kwargs.get('num_text_tokens', None),
        )
        print(f"Have applied MoE architecture to the Face")
        print(f"config_moe: {config_moe}")

    def _set_gradient_checkpointing(self, module, value=False):
        self.gradient_checkpointing = True

    def forward(
            self,
            input_ids,
            input_embeddings=None,
            attention_mask=None,
            labels=None,
            label_smoothing=0.0,
            batch_size_t2i=0,
            batch_size_lm=0,
            batch_size_mmu=0,
            max_seq_length=128,
            labels_mask_text=None,
            labels_mask_image=None,
            t2i_loss_type='cross_entropy',
            **kwargs,
    ):
        
        if t2i_loss_type == 'with_DWDSE':
            if kwargs.get('mask_prob') is not None:
                mp = kwargs.get('mask_prob')
            else:
                raise ValueError('mask_prob is required for DWDSE loss')
            if kwargs.get('mask_schedule_type') is not None:
                mask_schedule_type = kwargs.get('mask_schedule_type')
            else:
                raise ValueError('mask_schedule_type is required for DWDSE loss')

            sigma, dsigma = mp2sigma(mp, mask_schedule_type)

        if input_embeddings is None:
            logits = self.Face(
                input_ids=input_ids, 
                attention_mask=attention_mask,
                clip_features=kwargs.get("clip_feature", None),
                face_features=kwargs.get("face_feature", None),
                bsz_t2i=kwargs.get('bsz_t2i', None),
                sigmas=sigma if t2i_loss_type == 'with_DWDSE' else None,
            )['logits']
        else:
            # print(f"clip_features: {kwargs.get('clip_feature', None)}")
            # print(f"face_features: {kwargs.get('face_feature', None)}")
            logits = self.Face(
                inputs_embeds=input_embeddings, 
                attention_mask=attention_mask,
                clip_features=kwargs.get("clip_feature", None),
                face_features=kwargs.get("face_feature", None),
                bsz_t2i=kwargs.get('bsz_t2i', None),
                sigmas=sigma if t2i_loss_type == 'with_DWDSE' else None,
            )['logits']

        if labels is not None:
            # 1. Mask token prediction (discrete diffusion) for image generation
            # Note that, max_seq_length indicates the maximum number of text tokens, maybe a bit confused.
            if t2i_loss_type == 'cross_entropy':
                loss_t2i = F.cross_entropy(
                    logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size),
                    labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100,
                )

                loss_t2i_dwdse = torch.tensor(0.0).to(logits.device)
            elif t2i_loss_type == 'only_DWDSE':
                loss_t2i = torch.tensor(0.0).to(logits.device)

                if kwargs.get('mask_prob') is not None:
                    mp = kwargs.get('mask_prob')
                else:
                    raise ValueError('mask_prob is required for DWDSE loss')

                if kwargs.get('mask_schedule_type') is not None:
                    mask_schedule_type = kwargs.get('mask_schedule_type')
                else:
                    raise ValueError('mask_schedule_type is required for DWDSE loss')
                
                if kwargs.get('mask_id') is not None:
                    mask_id = kwargs.get('mask_id')
                else:
                    raise ValueError('mask_id is required for DWDSE loss')
                       
                x0 = labels[:batch_size_t2i, max_seq_length + 1:]

                x = input_ids[:batch_size_t2i, max_seq_length + 1:]
    
                sigma, dsigma = mp2sigma(mp, mask_schedule_type)

                preds = logits[:batch_size_t2i, max_seq_length + 1:].contiguous()

                preds = preds.softmax(dim=-1)

                scores = p2score(preds, x, mask_id, mp)

                loss_se = score_entropy(scores, x, x0, mask_id, sigma)

                loss_t2i_dwdse = DWDSE_loss(loss_se, dsigma)

            elif t2i_loss_type == 'with_DWDSE':
                loss_t2i = F.cross_entropy(
                    logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size),
                    labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100,
                )

                # if kwargs.get('mask_prob') is not None:
                #     mp = kwargs.get('mask_prob')
                # else:
                #     raise ValueError('mask_prob is required for DWDSE loss')

                # if kwargs.get('mask_schedule_type') is not None:
                #     mask_schedule_type = kwargs.get('mask_schedule_type')
                # else:
                #     raise ValueError('mask_schedule_type is required for DWDSE loss')
                
                if kwargs.get('mask_id') is not None:
                    mask_id = kwargs.get('mask_id')
                else:
                    raise ValueError('mask_id is required for DWDSE loss')
                       
                x0 = labels[:batch_size_t2i, max_seq_length + 1:]

                x = input_ids[:batch_size_t2i, max_seq_length + 1:]
    
                # sigma, dsigma = mp2sigma(mp, mask_schedule_type)

                preds = logits[:batch_size_t2i, max_seq_length + 1:].contiguous()

                preds = preds.softmax(dim=-1)

                scores = p2score(preds, x, mask_id, mp)

                loss_se = score_entropy(scores, x, x0, mask_id, sigma)

                loss_t2i_dwdse = DWDSE_loss(loss_se, dsigma)


            # 2. Next token prediction for language modeling
            # loss_lm = F.cross_entropy(
            #     logits[batch_size_t2i:batch_size_t2i + batch_size_lm, :-1].contiguous().view(-1, self.output_size),
            #     labels[batch_size_t2i:batch_size_t2i + batch_size_lm, 1:].contiguous().view(-1), ignore_index=-100,
            # )
            loss_lm = torch.tensor(0.0).to(logits.device)

            # 3. Next token prediction for captioning/multimodal understanding
            loss_mmu = F.cross_entropy(
                logits[-batch_size_mmu:, :-1].contiguous().view(-1, self.output_size),
                labels[-batch_size_mmu:, 1:].contiguous().view(-1), ignore_index=-100,
            )

            return logits, loss_t2i, loss_lm, loss_mmu, loss_t2i_dwdse

        return logits

    def t2i_generate(
            self,
            input_ids: torch.LongTensor = None,
            uncond_input_ids: torch.LongTensor = None,
            attention_mask=None,
            temperature=1.0,
            timesteps=18,  # ideal number of steps is 18 in maskgit paper
            guidance_scale=0,
            noise_schedule=cosine_schedule,
            generator: torch.Generator = None,
            config=None,
            **kwargs,
    ):
        """
        Generate 1:1 similar to the original MaskGit repo
        https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79
        Ps: 一旦某一次迭代还原了该index的token，之后迭代就不再mask该token
        """
        # begin with all image token ids masked
        mask_token_id = self.config.mask_token_id
        num_vq_tokens = config.model.Face.num_vq_tokens
        num_new_special_tokens = config.model.Face.num_new_special_tokens

        input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone()
        input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id,
                                                    mask_token_id,
                                                    input_ids_minus_lm_vocab_size - config.model.Face.llm_vocab_size - num_new_special_tokens)

        # for classifier-free guidance
        if uncond_input_ids is not None:
            uncond_prefix = uncond_input_ids[:, :config.dataset.preprocessing.max_seq_length + 1]

        ratio = 1.0
        mask_ratio = noise_schedule(torch.tensor(ratio).to(input_ids.device))
        for step in range(timesteps):
            if uncond_input_ids is not None and guidance_scale > 0:
                uncond_input_ids = torch.cat(
                    [uncond_prefix, input_ids[:, config.dataset.preprocessing.max_seq_length + 1:]], dim=1)
                model_input = torch.cat([input_ids, uncond_input_ids])
                cond_logits, uncond_logits = self(
                    model_input, 
                    attention_mask=attention_mask, 
                    bsz_t2i=model_input.shape[0],
                    mask_prob=mask_ratio.repeat(model_input.shape[0]),
                    mask_schedule_type=kwargs.get('mask_schedule_type', None),
                    t2i_loss_type=kwargs.get('t2i_loss_type', 'cross_entropy'),
                ).chunk(2)
                # logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
                # it seems that muse has a different cfg setting
                logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits
                logits = logits[:, -(num_vq_tokens + 1):-1, config.model.Face.llm_vocab_size + num_new_special_tokens:-1]
            else:
                logits = self(
                    input_ids, 
                    attention_mask=attention_mask, 
                    bsz_t2i=input_ids.shape[0],
                    mask_prob=mask_ratio.repeat(input_ids.shape[0]),
                    mask_schedule_type=kwargs.get('mask_schedule_type', None),
                    t2i_loss_type=kwargs.get('t2i_loss_type', 'cross_entropy'),
                )
                logits = logits[:, -(num_vq_tokens + 1):-1, config.model.Face.llm_vocab_size + num_new_special_tokens:-1]

            probs = logits.softmax(dim=-1)
            sampled = probs.reshape(-1, logits.size(-1))
            sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) # 1、按照概率采样

            unknown_map = input_ids_minus_lm_vocab_size == mask_token_id
            sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size) # 2、x为mask token的位置，用采样的token替换，其余的保持不变
            # Defines the mask ratio for the next round. The number to mask out is
            # determined by mask_ratio * unknown_number_in_the_beginning.
            # ratio = 1.0 * (step + 1) / timesteps
            ratio = 1.0 * ((timesteps - 1) - step) / timesteps
            mask_ratio = noise_schedule(torch.tensor(ratio).to(logits.device))
            # Computes the probabilities of each selected tokens.
            selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None])
            selected_probs = selected_probs.squeeze(-1)

            # Ignores the tokens given in the input by overwriting their confidence.
            selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)  # 一旦某一次迭代还原了该index的token，之后迭代就不再mask该token
            # Gets mask lens for each sample in the batch according to the mask ratio.
            mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device)
            # Keeps at least one of prediction in this round and also masks out at least
            # one and for the next iteration
            mask_len = torch.max(
                torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) # 3、确定mask_len，从num_tokens ~ 0
            )
            # Adds noise for randomness
            temperature = temperature * (1.0 - ratio)
            masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) # 4、masking为true的地方，此次迭代设置为mask tokebn
            # Masks tokens with lower confidence.
            input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id,
                                                          sampled_ids + config.model.Face.llm_vocab_size
                                                          + num_new_special_tokens)
            input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids)

        return sampled_ids

    @torch.no_grad()
    def mmu_generate(self, idx=None, input_embeddings=None, attention_mask=None, max_new_tokens=100, temperature=1.0, top_k=None, eot_token=None, clip_feature=None, face_feature=None):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        """
        try:
            device = idx.device
        except:
            device = input_embeddings.device

        result = []
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            # idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            # forward the model to get the logits for the index in the sequence
            # logits, _ = self(idx_cond)
            logits = self(
                idx, 
                input_embeddings=input_embeddings, 
                attention_mask=attention_mask,
                clip_feature=clip_feature,
                face_feature=face_feature,
                bsz_t2i=0,
            )

            L = attention_mask.shape[-1]
            attention_mask = attention_mask.squeeze()
            attention_mask_a = torch.hstack(
                [
                    attention_mask,  # L, L
                    torch.zeros((L, 1)).to(device) + torch.finfo(logits.dtype).min,
                ]
            )
            attention_mask_b = torch.vstack(
                [
                    attention_mask_a,  # L, L+1
                    torch.hstack([attention_mask[-1, :], torch.tensor([0]).to(device)]).unsqueeze(0),
                ]
            )
            attention_mask = attention_mask_b

            # pluck the logits at the final step and scale by desired temperature
            logits = logits[:, -1, :] / temperature
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            # apply softmax to convert logits to (normalized) probabilities
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)
            result.append(idx_next[0][0])
            # append sampled index to the running sequence and continue
            if self.config.w_clip_vit:
                idx_next_embeddings = self.Face.model.embed_tokens(idx_next)
                input_embeddings = torch.cat([input_embeddings, idx_next_embeddings], dim=1)
            else:
                idx = torch.cat((idx, idx_next), dim=1)

            if eot_token is not None and idx_next.cpu() == eot_token:
                break

        return result