# import os
# from typing import Any, Dict, List, Optional, Union, Tuple
# from datasets.formatting import get_formatter, query_table, format_table
# import numpy as np
# from PIL import Image
# import time

# import torch
# import torch.nn as nn
# from torch.utils.data import DataLoader, Dataset
# from einops import rearrange, reduce, repeat

# from transformers import PreTrainedModel, BatchFeature
# from accelerate.utils import tqdm
# from absl import logging

# from .mllm import MLLM
# from ..utils.utils_speculative_decoding import (get_model_kwargs, init_outputs_dict, update_outputs_dict, finalize_sd_outputs, crop_past_key_values)

# class SpeculativeDecoding(object):
#     """
#     Module for Speculate Decoding
#     """
#     def __init__(
#             self,
#             _config,
#             drf_wrapper: MLLM,
#             tgt_wrapper: MLLM,
#             tokenizer,
#             drf_aux_tokenizer,
#             drf_image_processor,
#             tgt_image_processor,
#             **kwargs,
#         ):
#         self._config = _config

#         # wrappers: MLLM
#         self.drf_wrapper = drf_wrapper
#         self.tgt_wrapper = tgt_wrapper

#         # processors
#         self.tokenizer = tokenizer
#         self.drf_aux_tokenizer = drf_aux_tokenizer
#         self.drf_image_processor = drf_image_processor
#         self.tgt_image_processor = tgt_image_processor

#         # config
#         self.eos_token_id = kwargs['eos_token_id']
#         self.pad_token_id = kwargs['pad_token_id']
#         self.max_prompt_length = _config['max_prompt_length']
#         self.max_target_length = _config['max_target_length']
#         self.max_chunk_length = _config['max_chunk_length']

#         # Save tensor loading time for SpS
#         device = self.drf_wrapper.mllm.device
#         self.eos_token_id_tensor = torch.tensor(
#             ([self.eos_token_id] if isinstance(self.eos_token_id, int) else self.eos_token_id)
#         ).to(device) if self.eos_token_id is not None else None

#     def load_batch_to_device(self, batch):
#         for k, v in batch.items():
#             if isinstance(v, torch.Tensor):
#                 batch[k] = v.to(self.drf_wrapper.mllm.device)
#         return batch
    
#     def set_dtype(self, batch):
#         if self._config['is_drf_text_only'] and self._config['is_tgt_text_only']:
#             #  NO input image
#             return
#         if self._config['tgt_dtype'] == 'fp16':
#             batch['pixel_values'] = batch['pixel_values'].to(dtype=torch.float16)

#     def spec_decode(self, batch, do_sample=True) -> Dict:
#         drf_kwargs = get_model_kwargs()
#         tgt_kwargs = get_model_kwargs()
#         prompt_len = batch.input_ids.shape[1]
#         drf_kwargs['prompt_length'] = prompt_len
#         outputs_dict = init_outputs_dict(**drf_kwargs)
        
#         # load batch to device & set dtype
#         self.set_dtype(batch)
#         self.load_batch_to_device(batch)
#         # torch.cuda.empty_cache()
#         torch.cuda.synchronize()
#         start_time_spec_decode = time.time()
#         while True:
#             # draft
#             torch.cuda.synchronize()
#             start_time_draft = time.time()
#             outputs_drf = self.draft(batch, drf_kwargs)
#             torch.cuda.synchronize()
#             outputs_drf['time_drf_generate'] = time.time() - start_time_draft

#             # verify
#             torch.cuda.synchronize()
#             start_time_verify = time.time()
#             outputs_tgt = self.verify(batch, outputs_drf, tgt_kwargs)
#             torch.cuda.synchronize()
#             outputs_tgt['time_tgt_forward'] = time.time() - start_time_verify

#             if do_sample:
#                 if outputs_drf.get('logits') is not None:
#                     # speculative sampling
#                     valid_tokens, n_matches, first_rejected_token = self.speculative_sampling(batch, outputs_drf, outputs_tgt)
#                     update_outputs_dict(outputs_dict, n_matches, first_rejected_token, outputs_drf, outputs_tgt, batch)
#                 else:
#                     # multinomial sampling for the single last token
#                     probs = outputs_tgt.logits.softmax(dim=-1)
#                     valid_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
#             else: # greedy decoding
#                 pass
#                 # selected_tokens = new_logits.argmax(dim=-1)
#                 # candidate_new_tokens = candidate_input_ids[:, cur_len:]
#                 # n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()

#                 # # Ensure we don't generate beyond max_len or an EOS token
#                 # if is_done_candidate and n_matches == candidate_length:
#                 #     n_matches -= 1
#                 # valid_tokens = selected_tokens[:, : n_matches + 1]
            
            

