import numpy as np
from concurrent.futures import ProcessPoolExecutor
import tqdm
import json 
from vllm import LLM, SamplingParams
from utils_execution import BASE_IMPORTS, check_correctness
from transformers import AutoTokenizer

def make_direct_output_prompt(s):
    code, input = s
    return f"""You are given a Python function and an assertion containing an input to the function. Complete the assertion with a literal (no unsimplified expressions, no function calls) containing the output when executing the provided code on the given input, even if the function is incorrect or incomplete. Do NOT output any extra information. Provide the full assertion with the correct output in [ANSWER] and [/ANSWER] tags, following the examples.

[PYTHON]
def repeatNumber(number : int) -> int:
    return number
assert repeatNumber(number = 17) == ??
[/PYTHON]
[ANSWER]
assert repeatNumber(number = 17) == 17
[/ANSWER]

[PYTHON]
def addCharacterA(string : str) -> str:
    return string + "a"
assert addCharacterA(string = "x9j") == ??
[/PYTHON]
[ANSWER]
assert addCharacterA(string = "x9j") == "x9ja"
[/ANSWER]

[PYTHON]
{code}
assert {input} == ??
[/PYTHON]
[ANSWER]
"""

def make_cot_output_prompt(s):
    code, input = s
    return f"""You are given a Python function and an assertion containing an input to the function. Complete the assertion with a literal (no unsimplified expressions, no function calls) containing the output when executing the provided code on the given input, even if the function is incorrect or incomplete. Do NOT output any extra information. Execute the program step by step before arriving at an answer, and provide the full assertion with the correct output in [ANSWER] and [/ANSWER] tags, following the examples.

[PYTHON]
def performOperation(s):
    s = s + s
    return "b" + s + "a"
assert performOperation(s = "hi") == ??
[/PYTHON]
[THOUGHT]
Let's execute the code step by step:

1. The function performOperation is defined, which takes a single argument s.
2. The function is called with the argument "hi", so within the function, s is initially "hi".
3. Inside the function, s is concatenated with itself, so s becomes "hihi".
4. The function then returns a new string that starts with "b", followed by the value of s (which is now "hihi"), and ends with "a".
5. The return value of the function is therefore "bhihia".
[/THOUGHT]
[ANSWER]
assert performOperation(s = "hi") == "bhihia"
[/ANSWER]

[PYTHON]
{code}
assert {input} == ??
[/PYTHON]
[THOUGHT]
"""

def run_batch(prompts):
    model_tokenizer_path = "qwen-7b-05131"
    tokenizer = AutoTokenizer.from_pretrained(model_tokenizer_path, trust_remote_code=True)
    llm = LLM(
        model=model_tokenizer_path,
        tokenizer=model_tokenizer_path,
        dtype="bfloat16",
        enforce_eager=True,
        max_model_len=8000,
        trust_remote_code=True,
        gpu_memory_utilization=0.98,
        tensor_parallel_size=4,
    )
    sampling_params = SamplingParams(
        n=1,
        max_tokens=2048,
        temperature=0,
        top_p=1.0,
        frequency_penalty=0,
        presence_penalty=0,
        stop_token_ids=[tokenizer.eos_token_id, tokenizer.pad_token_id],
    )

    outputs = [None for _ in prompts]
    remaining_prompts = []
    remaining_indices = []
    for prompt_index, prompt in enumerate(prompts):
        remaining_prompts.append(prompt)
        remaining_indices.append(prompt_index)

    if remaining_prompts:
        vllm_outputs = llm.generate(remaining_prompts, sampling_params)
        for index, vllm_output in zip(remaining_indices, vllm_outputs):
            outputs[index] = [o.text for o in vllm_output.outputs]
    
    return outputs

def evaluate_score(args) -> list[bool]:
    gs, (c, i, o) = args

    execution_results = []
    for g in gs:
        if i in g:
            pass
        else:
            code_to_execute = f"{BASE_IMPORTS}\n{c}\nassert {o} == {g}"
            execution_results.append(check_correctness(code_to_execute, 3))
    if len(execution_results) == 0:
        execution_results = [False] * len(gs)
    return execution_results

def pass_at_k(n, c, k):
    if n - c < k: return 1.0
    return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))

def code_execution_metrics(
    samples,
    generations,
):
    # execute the code
    references = [(doc["code"], doc["input"], doc["output"]) for doc in samples]
    with ProcessPoolExecutor() as executor:
        args_list = zip(generations, references)
        results = executor.map(evaluate_score, args_list)
    all_results = list(results)

    # compute pass@1
    pass_at_1s = []
    for execution_result in all_results:
        c, n = execution_result.count(True), len(execution_result)
        pass_at_1s.append(pass_at_k(n, c, 1))
    metrics = {"pass@1": sum(pass_at_1s) / len(pass_at_1s) * 100}

    results = {}
    for i, r in enumerate(all_results):
        r_new = []
        for _r in r:
            r_new.append([_r])
        results[i] = r_new
    return [metrics, results]


def main():
    test_path = "livecodebench/execution2/test.jsonl"
    test_samples = []
    with open(test_path, 'r') as f:
        for line in f:
            test_samples.append(json.loads(line))

    print(f"test samples length = {len(test_samples)}")
    
    prompts = [make_cot_output_prompt((item['code'], item['input'])) for item in test_samples]
    outputs = run_batch(prompts)
    
    generations = []
    for model_output in outputs:
        model_output = model_output[0]
        if "==" in model_output:
            model_output = model_output.split("==")[1].strip()
        if "[/ANSWER]" in model_output:
            model_output = model_output.split("[/ANSWER]")[0].strip()
        if "[ANSWER]" in model_output:
            model_output = model_output.split("[ANSWER]")[1].strip()
        else:
            model_output = model_output.split("\n")[0].strip()
            
        generations.append([model_output])

    print(generations)
    
    result = code_execution_metrics(test_samples, generations)
    metrics = result[0]
    results = result[1]

    print(metrics)
    print(results)



if __name__ == "__main__":
    main()