import torch
from any_precision import AnyPrecisionForCausalLM_3456
from transformers import AutoTokenizer, TextStreamer, LlamaForCausalLM
from torch.utils.data import DataLoader
from tqdm import tqdm
import lm_eval
import argparse
import os
from filelock import Timeout, FileLock
import sys, subprocess

parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--limit", type=int, default=None)
parser.add_argument("--num_fewshot", type=int, default=5)
parser.add_argument("--wbits", type=int, default=6)
parser.add_argument("--max_wbits", type=int, default=6) # only for decode mode
parser.add_argument("--model_type", choices=["ap"], default="ap")
parser.add_argument("--mode", choices=["orig", "jl", "mqdecode", "decode", "oracle"], default="orig")
parser.add_argument("--suffix", default="")
parser.add_argument("--targ_setup", type=str, default="")
parser.add_argument("--model_path", type=str)
parser.add_argument("--corr_arr_path", type=str)
parser.add_argument("--max_mem_dict_path", type=str)
parser.add_argument("--targ_path", type=str)
parser.add_argument("--jl_path", type=str)

args = parser.parse_args()

batch_size = args.batch_size
limit = args.limit
num_fewshot = args.num_fewshot
wbits = args.wbits
model_type = args.model_type
mode = args.mode
suffix = args.suffix
if suffix != "":
    suffix = "_"+suffix

lock_path = "lock_llama3.lock"
lock = FileLock(lock_path, timeout=-1)

if args.targ_setup != "":
    targ_str : str = args.targ_setup
    targ_str_arr = targ_str.split(" ")
    call_arr = [sys.executable]
    call_arr.extend(targ_str_arr)
    print("Acquring lock...")
    lock.acquire()
    print(f"Calling with {call_arr}")
    subprocess.call(call_arr)

model_path = args.model_path

tokenizer = AutoTokenizer.from_pretrained(model_path)

if model_type == "ap":
    prec_arr = [i for i in range(3,args.max_wbits+1)]
    path_dict={"corr_arr_path": args.corr_arr_path,
               "max_mem_dict_path": args.max_mem_dict_path,
               "targ_path_fn": lambda x: f"{args.targ_path}/{x[0]}_{x[1]}_targ.pt",
               "jl_path_fn": lambda x: f"{args.jl_path}/{x[0]}_{x[1]}_jl.pt",}
    model = AnyPrecisionForCausalLM_3456.from_quantized(model_path, precisions=prec_arr, mode=mode, model="llama3", path_dict=path_dict)
    model.set_precision(wbits)
    print("Using AP")
    print(f"Using {wbits}bits")
else:
    raise RuntimeError(f"Non ap model type {model_type}")

model = model.eval().cuda()
if model_type == "ap":
    model.setMotherLayer()

if args.targ_setup != "":
    print("Releasing Lock.")
    lock.release()

lm_obj = lm_eval.models.huggingface.HFLM(pretrained=model, tokenizer=tokenizer)

task_manager = lm_eval.tasks.TaskManager()

res = lm_eval.simple_evaluate(
    model=lm_obj,
    tasks=["gsm8k"],
    num_fewshot=num_fewshot,
    task_manager=task_manager,
    limit=limit,
    batch_size=args.batch_size,
    gen_kwargs="max_gen_toks=256",
    log_samples=True
)

if model_type == "ap":
    ebits = model.get_effective_bits()
    print(f"{ebits} effective bits")

save_dir = "out_llama3_3456"
os.makedirs(save_dir, exist_ok=True)
of = open(f"{save_dir}/gsm8k_{num_fewshot}shot_0-{(limit if limit else 'all')}_max{args.max_wbits}_{wbits}b_{model_type}{suffix}.txt", "w")
if model_type == "ap":
    of.write(f"{ebits} effective bits\n")
of.write(f"{res}")
of.close()
print(f"acc={res['results']['gsm8k']['exact_match,strict-match']}")