
# Import necessary modules
import argparse
import torch
import argparse
from tqdm import tqdm
import lm_eval
from lm_eval.models.huggingface import HFLM
from lm_eval.utils import make_table
import os
import json

from src.utils import set_all_seeds
from src.utils import load_yaml_as_dict
from src.model import load_model_and_tokenizer

def run_lm_eval_zero_shot(model, tokenizer, batch_size=64, max_length=4096, seed=0, task_list=["arc_easy", "hellaswag"], limit=None):
    model.seqlen = max_length
    lm_obj = HFLM(pretrained=model, tokenizer=tokenizer, add_bos_token=False, batch_size=batch_size)
    # indexes all tasks from the lm_eval/tasks subdirectory.
    # Alternatively, you can set TaskManager(include_path="path/to/my/custom/task/configs")
    # to include a set of tasks in a separate directory.
    task_manager = lm_eval.tasks.TaskManager()

    # Setting task_manager to the one above is optional and should generally be done
    # if you want to include tasks from paths other than ones in lm_eval/tasks.
    # simple_evaluate will instantiate its own task_manager is the it is set to None here.
    with torch.no_grad():
        if seed < 0:
            results = lm_eval.simple_evaluate( # call simple_evaluate
                model=lm_obj,
                #model_args= "add_bos_token=True" if model_type == "jamba" else "",
                tasks=task_list,
                task_manager=task_manager,
                log_samples=False,
                limit=limit,
                confirm_run_unsafe_code=True
            )
        else:
            results = lm_eval.simple_evaluate( # call simple_evaluate
                model=lm_obj,
                #model_args= "add_bos_token=True" if model_type == "jamba" else "",
                tasks=task_list,
                task_manager=task_manager,
                log_samples=False,
                limit=limit,
                random_seed=seed,
                numpy_random_seed=seed,
                torch_random_seed=seed,
                fewshot_random_seed=seed,
                confirm_run_unsafe_code=True
            )

    res = make_table(results)
    print(res)

    return results['results']


def main():
    """
    Use argparse here because of the dependency conflict between hydra and lm eval
    """
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--task',
        type=str
    )
    parser.add_argument(
        '--batch_size',
        default=1,
        type=int,
        help='batch size for lm_eval tasks'
    )
    parser.add_argument(
        '--model_name_or_path',
        default="meta-llama/Llama-2-7b-chat-hf",
        type=str,
        help='A model to test'
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        help="Directory to save the .json results.",
        default="./lm_eval_output"
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=-1
    )
    parser.add_argument(
        "--quantizer",
        type=str,
        default="NSNquantizer"
    )

    args = parser.parse_args()
    tasks=[args.task]
    quantizer_conf = load_yaml_as_dict(f"src/conf/quantizer/{args.quantizer}.yaml")
    model, tokenizer = load_model_and_tokenizer(args.model_name_or_path, quantizer_conf, True if args.task=="mmlu" else False, None)
    res = run_lm_eval_zero_shot(model, tokenizer, args.batch_size, seed=args.seed, task_list=tasks)

    # Create directory if it doesn't exist
    os.makedirs(args.output_dir, exist_ok=True)

    # Save results to JSON file
    model_name = args.model_name_or_path.split("/")[-1]
    if args.seed < 0:
        output_file = os.path.join(args.output_dir, f"{model_name}_{quantizer_conf['save_postfix']}_{args.task}.json")
    else:
        output_file = os.path.join(args.output_dir, f"{model_name}_{quantizer_conf['save_postfix']}_{args.seed}_{args.task}.json")
    with open(output_file, "w") as f:
        json.dump(res, f, indent=4)

    print(f"Results saved to {output_file}")


if __name__ == "__main__":
    main()
