import os
import torch
import fire
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import json
import re
import ast
import random
from utils import load_from_hdf5


def extract_code_snippet(text):
    # Extract the code between ```python and the closing ```
    match = re.search(r'```python(.*?)```', text, re.DOTALL)
    if match:
        return match.group(1).strip()  # Return the Python code, removing surrounding whitespace
    match = re.search(r'```(.*?)```', text, re.DOTALL)
    if match:
        return match.group(1).strip()  # Return the Python code, removing surrounding whitespace
    return None


def extract_list_from_code(code_str):
    # Parse the cleaned code string into a Python AST
    try:
        parsed_code = ast.parse(code_str)
        # Loop through all top-level statements
        for node in parsed_code.body:
            if isinstance(node, ast.Assign):  # Look for assignments
                for target in node.targets:
                    if isinstance(node.value, ast.List):  # Check if the assigned value is a List
                        # Extract the list from the assignment
                        return [ast.literal_eval(e) for e in node.value.elts]  # Convert list elements to their literal values
            if isinstance(node, ast.Expr) and isinstance(node.value, ast.List):
                # Extract and return the elements from the list
                return [ast.literal_eval(e) for e in node.value.elts]
    except Exception as e:
        # print(f"Error parsing code: {e}")
        pass
    return None

llm = LLM(
    model=".../Meta-Llama-3.1-70B-Instruct",
    tokenizer=".../Meta-Llama-3.1-70B-Instruct",
    tokenizer_mode="slow",
    dtype="bfloat16",
    tensor_parallel_size=torch.cuda.device_count(),
)

sampling_params = SamplingParams(
    temperature=0.6,
    seed=random.randint(0, 2**16 - 1),
    max_tokens=2048,
)
tokenizer = AutoTokenizer.from_pretrained(".../Meta-Llama-3.1-70B-Instruct")


def batch_query_llm(prompts, use_tqdm=True):
    with torch.no_grad():
        completions = llm.generate(prompts, sampling_params, use_tqdm=use_tqdm)
        return [completion.outputs[0].text for completion in completions]


