
from numpy import remainder
from compute_code_generation_metrics import codegen_metrics
from codeproblem import load_code_generation_dataset
from vllm import LLM, SamplingParams
import json 
import torch
from transformers import AutoTokenizer


class PromptConstants:
    SYSTEM_MESSAGE_QWENCODER = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user"
    SYSTEM_MESSAGE_OPENCODER = f"<|im_start|>system\nYou are OpenCoder, created by OpenCoder Team.<|im_end|>\n<|im_start|>user"
    FORMATTING_MESSAGE_WITH_STARTER_CODE = "You will use the following starter code to write the solution to the problem and enclose your code within delimiters."
    FORMATTING_WITHOUT_STARTER_CODE = "Read the inputs from stdin solve the problem and write the answer to stdout (do not directly test on the sample inputs). Enclose your code within delimiters as follows. Ensure that when the python program runs, it reads the inputs, runs the algorithm and writes output to STDOUT."
    

def get_qwencoder_question_template_answer(question):
    prompt = "You will be given a question (problem specification) and will generate a correct Python program that matches the specification and passes all tests. You will NOT return anything except for the program.\n\n"
    prompt += f"Question: {question.question_content}\n\n"
    if question.starter_code:
        prompt += f"{PromptConstants.FORMATTING_MESSAGE_WITH_STARTER_CODE}\n"
        prompt += f"```python\n{question.starter_code}\n```\n\n<|im_end|>\n"
    else:
        prompt += f"{PromptConstants.FORMATTING_WITHOUT_STARTER_CODE}\n"
        prompt += f"```python\n# YOUR CODE HERE\n```\n\n<|im_end|>\n"
    prompt += f"<|im_start|>assistant\n"
    return prompt

def get_opencoder_question_template_answer(question):
    prompt = "You will be given a question (problem specification) and will generate a correct Python program that matches the specification and passes all tests. You will NOT return anything except for the program.\n\n"
    prompt += f"Question: {question.question_content}\n\n"
    if question.starter_code:
        prompt += f"{PromptConstants.FORMATTING_MESSAGE_WITH_STARTER_CODE}\n"
        prompt += f"```python\n{question.starter_code}\n```\n\n<|im_end|>\n"
    else:
        prompt += f"{PromptConstants.FORMATTING_WITHOUT_STARTER_CODE}\n"
        prompt += f"```python\n# YOUR CODE HERE\n```\n\n<|im_end|>\n"
    prompt += f"<|im_start|>assistant\n"
    return prompt

def extract_code(model_output: str):
    outputlines = model_output.split("\n")
    indexlines = [i for i, line in enumerate(outputlines) if "```" in line]
    if len(indexlines) < 2:
        return ""
    return "\n".join(outputlines[indexlines[0] + 1 : indexlines[1]])

# def format_prompt(question):
#     FORMATTING_MESSAGE_WITH_STARTER_CODE = "You will use the following starter code to write the solution to the problem and enclose your code within delimiters."
#     FORMATTING_WITHOUT_STARTER_CODE = "You will read the inputs from stdin solve the problem and write the answer to stdout (do not directly test on the sample inputs). Enclose your code within delimiters."
#     if question.starter_code:
#         prompt = f"### Instruction:\n{question.question_content}\n{FORMATTING_MESSAGE_WITH_STARTER_CODE}\n```python\n{question.starter_code}\n```\n\n### Response:\n" 
#     else:
#         prompt = f"### Instruction:\n{question.question_content}\n{FORMATTING_WITHOUT_STARTER_CODE}\n\n### Response:\n"
#     return prompt

def format_prompt_qwencoder(question):
    prompt = f"{PromptConstants.SYSTEM_MESSAGE_QWENCODER}\n\n"
    prompt += f"{get_qwencoder_question_template_answer(question)}"
    return prompt

def format_prompt_opencoder(question):
    prompt = f"{PromptConstants.SYSTEM_MESSAGE_OPENCODER}\n\n"
    prompt += f"{get_opencoder_question_template_answer(question)}"
    return prompt


