import argparse
import json
import torch
from tqdm import tqdm
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils_exp import *

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, choices=["mmlu", "medmcqa", "commonsenseqa", "hellaswag"], default="mmlu")
    parser.add_argument("--model", type=str, choices=["gemma-2-2b-it", "gemma-2-9b-it", "gemma-2-27b-it", "Meta-Llama-3.1-8B-Instruct"], default="gemma-2-2b-it")
    parser.add_argument("--cache_dir", type=str, default="******")
    parser.add_argument("--hf_token", type=str, default="******")
    parser.add_argument("--cot", action='store_true')
    parser.add_argument("--load_from_cpu", action='store_true', help="whether or not to first load model to cpu then move to cuda device")
    parser.add_argument("--device", type=str, choices=["0", "1", "2"], default="0")
    parser.add_argument("--result_dir", type=str, default="******")
    parser.add_argument("--save_suffix", type=str, default="")
    parser.add_argument("--start_idx", type=str, default="0")
    parser.add_argument("--num_points", type=str, default="0")
    return parser.parse_args()

def main():
    args = parse_args()
    torch.set_grad_enabled(False)
    torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True

    # load dataset and model
    login(args.hf_token)
    print("LOADING DATASET AND MODEL...")
    if args.cot:
        print("COT PROMPTS VERSION...")
    ds, ds_util = load_ds(args.dataset, args.cache_dir, cot=args.cot)
    device = f"cuda:{args.device}"
    # can't seem to load model directly into cuda:1 and cuda:2, load and move from cpu
    if args.load_from_cpu:
        device = "cpu"
    # model = load_model(args.model, args.cache_dir, device=device)
    dtype = torch.float16
    mname = "google/gemma-2-2b-it"
    if args.model == "gemma-2-9b-it":
        mname = "google/gemma-2-9b-it"
    if args.model == "gemma-2-27b-it":
        dtype = torch.bfloat16
        mname = "google/gemma-2-27b-it"
        # device = "auto"
    if args.model == "Meta-Llama-3.1-8B-Instruct":
        mname = "meta-llama/Meta-Llama-3.1-8B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(mname, device_map="cpu", cache_dir=args.cache_dir)
    model = AutoModelForCausalLM.from_pretrained(
        mname,
        device_map=device,
        cache_dir=args.cache_dir,
        torch_dtype = dtype
    )
    if args.model == "Meta-Llama-3.1-8B-Instruct":
        tokenizer.pad_token = tokenizer.eos_token
    print(model.dtype)
    # predict and save results
    model.eval()
    n_points = ds.num_rows if args.num_points == "0" else int(args.num_points)
    start_idx = int(args.start_idx)
    n_p = 12
    n_samples = [12, 6, 4, 3, 2, 2, 1, 1, 1, 1, 1, 1]
    generations = {}
    dname = args.dataset

    print("STARTING MODEL PREDICTIONS...")
    max_token = 50
    if args.cot:
        max_token = 1000
    if "27" in args.model:
        max_token = 600
    
    # start generation
    for i in tqdm(range(start_idx, start_idx+n_points)):
        gen = {}
        clean_gpus()
        answer = None
        if dname == "mmlu":
            question = ds[i]["question"]
            candidate_answers = ds[i]["choices"]
            answer = ds[i]["answer"]
            gen["question"] = question
            gen["candidates"] = candidate_answers
            gen["answer"] = answer
        if "gemma" in args.model:
            for j in [0, 2, 6]:
                clean_gpus()
                if j == 0:
                    n_samples_generate = 12
                    n_prompts = 2
                elif j == 2:
                    n_samples_generate = 4
                    n_prompts = 4
                else:
                    n_samples_generate = 1
                    n_prompts = 6
                prompts = []
                for k in range(n_prompts):
                    prompts.append(get_prompt(ds, i, prompt_type=j+k, gemma_model=True, dname=args.dataset, d_util=ds_util, cot=args.cot))
                input_ids = tokenizer(prompts, return_tensors="pt", padding=True, truncation=False).to(device)
                outputs = model.generate(**input_ids, max_new_tokens=max_token, do_sample=True, temperature=0.7, num_return_sequences=n_samples_generate)
                model_ans_samples = tokenizer.batch_decode(outputs, skip_special_tokens=True)
                # record answers
                for k in range(n_prompts):
                    gen_responses = []
                    for l in range(n_samples_generate):
                        gen_responses.append(model_ans_samples[k*n_samples_generate+l].split('\nmodel\n')[-1])
                    gen[j+k] = gen_responses
            generations[i] = gen
        if "Llama" in args.model:
            for j in range(n_p):
                # for k in range(n_samples[n_p]):
                prompts = [get_prompt(ds, i, prompt_type=j, gemma_model=False, llama_model=True, dname=args.dataset, d_util=ds_util, cot=args.cot)]
                input_ids = tokenizer(prompts, return_tensors="pt", padding=False, truncation=False).to(model.device)
                outputs = model.generate(**input_ids, max_new_tokens=max_token, do_sample=True, temperature=0.7, top_p=0.9, num_return_sequences=n_samples[j])
                model_ans_samples = tokenizer.batch_decode(outputs, skip_special_tokens=True)
                gen_responses = []
                for k in range(n_samples[j]):
                    gen_responses.append(model_ans_samples[k].split("assistant\n\n")[-1])
                gen[j] = gen_responses
            generations[i] = gen
    print("SAVING RESULTS...")
    # save generations
    save_dir = args.result_dir
    fname_gen = ""
    save_suffix = args.save_suffix
    if save_suffix is None:
        save_suffix = ""
    if args.cot:
        fname_gen = save_dir + f"gen_{dname}_{args.model}_cot{args.save_suffix}_12x12.json"
    else:
        fname_gen = save_dir + f"gen_{dname}_{args.model}{args.save_suffix}_12x12.json"
    with open(fname_gen, "w") as f:
        json.dump(generations, f)

    print(f"GENERATION DONE WITH {args.model} on {args.dataset} dataset!")

if __name__ == "__main__":
    main()
