import pickle
import vllm
import torch
from tqdm import tqdm
import time
import argparse
import random
import math
from datasets import load_dataset
from vllm.config import JudgeConfig


from livecodebench_v5 import (
    CodeGenerationProblem,
    unpack_lcb_data,
    codegen_metrics,
    extract_code,
)

random.seed(42)

def remove_bos(text):
    bos = "<|begin_of_text|>"
    assert text.startswith(bos)
    return text[len(bos):]


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-o", dest="out_file", required=True)
    parser.add_argument("--target_model", required=True)
    parser.add_argument("--draft_model", required=False)
    parser.add_argument("--no_spec_dec", action="store_true")
    parser.add_argument("--no_judge", action="store_true")
    parser.add_argument("--data_size", type=float, default=1.0)
    parser.add_argument("--window_size", type=int, default=32)
    parser.add_argument("--judge_path", required=False)
    parser.add_argument("--judge_threshold", type=float, required=False)
    parser.add_argument("--fold", type=int)
    return parser.parse_args()


def load_lcb(release_version="release_v5", shuffle=False, data_size=1.0, task_ids=None):

    dataset = load_dataset(
        "livecodebench/code_generation_lite",
        split="test",
        version_tag=release_version,
        trust_remote_code=True,
    )
    if shuffle:
        dataset = dataset.shuffle(seed=43)
    if data_size != 1.0:
        dataset = dataset.select(range(math.ceil(data_size * len(dataset))))
    dataset = [CodeGenerationProblem(**p) for p in tqdm(dataset, desc="Loading dataset")]  # type: ignore
    if task_ids:
        dataset = filter_by_ids(dataset, task_ids)

    return dataset


def filter_by_ids(dataset, ids):
    filtered_dataset = [i for i in dataset if i.question_id in ids]
    return filtered_dataset


def eval(
    target_model,
    draft_model,
    use_spec_dec,
    data_size,
    window_size,
    judge_data,
    judge_threshold,
    fold
):
    for fold_idx, (task_ids, judge_params) in enumerate(judge_data):
        if fold is not None and fold_idx != fold:
            continue
        dataset = load_lcb(data_size=data_size, task_ids=task_ids)
        if use_spec_dec:
            judge_config = JudgeConfig(
                weights=judge_params["weights"],
                mean=judge_params["mean"],
                scale=judge_params["scale"],
                bias=judge_params["bias"],
                threshold=judge_threshold,            
            )
        else:
            judge_config = None
        fold_res = eval_fold(dataset,target_model, draft_model, use_spec_dec, window_size, judge_config)
        yield fold_idx, fold_res
        
def eval_fold(
    dataset,
    target_model,
    draft_model,
    use_spec_dec,
    window_size,
    judge_config
):
    sampling_params = vllm.SamplingParams(temperature=0.0, top_p=1.0, max_tokens=4096)
    if use_spec_dec:
        speculative_config = {
            "model": draft_model,
            "num_speculative_tokens": window_size,
            "judge_config": judge_config
        }
    else:
        speculative_config = None
    llm = vllm.LLM(
        model=target_model,
        dtype="bfloat16",
        # dtype="float32",
        gpu_memory_utilization=0.8,
        tensor_parallel_size=torch.cuda.device_count(),
        speculative_config=speculative_config,
    )

    tokenizer = llm.get_tokenizer()
    lcb_dataset, lcb_inputs_outputs = unpack_lcb_data(dataset, tokenizer)

    total_time = 0.0
    outputs = []
    for task in tqdm(lcb_dataset, total=len(lcb_dataset)):
        # inputs = (
        #     tokenizer([task["prompt"]], add_special_tokens=False, return_tensors="pt")
        #     .input_ids.squeeze(0)
        #     .cpu()
        #     .tolist()
        # )
        start_time = time.time()
        prompt =  remove_bos(task["prompt"])
        output = llm.generate([prompt], sampling_params=sampling_params, use_tqdm=False)
        end_time = time.time()
        elapsed = end_time - start_time
        total_time += elapsed
        outputs += output

    assert all(len(output.outputs) == 1 for output in outputs)
    progs = [extract_code(out.text) for output in outputs for out in output.outputs]
    prompt_tokens = sum(len(out.prompt_token_ids) for out in outputs)
    output_tokens = sum(
        len(out.token_ids) for output in outputs for out in output.outputs
    )

    result = codegen_metrics(
        tqdm(lcb_inputs_outputs), [[prog] for prog in progs], num_process_evaluate=16
    )
    pass_at_1 = result[0]["pass@1"].item()

    return pass_at_1, total_time, prompt_tokens, output_tokens

def get_judge_config(args):
    with open(args.judge_path, "rb") as f:
        data = pickle.load(f)
        judge_config = JudgeConfig(
            weights=data["weights"],
            mean=data["mean"],
            scale=data["scale"],
            bias=data["bias"],
            threshold=args.judge_threshold,            
        )
    return judge_config


if __name__ == "__main__":
    args = parse_args()
    out_file = args.out_file
    if not args.no_judge:
        with open(args.judge_path, "rb") as f:
            judge_data = pickle.load(f)
    else:
        judge_data = None

    autojudge_threshold = args.judge_threshold
    for fold_idx, res in eval(
        args.target_model,
        args.draft_model,
        not args.no_spec_dec,
        args.data_size,
        args.window_size,
        judge_data,
        args.judge_threshold,
        args.fold
    ):
        with open(out_file, "a") as f:
            print(autojudge_threshold, fold_idx, *res, file=f)