def generate_sketch_prompt(problems, n=3, version="v6"):
    problem_text = "\n".join(f"{i + 1}. {p}" for i, p in enumerate(problems))
    # prompt = (
    #     "Imagine that you are a problem-solving expert tasked with creating a universal problem-solving sketch.\n"
    #     f"You will be shown the following {len(problems)} problems:\n\n{problem_text}\n\n"
    #     "From these problems, you need to extract a universal sketch, which is a list of subgoals that can be applied to these problems and potentially other scenarios and disciplines.\n"
    #     "Each subgoal should be expressed as a natural language sentence that is casual and relatable to everyday life, using the first person perspective.\n"
    #     "The sketch should be designed to be versatile and adaptable, allowing it to be used to tackle various problems and challenges.\n"
    #     "Format your response as a Python list with each item containing exactly one subgoal."
    # )
    if version == "v1":
        # prompt = (
        #     "Imagine that you are a problem-solving expert tasked with creating a universal problem-solving sketch.\n"
        #     f"You will be shown the following {len(problems)} problems:\n\n{problem_text}\n\n"
        #     f"From these problems, extract up to {n} essential subgoals that form a universal sketch. These subgoals should be broad enough to apply to these problems and other scenarios across disciplines.\n"
        #     "Each subgoal should be expressed as a natural language sentence that is casual and relatable to everyday life, using the first person perspective.\n"
        #     "The sketch should be designed to be versatile and adaptable, allowing it to be used to tackle various problems and challenges.\n"
        #     f"Format your response as a Python list with 1-{n} items, where each item contains one subgoal."
        # )
        prompt = (
            "Imagine you are a problem-solving expert tasked with creating a universal problem-solving sketch.\n"
            f"You will be shown the following {len(problems)} problems:\n\n{problem_text}\n\n"
            f"From these problems, extract up to {n} essential subgoals that form a universal sketch. The subgoals should:\n"
            "- Apply broadly across different types of problems and disciplines\n"
            "- Use casual, everyday language from the first person perspective\n"
            "- Avoid sequential markers like 'first', 'next', 'then', 'finally'\n"
            "- Focus on the core action or insight needed at each stage\n\n"
            f"Format your response as a Python list with 1-{n} items, where each item expresses one subgoal.\n"
            "Make each subgoal self-contained so it can be applied flexibly rather than in a fixed sequence."
        )
    elif version == "v2": # less requirement
        prompt = (
            "Imagine you are a problem-solving expert tasked with creating a universal problem-solving sketch.\n"
            f"You will be shown the following {len(problems)} problems:\n\n{problem_text}\n\n"
            f"From these problems, extract up to {n} essential subgoals that form a universal sketch. The subgoals should:\n"
            "- Apply broadly across different types of problems and disciplines\n"
            "- Avoid sequential markers like 'first', 'next', 'then', 'finally'\n"
            f"Format your response as a Python list with 1-{n} items, where each item expresses one subgoal.\n"
            "Make each subgoal self-contained so it can be applied flexibly rather than in a fixed sequence."
        )
    elif version == "v3": # more creative
        prompt = (
            "Imagine you are a problem-solving expert tasked with creating a universal problem-solving sketch.\n"
            f"You will be shown the following {len(problems)} problems:\n\n{problem_text}\n\n"
            f"From these problems, extract up to {n} essential subgoals that form a universal sketch. The subgoals should:\n"
            "- Apply broadly across different types of problems and disciplines\n"
            "- Use casual, everyday language from the first person perspective\n"
            "- Avoid sequential markers like 'first', 'next', 'then', 'finally'\n"
            "- Focus on the core action or insight needed at each stage\n\n"
            "- Be as creative as possible, going beyond what you think is intuitively correct\n"
            f"Format your response as a Python list with 1-{n} items, where each item expresses one subgoal.\n"
            "Make each subgoal self-contained so it can be applied flexibly rather than in a fixed sequence."
        )
    elif version == "v4": # varing length
        prompt = (
            "Imagine you are a problem-solving expert tasked with creating a universal problem-solving sketch.\n"
            f"You will be shown the following {len(problems)} problems:\n\n{problem_text}\n\n"
            f"From these problems, extract between 1 to {n} essential subgoals that form a universal sketch. Your primary challenge is to capture the problems' essence in as few subgoals as possible - being comprehensive with 1-2 subgoals is far more impressive than having many. The subgoals should:\n"
            "- Apply broadly across different types of problems and disciplines\n"
            "- Use casual, everyday language from the first person perspective\n"
            "- Avoid sequential markers like 'first', 'next', 'then', 'finally'\n"
            "- Focus on the core action or insight needed at each stage\n\n"
            f"Format your response as a Python list with 1-{n} items, where each item expresses one subgoal.\n"
            "Make each subgoal self-contained so it can be applied flexibly rather than in a fixed sequence.\n"
            "Remember: Adding a subgoal is admitting the previous ones failed to capture something essential."
        )
    elif version == "v5": # creative + varing length
        prompt = (
            "Imagine you are a problem-solving expert tasked with creating a universal problem-solving sketch.\n"
            f"You will be shown the following {len(problems)} problems:\n\n{problem_text}\n\n"
            f"From these problems, extract between 1 to {n} essential subgoals that form a universal sketch. Your primary challenge is to capture the problems' essence in as few subgoals as possible - being comprehensive with 1-2 subgoals is far more impressive than having many. The subgoals should:\n"
            "- Apply broadly across different types of problems and disciplines\n"
            "- Use casual, everyday language from the first person perspective\n"
            "- Avoid sequential markers like 'first', 'next', 'then', 'finally'\n"
            "- Focus on the core action or insight needed at each stage\n\n"
            "- Be as creative as possible, going beyond what you think is intuitively correct\n"
            f"Format your response as a Python list with 1-{n} items, where each item expresses one subgoal.\n"
            "Make each subgoal self-contained so it can be applied flexibly rather than in a fixed sequence.\n"
            "Remember: Adding a subgoal is admitting the previous ones failed to capture something essential."
        )
    elif version == "v6": # varing lengths + less requirement + creative
        prompt = (
            "Imagine you are a problem-solving expert tasked with creating a universal problem-solving sketch.\n"
            f"You will be shown the following {len(problems)} problems:\n\n{problem_text}\n\n"
            f"From these problems, extract between 1 to {n} essential subgoals that form a universal sketch. Being comprehensive with 1-2 subgoals is far more impressive than having many. The subgoals should:\n"
            "- Apply broadly across different types of problems and disciplines\n"
            "- Use casual, everyday language from the first person perspective\n"
            "- Avoid sequential markers like 'first', 'next', 'then', 'finally'\n"
            "- Focus on the core action or insight needed at each stage\n\n"
            "- Be as creative as possible, going beyond what you think is intuitively correct\n"
            f"Format your response as a Python list with 1-{n} items, where each item expresses one subgoal.\n"
            "Make each subgoal self-contained so it can be applied flexibly rather than in a fixed sequence.\n"
            "Remember: Adding a subgoal is admitting the previous ones failed to capture something essential."
        )

    prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
    return prompt


