from typing import Dict, List, Any
import time
from functools import wraps

import torch

# def timing_decorator(timing_dict_key):
#     def decorator(method):
#         @wraps(method)
#         def timed(*args, **kwargs):
#             start_time = time.time()
#             result = method(*args, **kwargs)
#             end_time = time.time()
#             elapsed_time = end_time - start_time
#             if 'outputs_dict' in kwargs:
#                 kwargs['outputs_dict'][timing_dict_key] = elapsed_time
#             return result
#         return timed
#     return decorator

def init_outputs_dict(**kwargs) -> Dict[str, List]:
    """
    Initialize the outputs dictionary with empty lists for accumulation.
    """
    return {
        'num_accepted_tokens': [],
        'num_generated_tokens': [],
        'num_prefill_tokens': None,
        'ids_accepted_tokens': [],
        'ids_first_rejected_tokens': [],
        'first_rejected_tokens': [],
        'prompt_length': kwargs['prompt_length'],
        'time_drf_generate': [],
        'time_tgt_forward': [],
        'time_spec_decode': None,
    }

def update_outputs_dict(outputs_dict: Dict[str, List], n_matches: torch.Tensor, first_rejected_token: torch.Tensor, outputs_drf, outputs_tgt, batch):
    """
    Update outputs_dict with the number of accepted tokens, generated tokens,
    and IDs of accepted tokens.
    """
    # Todo: first rejected token
    outputs_dict['num_accepted_tokens'].append(n_matches.item())
    outputs_dict['num_generated_tokens'].append(len(outputs_drf.logits))
    if outputs_dict['num_prefill_tokens'] is None:
        outputs_dict['num_prefill_tokens'] = outputs_drf.past_key_values[0][0].shape[2] - len(outputs_drf.logits) + 1
    outputs_dict['ids_accepted_tokens'].extend([i + batch['input_ids'].shape[1] for i in range(n_matches)])
    if first_rejected_token is not None:
        outputs_dict['ids_first_rejected_tokens'].append(batch['input_ids'].shape[1] + n_matches.item())
        outputs_dict['first_rejected_tokens'].append(first_rejected_token.item())
    outputs_dict['time_drf_generate'].append(outputs_drf['time_drf_generate'])
    outputs_dict['time_tgt_forward'].append(outputs_tgt['time_tgt_forward'])

def finalize_sd_outputs(outputs_dict: Dict[str, List], batch, tokenizer, do_print=False):
    """
    Finalize the speculative decoding outputs by adding sequences and
    optionally printing the output details.
    """
    outputs_dict['sequences'] = batch.input_ids.tolist()
    outputs_dict['num_target_tokens'] = batch.input_ids.shape[1] - outputs_dict['prompt_length']
    
    if do_print:
        print_sd_outputs(tokenizer, outputs_dict)

def get_model_kwargs() -> Dict[str, Any]:
    """
    Get model kwargs with default values.
    """
    model_kwargs = {
        'num_accepted_tokens': None,
        'past_key_values': None
    }
    return model_kwargs

def crop_past_key_values(model, past_key_values, maximum_length):
    """Crops the past key values up to a certain maximum length."""
    new_past = []
    if model.config.is_encoder_decoder:
        for idx in range(len(past_key_values)):
            new_past.append(
                (
                    past_key_values[idx][0][:, :, :maximum_length, :],
                    past_key_values[idx][1][:, :, :maximum_length, :],
                    past_key_values[idx][2],
                    past_key_values[idx][3],
                )
            )
        past_key_values = tuple(new_past)
    else:
        for idx in range(len(past_key_values)):
            new_past.append(
                (
                    past_key_values[idx][0][:, :, :maximum_length, :],
                    past_key_values[idx][1][:, :, :maximum_length, :],
                )
            )
        past_key_values = tuple(new_past)
    return past_key_values

def print_sd_outputs(tokenizer, outputs_dict):
    """
    Utility function to print speculative decoding outputs.
    """
    
    sequences = outputs_dict['sequences']
    num_accepted_tokens = outputs_dict.get('num_accepted_tokens')
    num_generated_tokens = outputs_dict.get('num_generated_tokens')
    ids_accepted_tokens = outputs_dict.get('ids_accepted_tokens')
    
    print("#" * 50)
    print(f"Decoded sequence: {tokenizer.batch_decode(sequences, skip_special_tokens=False)[0]}")
    print(f"Number of accepted tokens: {num_accepted_tokens}")
    print(f"Number of generated tokens: {num_generated_tokens}")
    print(f"Sum of total tokens: {len(sequences[0])}")
    print(f"Sum of accepted tokens: {sum(num_accepted_tokens)}")
    print(f"Sum of generated tokens: {sum(num_generated_tokens)}")
    print(f"Accepted token ids: {ids_accepted_tokens}")
    print("#" * 50)