#             # update input
#             batch, drf_kwargs, tgt_kwargs = self.update_inputs(batch, valid_tokens, outputs_drf, outputs_tgt, drf_kwargs)

#             # stopping criteria
#             if self.is_stop(batch, prompt_len):
#                 break
        
#         torch.cuda.synchronize()
#         outputs_dict['time_spec_decode'] = time.time() - start_time_spec_decode
#         finalize_sd_outputs(outputs_dict, batch, self.tokenizer, do_print=self._config['do_print'])
#         return outputs_dict

#     def is_stop(self, batch, prompt_len):        
#         """
#         1) Check whether the number of total generated tokens exceeds the self.max_target_length
#         Todo: pad could be problematic in the future (there is no pad token for now, as it's batch 1 case)
#         """
#         input_ids = batch.input_ids
#         num_cum_gen_tokens = batch.input_ids.shape[1] - prompt_len

#         if num_cum_gen_tokens >= self.max_target_length:
#             return True
        
#         """
#         2) Check whether the newly generated chunk has eos token
#         """
#         """eos_token_id_tensor = torch.tensor(
#             [self.eos_token_id]
#         ).to(batch.input_ids.device) if self.eos_token_id is not None else None"""
        
#         unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
#         unfinished_sequences = unfinished_sequences.mul(
#             input_ids[:, -1]
#             .tile(self.eos_token_id_tensor.shape[0], 1)
#             .ne(self.eos_token_id_tensor.unsqueeze(1))
#             .prod(dim=0)
#         )
#         if unfinished_sequences.max() == 0:
#             return True

#         return False


#     @torch.no_grad()
#     def draft(
#             self, 
#             batch,
#             drf_kwargs,
#         ) -> Dict:        

#         if 'aux_input_ids' not in batch:
#             input_ids = batch['input_ids']
#             attention_mask = batch['attention_mask']
#         else:
#             input_ids_aux_tokenized = batch['aux_input_ids']
#             attention_mask_aux_tokenized = batch['aux_attention_mask']
#             input_ids = torch.cat([input_ids_aux_tokenized, batch['input_ids'][:, drf_kwargs.get('prompt_length'):]], dim=1)
#             attention_mask = torch.cat([attention_mask_aux_tokenized, batch['attention_mask'][:, drf_kwargs.get('prompt_length'):]], dim=1)
        
#         # draft generation config
#         cum_num_gen_tokens = batch.input_ids.shape[1] - drf_kwargs.get('prompt_length')
#         max_new_tokens = min(self.max_chunk_length, self.max_target_length - cum_num_gen_tokens - 1)
        
#         if max_new_tokens == 0:
#             # the single last token will be generated from target model
#             return dict(
#                 sequences=batch['input_ids'],
#                 logits=None,
#                 candidate_length=0,
#             )
        
#         # kwargs
#         kwargs = dict(
#             max_new_tokens=max_new_tokens,
#             do_sample=True,
#             pad_token_id=self.eos_token_id, # To prevent warnings: Setting pad_token_id to eos_token_id:151645 for open-end generation.
#         )

#         inputs = dict(
#             input_ids=input_ids,
#             attention_mask=attention_mask,
#         )
#         if not self._config['is_drf_text_only']:
#             inputs['pixel_values'] = batch['pixel_values']

#         kwargs['past_key_values'] = drf_kwargs['past_key_values']
        
#         # KV cache update for full acceptance: get kv cache for the token (last drafted token), right before the bonus token
#         if drf_kwargs.pop('num_accepted_tokens', None) == self.max_chunk_length:
#             inputs_get_last_kv = dict(
#                 input_ids=input_ids[:, :-1].to(self.drf_wrapper.mllm.device),
#                 attention_mask=attention_mask[:, :-1].to(self.drf_wrapper.mllm.device),
#             )
#             if not self._config['is_drf_text_only']:
#                 inputs_get_last_kv['pixel_values'] = inputs['pixel_values'].to(self.drf_wrapper.mllm.device)

#             kwargs['past_key_values'] = self.drf_wrapper.mllm.generate(
#                 **inputs_get_last_kv,
#                 past_key_values=kwargs['past_key_values'],
#                 max_new_tokens=1,
#                 use_cache=True,
#                 return_dict_in_generate=True,
#                 pad_token_id=self.eos_token_id, # To prevent warnings: Setting pad_token_id to eos_token_id:151645 for open-end generation.
#             ).past_key_values

#         # generate draft w/ MLLM
#         outputs_drf = self.drf_wrapper.generate(
#             **inputs,
#             **kwargs,
#         )

