import torch

import numpy as np
import copy
# DEVICE_TYPE = os.environ.get("DEVICE_TYPE", "ascend")
# if DEVICE_TYPE == "ascend":
#     import torch_npu
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutputWithPast,
    TokenClassifierOutput,
)
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.utils import (
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_greater_or_equal_2_10,
    logging,
    replace_return_docstrings,
)
from transformers.models.llama.configuration_llama import LlamaConfig

import time
from transformers import LlamaForCausalLM


class SelftokModel(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.cfg = config

    @torch.no_grad()
    def generate_image(
        self,
        input_ids,
        cfg_type='fix',
        img_seq_len=512,
        cfg_token_id=128256 + 32768 +2,
        image_vocab_slice=(128256, 128256 + 32768),
        temperature=1.0,
        ori_attention_mask=None,
        top_k=None,
        top_p=None,
        guidance_scale=None,
        beam_size=None,
        use_past=False,
        samples_per_prompt=1,
        entropy_bound=2,
        min_cfg=1,
    ):
        if isinstance(input_ids, list):
            bsz = len(input_ids)
        elif isinstance(input_ids, np.ndarray):
            bsz = input_ids.shape[0]
        elif isinstance(input_ids, torch.Tensor):
            bsz = input_ids.shape[0]

        decoding_pattern = None
        # decoding patten
        if top_k is None and top_p is None and beam_size is None:
            decoding_pattern = 'greedy search'

        if guidance_scale is not None:
            image_ids = torch.empty((bsz // 2, 0),dtype=torch.long).to(input_ids.device)
        else:
            image_ids = torch.empty((bsz, 0),dtype=torch.long).to(input_ids.device)
        beam_scores = None
        past_key_values = None
        cache_position = None

        if use_past:
            bsz, _ = input_ids.shape
            past_key_values = None
            cache_position = torch.arange(input_ids.shape[1]).to(input_ids.device)

        if ori_attention_mask is None:
            ori_attention_mask = torch.ones((input_ids.shape[0], input_ids.shape[1]), dtype=torch.long).to(input_ids.device)
        L1 = ori_attention_mask.shape[1]
        
        for idx in range(img_seq_len):
            
            if idx != 0:
                if guidance_scale is not None:
                    input_ids = torch.cat(
                        [input_ids, torch.cat([image_ids] * 2)], dim=-1)
                else:
                    input_ids = torch.cat([input_ids, image_ids], dim=-1)
            B, L = input_ids.shape[0], input_ids.shape[1]
            
            attention_mask = torch.cat((ori_attention_mask, torch.ones((B, L-L1), dtype=torch.long).to(input_ids.device)), dim=1)


            if idx != 0 and use_past:
                outputs = self.forward(input_ids[:, -1:],
                                        past_key_values=past_key_values,
                                        cache_position=cache_position[-1:],
                                        attention_mask=attention_mask)
            else:
                outputs = self.forward(input_ids,
                                        attention_mask=attention_mask,
                                        past_key_values=past_key_values,
                                        cache_position=cache_position,
                                        use_cache=use_past)
            logits = outputs.logits
            past_key_values = outputs.past_key_values

            if use_past:
                past_key_values = outputs.past_key_values
                cache_position = torch.cat([cache_position, torch.tensor([cache_position.size(0)], device='cuda')], dim=0)

            if guidance_scale is not None:
                cond_logits, uncond_logits = logits.chunk(2)
                
                if cfg_type == 'fix':
                    logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
                elif cfg_type == 'adaptive':
                    # compute the entropy
                    entropy = F.softmax(cond_logits, dim=-1)
                    entropy *= torch.log(entropy)
                    entropy = -torch.sum(entropy[:, :, image_vocab_slice[0]:image_vocab_slice[1]], dim=-1, keepdim=True) #(bsz, 1, 1)
                    
                    entropy = torch.where(torch.isnan(entropy), 0, entropy)
                    
                    new_guidance_scale = torch.where(entropy < entropy_bound, min_cfg, guidance_scale)
                    
                    logits = uncond_logits + new_guidance_scale * (cond_logits - uncond_logits)
            
            if decoding_pattern == 'greedy search':
                idx_next = torch.argmax(logits[:, -1, image_vocab_slice[0]:image_vocab_slice[1]], dim=-1, keepdim=True)
            else:
                logits = logits[:, -1,
                                image_vocab_slice[0]:image_vocab_slice[1]].float(
                                ) / temperature

                assert not (top_k and top_p and beam_size)
                # sampling
                if top_k:
                    v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                    logits[logits < v[:, [-1]]] = float("-inf")
                probs = F.softmax(logits, dim=-1)
                if top_p:
                    probs_sort, probs_idx = torch.sort(probs,
                                                        dim=-1,
                                                        descending=True)
                    probs_sum = torch.cumsum(probs_sort, dim=-1)
                    mask = probs_sum - probs_sort > top_p
                    probs_sort[mask] = 0.0
                    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
                    idx_next = torch.multinomial(probs_sort, num_samples=1)
                    idx_next = torch.gather(probs_idx, -1, idx_next)
                else:
                    idx_next = torch.multinomial(probs, num_samples=1)
            
            image_ids = torch.cat((image_ids, idx_next + image_vocab_slice[0]),
                                    dim=1)

        return image_ids