import fire
import os
import torch
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import json
from utils import reformulate_with_sketch
from sqil_rephrase import generate_rephrasing_prompt, generate_catch_all_prompt, extract_code_snippet, extract_list_from_code


llm = LLM(
    model=".../hf_models/Meta-Llama-3.1-8B-Instruct",
    tokenizer=".../hf_models/Meta-Llama-3.1-8B-Instruct",
    tokenizer_mode="slow",
    dtype="bfloat16",
    tensor_parallel_size=torch.cuda.device_count(),
)
sampling_params = SamplingParams(
    temperature=0.2,
    top_p=1,
    max_tokens=2048,
)
tokenizer = AutoTokenizer.from_pretrained(".../hf_models/Meta-Llama-3.1-8B-Instruct")


def batch_query_llm(prompts, use_tqdm=True):
    with torch.no_grad():
        completions = llm.generate(prompts, sampling_params, use_tqdm=use_tqdm)
        return [completion.outputs[0].text for completion in completions]


def generate_correctness_prompt(problem, solution, reference_solution=None):
    """
    Generates a prompt to check the correctness of a solution.

    Parameters:
    problem (str): The problem statement.
    solution (str): The generated solution to evaluate.
    reference_solution (str or None): The reference solution, if available.

    Returns:
    str: The formatted prompt for correctness evaluation.
    """

    if reference_solution:
        # Prompt when a reference solution is provided
        prompt = (
            f"You are given a problem, a proposed solution, and a reference solution.\n\n"
            f"Problem:\n{problem}\n\n"
            f"Proposed Solution:\n{solution}\n\n"
            f"Reference Solution:\n{reference_solution}\n\n"
            "Your task is to determine if the proposed solution is correct based on the problem statement "
            "and is generally consistent with the reference solution. The proposed solution does not need to match "
            "the reference solution exactly, but it should address the problem correctly.\n\n"
            "**Conclude your response with EXACTLY ONE of the following statements:**\n"
            "- \"The solution is correct\" if the proposed solution meets the requirements.\n"
            "- \"The solution is incorrect\" if the proposed solution fails to meet the requirements.\n\n"
            "**This is NOT optional. Your response MUST end with either \"The solution is correct\" or \"The solution is incorrect.\"**\n"
            "**Limit your response to 200 words.**\n"
        )
    else:
        # Prompt when no reference solution is provided
        prompt = (
            f"You are given a problem and a proposed solution.\n\n"
            f"Problem:\n{problem}\n\n"
            f"Proposed Solution:\n{solution}\n\n"
            "Your task is to determine if the proposed solution is correct and addresses the problem accurately.\n\n"
            "**Conclude your response with EXACTLY ONE of the following statements:**\n"
            "- \"The solution is correct\" if the proposed solution meets the requirements.\n"
            "- \"The solution is incorrect\" if the proposed solution fails to meet the requirements.\n\n"
            "**This is NOT optional. Your response MUST end with either \"The solution is correct\" or \"The solution is incorrect.\"**\n"
            "**Limit your response to 200 words.**\n"
        )
    prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
    return prompt


