import os
from typing import Any, Dict, List, Optional, Union, Tuple
from datasets.formatting import get_formatter, query_table, format_table
import numpy as np
from PIL import Image
import time

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from einops import rearrange, reduce, repeat

from transformers import AutoProcessor, PreTrainedModel, GenerationMixin
from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration
from absl import logging

from .mllm import MLLM
from .decoding import Decoding
from .eval_specbench import measure_time
from .sps.decoding import _assisted_decoding
from .sps.utils import (
    _get_candidate_generator_vlm, 
    _validate_assistant_vlm,
)
from .sps.modeling_llava_vlmsd import (
        forward, 
        prepare_inputs_for_generation, 
        _update_model_kwargs_for_generation, 
        _get_initial_cache_position, 
    )

class SpeculativeDecoding(Decoding):
    """
    Module for Speculate Decoding
    """
    def __init__(
            self,
            _config,
            models: Dict[str, MLLM],
            tokenizers: Dict[str, Any],
            image_processors: Dict[str, Any],
            **kwargs,
        ):
        super().__init__(
            _config=_config,
            models=models,
            tokenizers=tokenizers,
            image_processors=image_processors,
            **kwargs,
        )

        # sd config
        self.generate_config['num_assistant_tokens'] = _config['max_chunk_length']

        LlavaForConditionalGeneration.forward = forward
        LlavaForConditionalGeneration.prepare_inputs_for_generation = prepare_inputs_for_generation
        LlavaForConditionalGeneration._get_initial_cache_position = _get_initial_cache_position
        LlavaForConditionalGeneration._update_model_kwargs_for_generation = _update_model_kwargs_for_generation
        GenerationMixin._assisted_decoding = _assisted_decoding
        GenerationMixin._get_candidate_generator = _get_candidate_generator_vlm
        GenerationMixin._validate_assistant = _validate_assistant_vlm

    def decode(self, batch, **kwargs) -> Dict:
        # Set prompt
        _, time_prompt_process = measure_time(self.prompt_setter.set_batch, batch)

        # Generate
        outputs_generate = self.models['tgt'].generate(
            **batch,
            **self.generate_config,
            assistant_model=self.models['drf'],
        )
        if outputs_generate.metrics['time_prompt_process'] is None:
            assert self.prompt_setter.drafting not in ['mulimodal', 'image-pool']
            outputs_generate.metrics['time_prompt_process'] = time_prompt_process 
        return outputs_generate