#         # update candidate_length
#         outputs_drf['candidate_length'] = len(outputs_drf.logits) + 1

#         if 'aux_input_ids' in batch:
#             # drf tokenizer
#             input_ids_drf_tokenized = batch['input_ids'][:, :drf_kwargs.get('prompt_length')]
#             # aux tokenizer
#             prompt_length_aux = batch['aux_input_ids'].shape[1]
#             sequences_generated = outputs_drf['sequences'][:, prompt_length_aux:]   
#             outputs_drf['sequences'] = torch.cat([input_ids_drf_tokenized, sequences_generated], dim=1)
        
#         return outputs_drf

#     @torch.no_grad()
#     def verify(
#         self,
#         batch,
#         outputs_drf,
#         tgt_kwargs,
#     ) -> Dict:
        
#         # input_ids
#         inputs = dict(
#             input_ids=outputs_drf['sequences'],
#             past_key_values=tgt_kwargs.get('past_key_values'),
#         )

#         if inputs.get('past_key_values') is None:
#             # first itr
#             inputs['attention_mask'] = (inputs['input_ids'] != self.pad_token_id)
#         else:
#             # subsequent itr
#             candidate_length = 1 if outputs_drf['candidate_length']==0 else outputs_drf['candidate_length']
#             input_ids_addendum = inputs.pop('input_ids')[:, -candidate_length:].to(self.tgt_wrapper.mllm.device)
#             inputs['inputs_embeds'] = self.tgt_wrapper.mllm.get_input_embeddings()(input_ids_addendum).to(self.tgt_wrapper.mllm.dtype)

#             # Todo
#             cache_len = inputs.get('past_key_values')[0][0].shape[2]
#             attention_mask_addendum = (input_ids_addendum != self.pad_token_id).to(self.tgt_wrapper.mllm.device)
#             attention_mask_past = torch.ones((input_ids_addendum.shape[0], cache_len)).to(self.tgt_wrapper.mllm.device)
#             inputs['attention_mask'] = torch.cat([attention_mask_past, attention_mask_addendum], dim=-1)
#             inputs['position_ids'] = (attention_mask_addendum.cumsum(-1) -1) + cache_len
#             """
#             # 현재 세팅에서는 draft에서 pad 나올 일이 없음 > F.pad로 바꾸고 device 관리해야
#             cache_len = inputs.get('past_key_values')[0][0].shape[2]
#             input_len = inputs.get('inputs_embeds').shape[1]

#             inputs['attention_mask'] = torch.ones((input_ids_addendum.shape[0], cache_len + input_len)).to(self.tgt_wrapper.mllm.device)
#             inputs['position_ids'] = (inputs['attention_mask'][-input_len: ].cumsum(-1) -1) + cache_len
#             """

#         pixel_values = batch['pixel_values']

#         # forward
#         outputs_tgt = self.tgt_wrapper(
#             **inputs,
#             pixel_values=pixel_values,
#         )

#         return outputs_tgt

#     def speculative_sampling(self, batch, outputs_drf, outputs_tgt):
        
#         # the number of newly generated tokens in this itr / newly generated tokens
#         candidate_length = len(outputs_drf.logits) 
#         labels_drf = outputs_drf['sequences'][:, -candidate_length:] 
        
#         # the number of generated tokens so far
#         cur_len = outputs_drf['sequences'].shape[1] - batch.input_ids.shape[1] 

#         # the maximum number of tokens that can be added 
#         max_matches = self.max_target_length - cur_len - 1

#         # logits for newly generate/forwarded tokens
#         q_logits = rearrange(torch.stack(outputs_drf.logits), 's b v -> b s v')
#         p_logits = outputs_tgt.logits[:, -candidate_length - 1 :]  # excludes the input prompt if present
        
#         # [MLLMSD] adjust vocab size 
#         if not self._config['is_drf_from_mllm']:
#             p_logits = p_logits[:, :, :q_logits.shape[-1]]

#         """eos_token_id = [self.eos_token_id] if isinstance(self.eos_token_id, int) else self.eos_token_id
#         eos_token_id_tensor = torch.tensor(
#             eos_token_id
#         ).to(labels_drf.device) if eos_token_id is not None else None"""
#         # ).to(batch.input_ids.device) if eos_token_id is not None else None
        
#         # Check whether the last token is eos token
#         last_assistant_token_is_eos = (
#             ~labels_drf[:, -1]
#             .tile(self.eos_token_id_tensor.shape[0], 1)
#             .ne(self.eos_token_id_tensor.unsqueeze(1))
#             .prod(dim=0)
#             .bool()
#         )
        