def subgoal_rephrasing(problems_data, num_rephrasing_attempts, version="v10"):
    # problem, sketches -> rephrased_sketches

    prompts = []
    catch_all_prompts = []

    for item in problems_data:
        for sketch in item["sketches"]:
            for subgoal in sketch:
                prompt = generate_rephrasing_prompt(item["problem"], subgoal, version=version)
                catch_all_prompt = generate_catch_all_prompt(item["problem"], subgoal, version=version)
                prompt = tokenizer.apply_chat_template(
                    [{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=False)
                catch_all_prompt = tokenizer.apply_chat_template(
                    [{"role": "user", "content": catch_all_prompt}], tokenize=False, add_generation_prompt=True)
                prompts.append(prompt)
                catch_all_prompts.append(catch_all_prompt)

    completed = {}
    attempts = num_rephrasing_attempts
    while len(completed) < len(prompts) and attempts > 0:
        indices = [idx for idx, prompt in enumerate(prompts) if idx not in completed]
        remaining_prompts = [prompt for idx, prompt in enumerate(prompts) if idx not in completed]
        completions = batch_query_llm(remaining_prompts)
        for idx, completion in zip(indices, completions):
            rephrased = extract_list_from_code(extract_code_snippet(completion))
            if isinstance(rephrased, list) and len(rephrased) == 1:
                completed[idx] = rephrased[0]
        attempts -= 1

    indices = [idx for idx, prompt in enumerate(prompts) if idx not in completed]
    remaining_prompts = [prompt for idx, prompt in enumerate(catch_all_prompts) if idx not in completed]
    completions = batch_query_llm(remaining_prompts)

    for idx, completion in zip(indices, completions):
        if "Rephrased subgoal:" in completion:
            rephrased = completion.split("Rephrased subgoal:")[1].strip()
        else:
            rephrased = completion
        completed[idx] = rephrased

    rephrased_subgoals = [completed[idx] for idx, _ in enumerate(prompts)]

    assert len(rephrased_subgoals) == len(prompts)

    start_idx = 0
    for item in problems_data:
        rephrased_sketches = []
        for sketch in item["sketches"]:
            end_idx = start_idx + len(sketch)
            rephrased_sketch = rephrased_subgoals[start_idx: end_idx]
            rephrased_sketches.append(rephrased_sketch)
            start_idx += len(sketch)
        item["rephrased_sketches"] = rephrased_sketches


def solution_generation(problems_data, num_solutions_per_sketch, verbosity):
    prompts = []
    for item in problems_data:
        problem = item["problem"]
        sketches = item["rephrased_sketches"]
        for sketch in sketches:
            for _ in range(num_solutions_per_sketch):
                prompt = reformulate_with_sketch(problem, sketch, verbosity=verbosity)
                prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
                prompts.append(prompt)

    completions = batch_query_llm(prompts)
    start_idx = 0
    for item in problems_data:
        list_of_solutions = []
        for sketch in item["rephrased_sketches"]:
            end_idx = start_idx + num_solutions_per_sketch
            list_of_solutions.append(completions[start_idx: end_idx])
            start_idx += num_solutions_per_sketch
        item["list_of_solutions"] = list_of_solutions


def correctness_generation(problems_data, num_solutions_per_sketch):
    prompts = []
    for item in problems_data:
        for solutions in item["list_of_solutions"]: # for one sketch
            for solution in solutions:
                prompts.append(
                    generate_correctness_prompt(item["problem"], solution=solution, reference_solution=item["reference_solution"])
                )
    completions = batch_query_llm(prompts)
    start_idx = 0
    for item in problems_data:
        list_of_correctness = []
        for sketch in item["rephrased_sketches"]:
            end_idx = start_idx + num_solutions_per_sketch
            list_of_correctness.append(completions[start_idx: end_idx])
            start_idx += num_solutions_per_sketch
        item["list_of_correctness"] = list_of_correctness




def solve_problems(
        input_file_path,
        output_file_path,
        num_rephrasing_attempts=1,
        rephrasing_version="v10",
        num_solutions_per_sketch=4,
        solution_verbosity=4,
        start_idx=0,
        end_idx=10,
):
    os.makedirs(os.path.dirname(output_file_path), exist_ok=True)

    problems_data = []
    with open(input_file_path, encoding="utf-8") as f:
        for line in f.readlines():
            problems_data.append(json.loads(line))

    problems_data = problems_data[start_idx:end_idx]

    subgoal_rephrasing(
        problems_data=problems_data,
        num_rephrasing_attempts=num_rephrasing_attempts,
        version=rephrasing_version,
    )

    solution_generation(
        problems_data=problems_data,
        num_solutions_per_sketch=num_solutions_per_sketch,
        verbosity=solution_verbosity,
    )

    correctness_generation(
        problems_data=problems_data,
        num_solutions_per_sketch=num_solutions_per_sketch,
    )

    with open(output_file_path, "w", encoding="utf-8") as f:
        for item in problems_data:
            f.write(json.dumps(item) + "\n")


# Fire entry point for command-line argument processing
if __name__ == '__main__':
    fire.Fire(solve_problems)
