import os
import sys
import time
import argparse
import json
from tqdm import tqdm

from modeling_llama import MultiHeadLlamaForCausalLM, LlamaForCausalLM
from transformers import AutoTokenizer
import numpy as np
import torch
import datasets

from safetensors.torch import save_file
from transformers import set_seed

# set visible GPU
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# set seed for reproducibility
SEED = 42
set_seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

def generate(model, tokenizer, prompts, args, mode, num_runs=2):

    model.generation_config.pad_token_id = tokenizer.eos_token_id 

    formatted_prompts = [tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) for prompt in prompts]

    inputs = tokenizer(formatted_prompts, padding="longest", return_tensors="pt", add_special_tokens=False)
    inputs = {key: val.to(model.device) for key, val in inputs.items()}

    print(f'inputs: {inputs["input_ids"].shape}')

    times = []
    for _ in range(1 + num_runs):
        start_time = time.time()
        outputs = model.generate(
            **inputs,
            max_new_tokens=args.max_tokens,
            eos_token_id=[
                tokenizer.eos_token_id,
                tokenizer.convert_tokens_to_ids("<|eot_id|>")
                ],
            do_sample=False,
            temperature=None,
            top_p=None,
            return_dict_in_generate=True,
            output_logits=False,
            output_hidden_states=True,
            use_cache=True,
            enable_early_exit=True if mode == 'early_exit' else False,
            return_early_layer_logits=False
        )
        print(f'Generation time (model.generate): {time.time() - start_time}s')
        times.append(time.time() - start_time)

    times = times[1:]  # remove the first time because it's usually an outlier
    print(f'Average time: {round(np.mean(times), 3)}s ± {round(np.std(times), 3)}s')

    # num of new tokens
    num_new_tokens = outputs.sequences.shape[-1] - inputs["input_ids"].shape[-1]
    print(f'Number of new tokens: {num_new_tokens}')
    print(f'Number of early exit tokens: {outputs.early_exit_cnt}, Ratio: {100 * outputs.early_exit_cnt / num_new_tokens}%')
    print(f'Verification Success Rate ({outputs.success_verify_cnt} tokens): {100 * outputs.success_verify_cnt / (outputs.early_exit_cnt + 1e-20)}%')
    print(f'Inference time per token (ms/t): {1000 * np.mean(times) / num_new_tokens}')
    print(f'# of tokens per second (t/s): {num_new_tokens / np.mean(times)}')
    
    ##### Generated Tokens #####
    sequences = outputs.sequences
    print(f'sequences shape: {sequences.shape}')
    response = [seq[inputs["input_ids"][idx].shape[-1]:] for idx, seq in enumerate(sequences)]
    print(f'response ids: {response[0]}')

    gen_text = tokenizer.batch_decode(response, skip_special_tokens=True)
    print(f'len of generated tokens: {response[0].shape}')

    for idx, prompt in enumerate(formatted_prompts):
        print(f"Prompt: {prompt!r}\nGenerated text: {gen_text[idx]!r}\n")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()  

    parser.add_argument('--model_name_or_path', type=str, default='llama_3.1_8b_instruct_lr_5e-3_epoch_50', help='name of the model in Hugging Face model hub or path to the model')

    parser.add_argument('--output_dir', type=str, help='Path to the output file')
    parser.add_argument('--cache_dir', type=str, default=None, help='Directory to cached models')
    parser.add_argument('--temperature', type=float, default=0, help='Temperature for sampling')
    parser.add_argument('--max_tokens', type=int, default=2048, help='Maximum number of tokens')
    parser.add_argument('--max_instances', type=int, default=sys.maxsize)

    args = parser.parse_args()

    print(f'Loading model {args.model_name_or_path}...')

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, padding_side="left") # this will pad the input from the left side for batch generation

    tokenizer.pad_token = tokenizer.eos_token

    task = 'xsum'
    mode = 'early_exit' # 'vanilla' or 'early_exit'

    prefix = 'early_exit_' if mode == 'vanilla' else 'early_exit_'

    model = MultiHeadLlamaForCausalLM.from_pretrained(
        args.model_name_or_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        cache_dir=args.cache_dir,
        low_cpu_mem_usage=True,
    )


    prompts = ["Who was the first American to win the Nobel Prize, in which year?"]

    generate(model, tokenizer, prompts, args, mode)

        
        