import os
from typing import Any, Dict, List, Optional, Union, Tuple
from datasets.formatting import get_formatter, query_table, format_table
import numpy as np
from PIL import Image
import time

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from einops import rearrange, reduce, repeat

from transformers import PreTrainedModel, BatchFeature
from accelerate.utils import tqdm
from absl import logging

from .mllm import MLLM
from ..utils.utils_speculative_decoding import (get_model_kwargs, init_outputs_dict, update_outputs_dict, finalize_sd_outputs, crop_past_key_values)
from .decoding import Decoding

class AutoregressiveDecoding(Decoding):
    """
    Module for Speculate Decoding
    """
    def __init__(
            self,
            _config,
            models: Dict[str, MLLM],
            tokenizers: Dict[str, Any],
            image_processors: Dict[str, Any],
            **kwargs,
        ):
        super().__init__(
            _config=_config,
            models=models,
            tokenizers=tokenizers,
            image_processors=image_processors,
            **kwargs,
        )
        
        # ard config
        pass

    def decode(self, batch, **kwargs) -> Dict:
        if self._config['is_drf_text_only']:
            batch.pop('pixel_values')
        
        outputs_generate = self.models['drf'].generate(
            **batch,
            **self.generate_config,
        )
        
        # logging time_prefill
        outputs_generate.metrics = {}
        outputs_generate.metrics['time_prefill_drf'] = outputs_generate.time_prefill
        outputs_generate.metrics['num_prefill_tokens_drf'] = outputs_generate.num_prefill_tokens
        
        return outputs_generate