import argparse
import torch
import os

import lm_eval
from tqdm import tqdm
from lm_eval.models.huggingface import HFLM
from lm_eval.utils import make_table

from utils.data import get_loaders, set_seed
from utils.evaluator import PPLMetric, eval_ppl, evaluate_model, eval_zero_shot, llama_eval
from datasets import load_dataset


def get_gsm8k_promp():
    dataset = load_dataset('/mnt/bd/pretraining/mjl_work/data/gsm8k', 'main')

    prompt = ''
    for i in range(5):
        prompt += 'Question: ' + dataset['train'][i]['question'] + '\nAnswer: ' + dataset['train'][i]['answer'] + '\n'
    prompt += "Question: John takes care of 10 dogs. Each dog takes .5 hours a day to walk and take care of their business. How many hours a week does he spend taking care of dogs?"
    return prompt

def chat(model, tokenizer, prompt):
    # print(prompt)
    print("")
    inputs = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")

    output = model.generate(inputs, max_new_tokens=256, do_sample=False)
    # config_str = f"# prompt tokens: {inputs.shape[1]}, K bit: {config.k_bits}, v_bits: {config.v_bits}, group_size: {config.group_size}, residual_length: {config.residual_length}"

    print("=" * 10 + "=" * 10 + "\nOutput:")
    print(tokenizer.decode(output[0].tolist()[inputs.shape[1]:], skip_special_tokens=True))
    print(output[0].tolist()[inputs.shape[1]:])


def test(model, tokenizer):
    prompt = get_gsm8k_promp()
    chat(model, tokenizer, prompt)

def eval(model, tokenizer, tasks, withppl=False, batch_size=1):
    # test(model, tokenizer)
    if withppl:
        _, test_loader = get_loaders('wikitext2', tokenizer, seq_len=4096, batch_size=1)
        ppl_test, logits = llama_eval(model, test_loader, 'cuda')
        print("ppl after prune", ppl_test)

    if tasks:
        lm_obj = HFLM(pretrained=model, tokenizer=tokenizer, add_bos_token=False, batch_size=batch_size)
        # indexes all tasks from the lm_eval/tasks subdirectory.
        # Alternatively, you can set TaskManager(include_path="path/to/my/custom/task/configs")
        # to include a set of tasks in a separate directory.
        task_manager = lm_eval.tasks.TaskManager()

        # Setting task_manager to the one above is optional and should generally be done
        # if you want to include tasks from paths other than ones in lm_eval/tasks.
        # simple_evaluate will instantiate its own task_manager is the it is set to None here.

        with torch.no_grad():
            results = lm_eval.simple_evaluate( # call simple_evaluate
                model=lm_obj,
                #model_args= "add_bos_token=True" if model_type == "jamba" else "",
                tasks=tasks,
                task_manager=task_manager,
                log_samples=False,
            ) 
        return make_table(results)
    
    