import numpy as np
from concurrent.futures import ProcessPoolExecutor
import tqdm
import json 
from vllm import LLM, SamplingParams
from utils_execution import BASE_IMPORTS, check_correctness

def make_direct_output_prompt(s):
    code, input = s
    return f"""You are given a Python function and an assertion containing an input to the function. Complete the assertion with a literal (no unsimplified expressions, no function calls) containing the output when executing the provided code on the given input, even if the function is incorrect or incomplete. Do NOT output any extra information. Provide the full assertion with the correct output in [ANSWER] and [/ANSWER] tags, following the examples.

[PYTHON]
def repeatNumber(number : int) -> int:
    return number
assert repeatNumber(number = 17) == ??
[/PYTHON]
[ANSWER]
assert repeatNumber(number = 17) == 17
[/ANSWER]

[PYTHON]
def addCharacterA(string : str) -> str:
    return string + "a"
assert addCharacterA(string = "x9j") == ??
[/PYTHON]
[ANSWER]
assert addCharacterA(string = "x9j") == "x9ja"
[/ANSWER]

[PYTHON]
{code}
assert {input} == ??
[/PYTHON]
[ANSWER]
"""

def run_batch(prompts):
    model_tokenizer_path = "Qwen/Qwen2.5-Coder-7B-Instruct"
    llm = LLM(
        model=model_tokenizer_path,
        tokenizer=model_tokenizer_path,
        dtype="bfloat16",
        enforce_eager=True,
        max_model_len=20000,
        trust_remote_code=True,
        gpu_memory_utilization=0.98,
        tensor_parallel_size = 4,
    )
    sampling_params = SamplingParams(
        n=1,
        max_tokens=2048,
        temperature=0.2,
        top_p=1.0,
        frequency_penalty=0,
        presence_penalty=0,
        # stop=["###"],
    )

    outputs = [None for _ in prompts]
    remaining_prompts = []
    remaining_indices = []
    for prompt_index, prompt in enumerate(prompts):
        remaining_prompts.append(prompt)
        remaining_indices.append(prompt_index)

    if remaining_prompts:
        vllm_outputs = llm.generate(remaining_prompts, sampling_params)
        for index, vllm_output in zip(remaining_indices, vllm_outputs):
            outputs[index] = [o.text for o in vllm_output.outputs]

        # for index, remaining_prompt in zip(remaining_indices, remaining_prompts):
        #     # print(remaining_prompt)
        #     vllm_outputs = llm.generate([remaining_prompt], sampling_params)
        #     output = [o.text for o in vllm_outputs[0].outputs]
        #     outputs[index] = output
    
    return outputs

def evaluate_score(args) -> list[bool]:
    gs, (c, i, o) = args

    execution_results = []
    for g in gs:
        if i in g:
            pass
        else:
            code_to_execute = f"{BASE_IMPORTS}\n{c}\nassert {o} == {g}"
            execution_results.append(check_correctness(code_to_execute, 3))
    if len(execution_results) == 0:
        execution_results = [False] * len(gs)
    return execution_results

def pass_at_k(n, c, k):
    if n - c < k: return 1.0
    return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))

def code_execution_metrics(
    samples,
    generations,
):
    # execute the code
    references = [(doc["code"], doc["input"], doc["output"]) for doc in samples]
    with ProcessPoolExecutor() as executor:
        args_list = zip(generations, references)
        results = executor.map(evaluate_score, args_list)
    all_results = list(results)

    # serial version
    # all_results = []
    # for i in range(len(generations)):
    #     generation = generations[i]
    #     result = evaluate_score([generation, references[i]])
    #     all_results.append(result)

    # compute pass@1
    pass_at_1s = []
    for execution_result in all_results:
        c, n = execution_result.count(True), len(execution_result)
        pass_at_1s.append(pass_at_k(n, c, 1))
    metrics = {"pass@1": sum(pass_at_1s) / len(pass_at_1s) * 100}

    results = {}
    for i, r in enumerate(all_results):
        r_new = []
        for _r in r:
            r_new.append([_r])
        results[i] = r_new
    return [metrics, results]


def main():
    test_path = "livecodebench/execution2/test.jsonl"
    test_samples = []
    with open(test_path, 'r') as f:
        for line in f:
            test_samples.append(json.loads(line))

    print(f"test samples length = {len(test_samples)}")
    
    prompts = [make_direct_output_prompt((item['code'], item['input'])) for item in test_samples]
    outputs = run_batch(prompts)
    
    generations = []
    for model_output in outputs:
        model_output = model_output[0]
        if "==" in model_output:
            model_output = model_output.split("==")[1].strip()
        if "[/ANSWER]" in model_output:
            model_output = model_output.split("[/ANSWER]")[0].strip()
        else:
            model_output = model_output.split("\n")[0].strip()
            
        generations.append(model_output)

    print(generations)
    # assert False
    
    result = code_execution_metrics(test_samples, generations)
    metrics = result[0]
    results = result[1]

    print(metrics)
    print(results)



if __name__ == "__main__":
    main()