import torch
import evaluate
import datasets
from my_utils.perplexity import Perplexity
from my_utils.load_model import load_model
# from my_utils.load_input import read_context_files
import argparse
from eval.needle.utils import load_context
from my_utils.my_generation import set_topk

def perplexity_hf(input_texts, model_id):
    perplexity = evaluate.load("perplexity", module_type="metric")
    results = perplexity.compute(model_id=model_id,
                                predictions=input_texts,
                                batch_size=8
                                )
    # ['perplexities', 'mean_perplexity']
    print('mean_perplexity: ', round(results["mean_perplexity"], 2))
    del perplexity, results
    return


def perplexity_my(input_texts, model, model_id, tokenizer, device=None, torch_dtype=torch.float32, 
                  eval_standard=False, debug_mode=False, max_length=3200, generation_length=40):
    perplexity = Perplexity()
    results = perplexity.compute(model=model,
                                model_id=model_id,
                                tokenizer=tokenizer,
                                predictions=input_texts,
                                batch_size=1,
                                max_length=max_length,
                                generation=generation_length,
                                device=device,
                                torch_dtype=torch_dtype,
                                eval_standard=eval_standard,
                                debug_mode=debug_mode
                                )
    # print('mean_perplexity: ', round(results["mean_perplexity"], 2))
    return results


parser = argparse.ArgumentParser()
# parser.add_argument('--model', type=str, default='meta-llama/Meta-Llama-3.1-8B-Instruct',
#                     choices=['microsoft/Phi-3.5-mini-instruct', 'mistralai/Mistral-Nemo-Instruct-2407',
#                              'meta-llama/Meta-Llama-3.1-8B-Instruct'])
parser.add_argument('--model', type=str, default='meta-llama/Meta-Llama-3.1-8B-Instruct')
parser.add_argument('--modified', type=str, default=None,
                    choices=['gemfilter', 'snapkv', 'h2o', 'knn'])  # None for standard attention
parser.add_argument('--topk', type=int, default=1024, help='KV cache size')
parser.add_argument('--ctx_len', type=int, default=32768,
                    help='haystack context token length')
parser.add_argument('--gen_len', type=int, default=1024,
                    help='generation length')
parser.add_argument('--debug', action='store_true', help='whether show full perplexity')
parser.add_argument('--eval_standard', action='store_true',
                    help='whether show full perplexity')
args = parser.parse_args()


model_id = args.model
modified = args.modified
flash_attention_2 = True

eval_standard = args.eval_standard
debug_mode = args.debug
max_length = args.ctx_len
generation_length = args.gen_len
topk = args.topk

torch_dtype = torch.float16

context = load_context(
    fpath="eval/needle/PaulGrahamEssays/*.txt", ctx_len=2*max_length)
texts = ['']
texts[0] += context

# perplexity_hf(texts, model_id) 

model, tokenizer = load_model(model_id, modified=modified, torch_dtype=torch_dtype, flash_attention_2=flash_attention_2)
print(model.model)

if not modified:
    result = perplexity_my(texts, model, model_id, tokenizer, 
                    device=None, torch_dtype=torch_dtype, eval_standard=eval_standard, 
                    debug_mode=debug_mode, max_length=max_length)


if modified:
    topk_list = [2**2, 2**4, 2**6, 2**8, 2**10, 2**12]
    # topk_list = [2**2]
    results_list = []
    for topk in topk_list:
        set_topk(model, topk, mode=modified)
        result = perplexity_my(texts, model, model_id, tokenizer, 
                    device=None, torch_dtype=torch_dtype, eval_standard=eval_standard, 
                    debug_mode=debug_mode, max_length=max_length)
        results_list.append(result)
    print(results_list)
