import os
from typing import Any, Dict, List, Optional, Union, Tuple
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from absl import logging

import transformers
from transformers import AutoProcessor, PreTrainedModel, GenerationMixin
from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM

from .mllm import MLLM
from .promptsetter import PromptSetter
from .sample import _sample
from .utils import GenerateDecoderOnlyOutput
from ..utils.util import get_pseudo_image_text_token_ids, get_image_escape_token_num, get_caption_prefix_ids, patch_function

class Decoding(object):
    """
    Module for Speculate Decoding
    """
    def __init__(
            self,
            _config,
            models: Dict[str, MLLM],
            tokenizers: Dict[str, Any],
            image_processors: Dict[str, Any],
            **kwargs,
        ):
        self._config = _config
        
        self.models = models

        # processors
        self.tokenizers = tokenizers
        self.image_processors = image_processors


        # config
        device = self.models['drf'].device
        image_tokenized_ids = self.tokenizers['drf']("<image>").input_ids
        escape_tokenized_ids = self.tokenizers['drf']("\n").input_ids # 0x0A for \n (hexadecimal)
        num_image_str_tokenized, num_escape_tokenized = get_image_escape_token_num(_config['drf'])
        assert image_tokenized_ids.size(1) == num_image_str_tokenized, f"Tokenizing a '<image>' should result in {num_image_str_tokenized} tokens"
        assert escape_tokenized_ids.size(1) == num_escape_tokenized, f"Tokenizing a '\n' should result in {num_escape_tokenized} tokens"
        caption_prefix_ids = get_caption_prefix_ids(_config['drf']).to(device)

        self.image_token_id = image_tokenized_ids[0, -1].item()
        self.escape_token_id = escape_tokenized_ids[0, -1].item()
        self.pseudo_image_text_token_ids = get_pseudo_image_text_token_ids(_config['drf']).to(device)
        # 529, 3027, 29958 for llava-llama

        prompt_kwargs = dict(
            image_token_id=self.image_token_id,
            escape_token_id=self.escape_token_id,
            pseudo_image_text_token_ids=self.pseudo_image_text_token_ids,
            caption_prefix_ids=caption_prefix_ids,
            device=device,
        )

        self.prompt_setter = PromptSetter(_config,
                                          tokenizer=self.tokenizers['drf'],
                                          **prompt_kwargs)
        self.models['drf'].prompt_setter = self.prompt_setter
        self.models['drf']._config = _config
        self.models['tgt']._config = _config
        # generation config
        self.generate_config = dict(
            do_sample=(_config['temperature'] == 1),
            use_cache=True,
            max_new_tokens=_config['max_target_length'],
            return_dict_in_generate=True,
            pad_token_id=kwargs['eos_token_id'],
            # output_logits=True,
            # output_hidden_states=False,
        )
        Qwen2ForCausalLM._sample = _sample
        LlamaForCausalLM._sample = _sample
        LlavaForConditionalGeneration._sample = _sample
        # transformers.generation.utils.GenerateDecoderOnlyOutput = GenerateDecoderOnlyOutput

        logging.info(f"[Decoding] image_token_id: {self.image_token_id}")
        logging.info(f"[Decoding] escape_token_id: {self.escape_token_id}")

        

    def load_batch_to_device(self, batch):
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                batch[k] = v.to(self.models['drf'].device)
            if k == 'pixel_values' and 'llava-hf/llava-interleave-qwen' in self._config['drf'] and self._config['drf_dtype'] == 'fp16': 
                batch[k] = batch[k].half()
        return batch

    def decode(self, **kwargs) -> Dict:
        raise NotImplementedError