"""This file contains implementation for MaskGIT model.

Copyright (2024) Bytedance Ltd. and/or its affiliates

Licensed under the Apache License, Version 2.0 (the "License"); 
you may not use this file except in compliance with the License. 
You may obtain a copy of the License at 

    http://www.apache.org/licenses/LICENSE-2.0 

Unless required by applicable law or agreed to in writing, software 
distributed under the License is distributed on an "AS IS" BASIS, 
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
See the License for the specific language governing permissions and 
limitations under the License. 

Reference: 
    https://github.com/huggingface/open-muse
    https://github.com/baaivision/MUSE-Pytorch
    https://github.com/sail-sg/MDT/blob/main/masked_diffusion/models.py
"""

import torch
from torch import nn
import numpy as np
import math
import torch.utils.checkpoint
from transformers import BertConfig, BertModel
from einops import rearrange

import json
from huggingface_hub import PyTorchModelHubMixin
from omegaconf import OmegaConf
from pathlib import Path

from deps.rar.modeling.modules.base_model import BaseModel
from deps.rar.modeling.modules.blocks import UViTBlock


class ImageBert(BaseModel, PyTorchModelHubMixin, tags=["arxiv:2406.07550", "image-generation"], repo_url="https://github.com/bytedance/1d-tokenizer", license="apache-2.0"):
    def __init__(self, config):

        if isinstance(config, dict):
            config = OmegaConf.create(config)

        super().__init__()
        self.config = config
        self.target_codebook_size = config.model.vq_model.codebook_size
        self.condition_num_classes = config.model.generator.condition_num_classes
        self.image_seq_len = config.model.generator.image_seq_len
        self.mask_token_id = self.target_codebook_size
        self.hidden_size = config.model.generator.hidden_size
        self.num_hidden_layers = config.model.generator.num_hidden_layers
        self.num_attention_heads = config.model.generator.num_attention_heads
        self.intermediate_size = config.model.generator.intermediate_size

        self.model = BertModel(BertConfig(
            vocab_size=self.target_codebook_size + self.condition_num_classes + 2,
            hidden_size=self.hidden_size,
            num_hidden_layers=self.num_hidden_layers,
            num_attention_heads=self.num_attention_heads,
            intermediate_size=self.intermediate_size,
            hidden_act='gelu',
            hidden_dropout_prob=config.model.generator.dropout,
            attention_probs_dropout_prob=config.model.generator.attn_drop,
            max_position_embeddings=config.model.generator.image_seq_len + 1,
            initializer_range=0.02,
            layer_norm_eps=1e-12,
            pad_token_id=None,
            position_embedding_type="absolute",
            use_cache=True
        ), add_pooling_layer=False)
        self.model.lm_head = nn.Linear(self.hidden_size, self.target_codebook_size, bias=True)
        
        self.model.post_init()

    def _save_pretrained(self, save_directory: Path) -> None:
        """Save weights and config to a local directory."""
        # Assume 'self.config' is your DictConfig object
        # Convert to a regular dictionary
        dict_config = OmegaConf.to_container(self.config)
        # Save as JSON
        file_path = Path(save_directory) / "config.json"
        with open(file_path, 'w') as json_file:
            json.dump(dict_config, json_file, indent=4)
        super()._save_pretrained(save_directory)
    
    def forward(self, input_ids=None, condition=None, cond_drop_prob=0.1):
        # Token space:
        #  [0, codebook_size - 1]                       : those are the learned quantized image tokens
        #  codebook_size                                : the mask token used to mask image tokens
        #  [codebook_size + 1, codebook_size + nclass]  : the imagenet class tokens
        #  codebook_size + 1 + nclass                   : the class drop label
        drop_label_mask = torch.rand_like(condition, dtype=torch.float) < cond_drop_prob
        # Shift the classes
        condition = condition + self.target_codebook_size + 1  # [0, 999] -> [codebook_size + 1, codebook_size + 999]
        condition[drop_label_mask] = self.condition_num_classes + self.target_codebook_size + 1
        # prepend condition token
        if input_ids is not None:
            input_ids = torch.cat([condition.view(condition.shape[0], -1),
                                   input_ids.view(input_ids.shape[0], -1),], dim=1)
        else:
            # at least there should be masked token
            raise NotImplementedError
        model_output = self.model(input_ids=input_ids)
        model_output = model_output[0]
        return self.model.lm_head(model_output[:, 1:]) # remove cond
    
    # ref: https://github.com/baaivision/MUSE-Pytorch/blob/master/libs/muse.py#L40
    @torch.no_grad()
    def generate(self,
                 condition,
                 guidance_scale=3.0,
                 guidance_decay="constant",
                 guidance_scale_pow=3.0,
                 randomize_temperature=4.5,
                 softmax_temperature_annealing=False,
                 num_sample_steps=8):
        if guidance_decay not in ["constant", "linear", "power-cosine"]:
            # contstant: constant guidance scale
            # linear: linear increasing the guidance scale as in MUSE
            # power-cosine: the guidance schedule from MDT
            raise ValueError(f"Unsupported guidance decay {guidance_decay}")
        device = condition.device
        ids = torch.full((condition.shape[0], self.image_seq_len),
                          self.mask_token_id, device=device)

        cfg_scale = guidance_scale if guidance_decay == "constant" else 0.

        for step in range(num_sample_steps):
            ratio = 1. * (step + 1) / num_sample_steps
            annealed_temp = randomize_temperature * (1.0 - ratio)
            is_mask = (ids == self.mask_token_id)

            if guidance_decay == "power-cosine":
                # ref: https://github.com/sail-sg/MDT/blob/main/masked_diffusion/models.py#L501
                guidance_scale_pow = torch.ones((1), device=device) * guidance_scale_pow
                scale_step = (1 - torch.cos(((step / num_sample_steps) ** guidance_scale_pow) * torch.pi)) * 1/2
                cfg_scale = (guidance_scale - 1) * scale_step + 1

            if cfg_scale != 0:
                cond_logits = self.forward(
                    ids, condition, cond_drop_prob=0.0
                )
                uncond_logits = self.forward(
                    ids, condition, cond_drop_prob=1.0
                )
                if guidance_decay == "power-cosine":
                    logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
                else:
                    logits = cond_logits + (cond_logits - uncond_logits) * cfg_scale
            else:
                logits = self.forward(
                    ids, condition, cond_drop_prob=0.0
                )

            if softmax_temperature_annealing:
                softmax_temperature = 0.5 + 0.8 * (1 - ratio)
                logits = logits / softmax_temperature

            # Add gumbel noise
            def log(t, eps=1e-20):
                return torch.log(t.clamp(min=eps))
            def gumbel_noise(t):
                noise = torch.zeros_like(t).uniform_(0, 1)
                return -log(-log(noise))
            def add_gumbel_noise(t, temperature):
                return t + temperature * gumbel_noise(t)

            sampled_ids = add_gumbel_noise(logits, annealed_temp).argmax(dim=-1)
            sampled_logits = torch.squeeze(
                torch.gather(logits, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
            sampled_ids = torch.where(is_mask, sampled_ids, ids)
            sampled_logits = torch.where(is_mask, sampled_logits, +np.inf).float()
            # masking
            mask_ratio = np.arccos(ratio) / (math.pi * 0.5)

            mask_len = torch.Tensor([np.floor(self.image_seq_len * mask_ratio)]).to(device)
            mask_len = torch.maximum(torch.Tensor([1]).to(device),
                                     torch.minimum(torch.sum(is_mask, dim=-1, keepdims=True) - 1,
                                                   mask_len))[0].squeeze()
            confidence = add_gumbel_noise(sampled_logits, annealed_temp)
            sorted_confidence, _ = torch.sort(confidence, axis=-1)
            cut_off = sorted_confidence[:, mask_len.long() - 1:mask_len.long()]
            masking = (confidence <= cut_off)
            if step == num_sample_steps - 1:
                ids = sampled_ids
            else:
                ids = torch.where(masking, self.mask_token_id, sampled_ids)

            if guidance_decay == "linear":
                cfg_scale = ratio * guidance_scale
        return ids

    def masking_input_tokens(self, input_tokens):
        batch_size, seq_len = input_tokens.shape
        device = input_tokens.device

        timesteps = torch.zeros((batch_size,), device=device).float().uniform_(0, 1.0)
        mask_ratio = torch.acos(timesteps) / (math.pi * 0.5) # arccos schedule
        mask_ratio = torch.clamp(mask_ratio, min=1e-6, max=1.)
        num_token_masked = (seq_len * mask_ratio).round().clamp(min=1)
        batch_randperm = torch.rand(batch_size, seq_len, device=device).argsort(dim=-1)
        masks = batch_randperm < rearrange(num_token_masked, 'b -> b 1')
        masked_tokens = torch.where(masks, self.mask_token_id, input_tokens)
        return masked_tokens, masks


class UViTBert(ImageBert):
    def __init__(self, config):
        super().__init__(config=config)

        del self.model

        self.embeddings = nn.Embedding(
            self.target_codebook_size + self.condition_num_classes + 2,
            self.hidden_size)
        
        self.pos_embed = nn.init.trunc_normal_(
            nn.Parameter(torch.zeros(1, self.config.model.generator.image_seq_len + 1, self.hidden_size)), 0., 0.02)
 
        self.in_blocks = nn.ModuleList([
            UViTBlock(
                dim=self.hidden_size, num_heads=self.num_attention_heads, mlp_ratio=(self.intermediate_size / self.hidden_size),
                qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, use_checkpoint=False)
            for _ in range(self.num_hidden_layers // 2)])

        self.mid_block = UViTBlock(
                dim=self.hidden_size, num_heads=self.num_attention_heads, mlp_ratio=(self.intermediate_size / self.hidden_size),
                qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, use_checkpoint=False)

        self.out_blocks = nn.ModuleList([
            UViTBlock(
                dim=self.hidden_size, num_heads=self.num_attention_heads, mlp_ratio=(self.intermediate_size / self.hidden_size),
                qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, skip=True, use_checkpoint=False)
            for _ in range(self.num_hidden_layers // 2)])

        self.norm = nn.LayerNorm(self.hidden_size)
        self.lm_head = nn.Linear(self.hidden_size,
                                 self.target_codebook_size, bias=True)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Embedding):
            m.weight.data = nn.init.trunc_normal_(m.weight.data, mean=0.0, std=0.02)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, input_ids=None, condition=None, cond_drop_prob=0.1):
        # Token space:
        #  [0, codebook_size - 1]                       : those are the learned quantized image tokens
        #  codebook_size                                : the mask token used to mask image tokens
        #  [codebook_size + 1, codebook_size + nclass]  : the imagenet class tokens
        #  codebook_size + 1 + nclass                   : the class drop label
        drop_label_mask = torch.rand_like(condition, dtype=torch.float) < cond_drop_prob
        # Shift the classes
        condition = condition + self.target_codebook_size + 1  # [0, 999] -> [codebook_size + 1, codebook_size + 999]
        condition[drop_label_mask] = self.condition_num_classes + self.target_codebook_size + 1
        # prepend condition token
        if input_ids is not None:
            input_ids = torch.cat([condition.view(condition.shape[0], -1),
                                   input_ids.view(input_ids.shape[0], -1),], dim=1)
        else:
            # at least there should be masked token
            raise NotImplementedError
        # UViT forward
        embeddings = self.embeddings(input_ids)
        x = embeddings + self.pos_embed[:, :embeddings.shape[1]]
        skips = []
        for blk in self.in_blocks:
            x = blk(x)
            skips.append(x)
        x = self.mid_block(x)
        for blk in self.out_blocks:
            x = blk(x, skips.pop())
        x = self.norm(x)
        return self.lm_head(x[:, 1:]) # remove cond