def extract_instance_results(results):
    instance_wise_grades = {}
    for task_id, res in results.items():
        instance_wise_grades[task_id] = []
        for generation in res:
            instance_wise_grades[task_id].append(all([g > 0 for g in generation]))

    instance_wise_grades = [
        v for _, v in sorted(instance_wise_grades.items(), key=lambda item: item[0])
    ]
    return instance_wise_grades

def run_batch(prompts):
    model_tokenizer_path = "/ainative/codefuse/user/448207/code_data_results/public_private_contrastive_prompt_tuning/sft_models/lcb-code-gen/0430-3-temp08-lora8-merge"
    tokenizer = AutoTokenizer.from_pretrained(model_tokenizer_path, trust_remote_code=True)
    llm = LLM(
        model=model_tokenizer_path,
        tokenizer=model_tokenizer_path,
        dtype="bfloat16",
        enforce_eager=True,
        max_model_len=8000,
        trust_remote_code=True,
        gpu_memory_utilization=0.95,
        tensor_parallel_size=4,
    )
    sampling_params = SamplingParams(
        n=10,
        max_tokens=4096,
        temperature=0.2,
        top_p=1.0,
        frequency_penalty=0,
        presence_penalty=0,
        stop_token_ids=[tokenizer.eos_token_id, tokenizer.pad_token_id],
    )

    outputs = [None for _ in prompts]
    remaining_prompts = []
    remaining_indices = []
    for prompt_index, prompt in enumerate(prompts):
        remaining_prompts.append(prompt)
        remaining_indices.append(prompt_index)

    if remaining_prompts:
        vllm_outputs = llm.generate(remaining_prompts, sampling_params)
        for index, vllm_output in zip(remaining_indices, vllm_outputs):
            outputs[index] = [o.text for o in vllm_output.outputs]
    
    return outputs

def get_metrics(benchmark, combined_results):
    eval_samples = [instance.get_evaluation_sample() for instance in benchmark]
    generations = [extracted for _, extracted in combined_results]

    metrics = codegen_metrics(
        eval_samples,
        generations,
        num_process_evaluate=12,
        timeout=6,
        debug=False,
    )

    print(f"pass@1 = {metrics[0]['pass@1']}")
    print(f"pass@5 = {metrics[0]['pass@5']}")
    print(f"pass@10 = {metrics[0]['pass@10']}")

    return metrics

def main():
    benchmark = load_code_generation_dataset()
    benchmark = sorted(benchmark, key=lambda x: x.question_id)

    prompts = [format_prompt_qwencoder(problem) for problem in benchmark]
    # print(prompts[:3])
    # assert False
    results = run_batch(prompts)
    
    combined_results = [(outputs_list, [extract_code(output) for output in outputs_list]) for outputs_list in results]

    save_results = [instance.insert_output(outputs_list, extracted_list) for instance, (outputs_list, extracted_list) in zip(benchmark, combined_results)]

    save_results = sorted(save_results, key=lambda x: x["question_id"])

    # output_path = "check4save_result_temp10.json"
    # with open(output_path, "w") as f:
    #     json.dump(save_results, f, indent=4)
    
    # with open(output_path, "r") as f:
    #     save_results = json.load(f)
    
    combined_results = [(save_result_instance["output_list"], save_result_instance["code_list"])  for save_result_instance in save_results]
    # print(combined_results)

    metrics = get_metrics(benchmark, combined_results)
    graded = extract_instance_results(metrics[1])
    metadatas = metrics[2]

    save_eval_results = [instance.insert_output_evaluation(outputs_list, extracted_list, graded_list, metadata=meta)
    for instance, (outputs_list, extracted_list), graded_list, meta in zip(benchmark, combined_results, graded, metadatas)]

    # eval_file = "check4save_result_removed2_temp10_eval.json"
    # with open(eval_file, "w") as f:
    #     json.dump(metrics, f, indent=4)

    eval_all_file = "check4save_result_removed2_temp10_eval_all.json"
    with open(eval_all_file, "w") as f:
        json.dump(save_eval_results, f, indent=4)



if __name__ == "__main__":
    main()