#         # filter only valid tokens
#         valid_tokens, n_matches, first_rejected_token = self._speculative_sampling(
#             candidate_input_ids=labels_drf,
#             candidate_logits=q_logits,
#             candidate_length=candidate_length,
#             new_logits=p_logits,
#             last_assistant_token_is_eos=last_assistant_token_is_eos,
#             max_matches=max_matches,
#         )

#         return valid_tokens, n_matches, first_rejected_token
    
#     def update_inputs(self, batch, valid_tokens, outputs_drf, outputs_tgt, drf_kwargs):
#         batch['input_ids'] = torch.cat(
#             (
#                 batch.input_ids, 
#                 valid_tokens.to(batch.input_ids.device)
#             ), 
#             dim=-1 
#         )
#         if outputs_drf.get('logits') is None:
#             return batch, None, None

#         # expand attention mask
#         attention_mask_addendum = (valid_tokens != self.pad_token_id).to(batch.attention_mask.device)
#         attention_mask = torch.cat((batch.attention_mask, attention_mask_addendum), dim=-1)
#         batch['attention_mask'] = attention_mask

#         # Update past_key_values
#         num_generated_tokens = len(outputs_drf.logits) 
#         num_accepted_tokens = valid_tokens.shape[-1] - 1 # excludes bonus token
#         drf_new_cache_size = outputs_drf.past_key_values[0][0].shape[2] - (num_generated_tokens - num_accepted_tokens) # input_ids.shape[1] - 1 # len(prev_input_ids) + len(valid_tokens) - # bonus token
#         tgt_new_cache_size = outputs_tgt.past_key_values[0][0].shape[2] - (num_generated_tokens - num_accepted_tokens) # input_ids.shape[1] - 1 # len(prev_input_ids) + len(valid_tokens) - # bonus token

#         drf_kwargs.update({
#             'num_accepted_tokens': num_accepted_tokens,
#             'past_key_values': crop_past_key_values(self.drf_wrapper.mllm, outputs_drf.past_key_values, drf_new_cache_size),
#         })
#         tgt_kwargs = {
#             'past_key_values': crop_past_key_values(self.tgt_wrapper.mllm, outputs_tgt.past_key_values, tgt_new_cache_size),
#         }

#         return batch, drf_kwargs, tgt_kwargs


#     def _speculative_sampling(
#             self,
#             candidate_input_ids,
#             candidate_logits,
#             candidate_length,
#             new_logits,
#             last_assistant_token_is_eos,
#             max_matches,
#         ):
#         """
#         Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns
#         the selected tokens, as well as the number of candidate matches.

#         NOTE: Unless otherwise stated, the variable names match those in the paper.
#         """
#         # 새로 생성된 token ids
#         new_candidate_input_ids = candidate_input_ids[:, -candidate_length:]
#         # Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
#         # selected by the assistant, respectively.
#         q = candidate_logits.softmax(dim=-1)
#         q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
#         p = new_logits.softmax(dim=-1)
#         p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
#         probability_ratio = p_i / q_i
#         if self._config['do_print']:
#             print(probability_ratio)

#         # When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
#         # than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio
#         # (= keep with p = probability_ratio). Keep all the tokens until the first rejection
        
#         r_i = torch.rand_like(probability_ratio)
#         is_accepted = r_i <= probability_ratio # length gamma
#         n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum()  # this is `n` in algorithm 1

#         # Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior)
#         if last_assistant_token_is_eos and n_matches == candidate_length:
#             # max_chunk_length를 꽉채워서 accept 한 경우
#             # Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model
#             # due to acceptance on EOS we fix `n_matches`
#             n_matches -= 1
#             valid_tokens = new_candidate_input_ids[:, : n_matches + 1]
#         else:
#             n_matches = min(n_matches, max_matches)

#             # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
#             gamma = min(candidate_logits.shape[1], max_matches)
#             p_n_plus_1 = p[:, n_matches, :]
#             if n_matches < gamma:
#                 # Resample from adjusted distribution
#                 q_n_plus_1 = q[:, n_matches, :]
#                 p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0)
#                 p_prime.div_(p_prime.sum())
#             else:
#                 # Sample directly
#                 p_prime = p_n_plus_1
#             t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]

#             # The selected tokens include the matches (if any) plus the next sampled tokens
#             if n_matches > 0:
#                 valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1)
#             else:
#                 valid_tokens = t

#         # include resampled token
#         first_rejected_token = new_candidate_input_ids[:, n_matches] if n_matches < candidate_length else None
#         return valid_tokens, n_matches, first_rejected_token