def get_sketch_batch(problem_groups, max_len=6, version="v1"):
    prompts = []
    for problems in problem_groups:
        prompts.append(generate_sketch_prompt(problems, n=max_len, version=version))
    return batch_query_llm(prompts)


def solve_problems(
    input_file_path="demo/platypus_64sketches_4attempts.hdf5",
    output_dir="./demo",
    problems_per_group=10,
    max_sketch_len=6,
    version="v1",
    num_sketches_per_problem=16,  # todo: 4
    start_from=0,
):
    os.makedirs(output_dir, exist_ok=True)
    problems_data = load_from_hdf5(input_file_path, use_tqdm=True)
    for i in range(start_from, num_sketches_per_problem):
        output_file_path = os.path.join(output_dir, f"universal_sketch_v{i}.jsonl")
        random.shuffle(problems_data)
        print("\n---------------------------------------")
        print("Q1:", problems_data[0]["problem"])
        print("A1:", problems_data[0]["reference_solution"])
        print("Q2:", problems_data[1]["problem"])
        print("A2:", problems_data[1]["reference_solution"])
        print("Q3:", problems_data[2]["problem"])
        print("A3:", problems_data[2]["reference_solution"])
        print("---------------------------------------\n")

        writer = open(output_file_path, "w", encoding="utf-8")
        problem_groups = []
        solution_groups = []
        index_groups = []
        for j in range(0, len(problems_data), problems_per_group):
            problem_groups.append([item["problem"] for item in problems_data[j:j+problems_per_group]])
            solution_groups.append([item["reference_solution"] for item in problems_data[j:j+problems_per_group]])
            index_groups.append([item["idx"] for item in problems_data[j:j+problems_per_group]])
        # responses = get_sketch_batch(problem_groups, max_len=max_sketch_len, version=version)

        completed = {}
        while len(completed) < 0.99 * len(problem_groups):
            indices = [idx for idx, group in enumerate(problem_groups) if idx not in completed]
            remaining_groups = [group for idx, group in enumerate(problem_groups) if idx not in completed]
            responses = get_sketch_batch(remaining_groups, max_len=max_sketch_len, version=version)
            for idx, response in zip(indices, responses):
                if extract_list_from_code(extract_code_snippet(response)) is not None:
                    completed[idx] = response
        for idx, response in zip(indices, responses):
            if idx not in completed:
                completed[idx] = response
        responses = [completed[idx] for idx, _ in enumerate(problem_groups)]

        assert len(responses) == len(problem_groups)
        results = []
        for problems, solutions, indices, response in zip(problem_groups, solution_groups, index_groups, responses):
            for problem, solution, index in zip(problems, solutions, indices):
                results.append({
                    "idx": index,
                    "problem": problem,
                    "solution": solution,
                    "response": response,
                })
        results = sorted(results, key=lambda x: x["idx"])
        for result in results:
            try:
                writer.write(json.dumps(result) + "\n")
            except Exception as e:
                print(f"Error writing result: {result}")
                print(f"Error message: {e}")
                continue
        writer.close()


if __name__ == '__main__':
    fire.Fire(solve_problems)
