import fire
import os
import torch
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import h5py
import json
from tqdm import tqdm
from utils import load_from_hdf5, load_available_actions
import random

# Set a random seed for reproducibility

llm = LLM(
    model="/.../Meta-Llama-3.1-70B-Instruct",
    tokenizer="/.../Meta-Llama-3.1-70B-Instruct",
    tokenizer_mode="slow",
    dtype="bfloat16",
    tensor_parallel_size=torch.cuda.device_count(),
)

seed = random.randint(0, 2**16 - 1)
print(f"set random seed to {seed}")
sampling_params = SamplingParams(
    # temperature=0.2,
    temperature=0.6,
    # top_p=1,
    seed=seed,
    max_tokens=2048,
)
tokenizer = AutoTokenizer.from_pretrained("/.../Meta-Llama-3.1-70B-Instruct")


def batch_query_llm(prompts, use_tqdm=True):
    with torch.no_grad():
        completions = llm.generate(prompts, sampling_params, use_tqdm=use_tqdm)
        return [completion.outputs[0].text for completion in completions]


def generate_prompt(observation):
    # v6
    # prompt = (
    #     "You are given an observation.\n\n"
    #     f"Observation:\n{observation}\n\n"
    #     "Your task is to evaluate whether this observation is:\n\n"
    #     "1. **General enough to apply across all kinds of problems**—i.e., a fundamental principle or truth that holds universally across diverse contexts, domains, and situations.\n"
    #     "2. **Applicable**—meaning it is relevant in a variety of contexts, but not necessarily universally valid across all areas.\n"
    #     "3. **Specific**—meaning it is narrowly applicable and relevant only to a particular domain, situation, or problem.\n\n"
    #     "**Conclude your response with EXACTLY ONE of the following statements:**\n"
    #     "- \"The observation is universal\" if it applies broadly and is relevant across various contexts.\n"
    #     "- \"The observation is applicable\" if it fits a range of situations but is not universally true.\n"
    #     "- \"The observation is specific\" if it is restricted to a particular area, field, or problem.\n\n"
    #     "**This is NOT optional. Your response MUST end with either \"The observation is universal\", \"The observation is applicable\", or \"The observation is specific\".**\n"
    #     "**Limit your response to 100 words.**"
    # )

    # v8
    # prompt = (
    #     "You are given an observation.\n\n"
    #     f"Observation:\n{observation}\n\n"
    #     "Your task is to evaluate whether this observation is:\n\n"
    #     "1. **Universal**—i.e., a fundamental principle or truth that applies across all contexts, fields, and situations.\n"
    #     "2. **Broadly applicable**—meaning it holds true in a wide range of contexts but may not be universally valid across all domains.\n"
    #     "3. **Generally applicable**—meaning it is relevant in several contexts but not across all situations or disciplines.\n"
    #     "4. **Applicable**—meaning it is relevant in a few contexts, though not universally valid.\n"
    #     "5. **Specific**—meaning it is relevant only to a particular domain, situation, or problem.\n\n"
    #     "**Conclude your response with EXACTLY ONE of the following statements:**\n"
    #     "- \"The observation is universal\" if it applies universally across contexts.\n"
    #     "- \"The observation is broadly applicable\" if it is applicable in many contexts but not all.\n"
    #     "- \"The observation is generally applicable\" if it fits several contexts but not a wide range.\n"
    #     "- \"The observation is applicable\" if it applies in limited contexts.\n"
    #     "- \"The observation is specific\" if it applies only to a particular domain or problem.\n\n"
    #     "**This is NOT optional. Your response MUST end with one of the above statements.**\n"
    #     "**Limit your response to 100 words.**"
    # )

    # v10
    prompt = (
        "You are given an observation.\n\n"
        f"Observation:\n{observation}\n\n"
        "Your task is to evaluate whether this observation is:\n\n"
        "1. **Universal**—i.e., a fundamental principle or truth that applies across all contexts, fields, and situations.\n"
        "2. **Broadly applicable**—meaning it holds true in a wide range of contexts but may not be universally valid across all domains.\n"
        "3. **Generally applicable**—meaning it is relevant in several contexts but not across all situations or disciplines.\n"
        "4. **Somewhat applicable**—meaning it applies in a limited set of contexts but not widely across many areas.\n"
        "5. **Applicable**—meaning it is relevant in a few contexts, though not universally valid.\n"
        "6. **Narrowly applicable**—meaning it applies only to a specific situation or problem, with very limited relevance beyond that.\n"
        "7. **Specific**—meaning it is highly specialized and relevant only to a particular domain, situation, or problem.\n\n"
        "**Conclude your response with EXACTLY ONE of the following statements:**\n"
        "- \"The observation is universal\" if it applies universally across contexts.\n"
        "- \"The observation is broadly applicable\" if it is applicable in many contexts but not all.\n"
        "- \"The observation is generally applicable\" if it fits several contexts but not a wide range.\n"
        "- \"The observation is somewhat applicable\" if it applies in a limited set of contexts.\n"
        "- \"The observation is applicable\" if it fits only a few contexts.\n"
        "- \"The observation is narrowly applicable\" if it applies in very few situations.\n"
        "- \"The observation is specific\" if it applies only to a particular domain or problem.\n\n"
        "**This is NOT optional. Your response MUST end with one of the above statements.**\n"
        "**Limit your response to 100 words.**"
    )

    # v12
    prompt = (
        "You are given an observation.\n\n"
        f"Observation:\n{observation}\n\n"
        "Your task is to evaluate the level of generality and applicability of this observation. Determine whether it is:\n\n"
        "1. **Universal**—a foundational, all-encompassing principle or truth that is valid across all disciplines, contexts, and problems, with no exceptions.\n"
        "2. **Broadly applicable**—a principle that holds true in a wide variety of contexts and disciplines, but may not extend universally to every situation.\n"
        "3. **Generally applicable**—a concept that applies to several contexts or fields but is not relevant across all problems or situations.\n"
        "4. **Somewhat applicable**—a notion that is valid in a limited set of contexts or scenarios, but not widely across disciplines.\n"
        "5. **Applicable**—a principle or observation that is relevant in only a few specific contexts.\n"
        "6. **Narrowly applicable**—an insight that pertains to a highly specific situation or a very small number of scenarios.\n"
        "7. **Specific**—a highly specialized observation that is confined to a particular domain, problem, or context.\n\n"
        "**Conclude your response with EXACTLY ONE of the following statements:**\n"
        "- \"The observation is universal\" if it applies across all disciplines, contexts, and problems without exception.\n"
        "- \"The observation is broadly applicable\" if it is applicable in many contexts and disciplines but not universally.\n"
        "- \"The observation is generally applicable\" if it is relevant across several contexts but not widely.\n"
        "- \"The observation is somewhat applicable\" if it applies in a limited range of contexts.\n"
        "- \"The observation is applicable\" if it is relevant only in a few contexts.\n"
        "- \"The observation is narrowly applicable\" if it applies in very few situations.\n"
        "- \"The observation is specific\" if it pertains exclusively to a particular domain, problem, or context.\n\n"
        "**This is NOT optional. Your response MUST end with one of the above statements.**\n"
        "**Limit your response to 100 words.**"
    )
    prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
    return prompt


def get_judgement(observations, num_attempts=8):
    prompts = []
    for observation in observations:
        for _ in range(num_attempts):
            prompts.append(generate_prompt(observation))
    return batch_query_llm(prompts)


def save_to_hdf5(problem, output_path, mode='w'):
    with h5py.File(output_path, mode) as hdf5_file:
        idx = problem.pop("idx", None)
        if idx is not None:
            group = hdf5_file.create_group(str(idx))
            for key, value in problem.items():
                if isinstance(value, (int, float, str, list, dict)):
                    group.attrs[key] = json.dumps(value) if isinstance(value, (list, dict)) else value
                else:
                    print(f"Unsupported data type for key {key} in problem {idx}.")


def solve_problems(
        input_file_path,
        output_file_path,
        batch_size=1024,
        num_attempts=8,
        start_idx=0,
        end_idx=None,
):
    os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
    problem_data = load_from_hdf5(input_file_path, start_idx, end_idx, use_tqdm=True)
    observation_data = load_available_actions(problem_data, use_tqdm=True)
    observation_data = [{"idx": i, "observation": obs} for i, obs in enumerate(observation_data)]

    with open(output_file_path, 'w') as jsonl_file:
        for i in range(0, len(observation_data), batch_size):
            batch = observation_data[i:i+batch_size]
            all_judgements = get_judgement([item["observation"] for item in batch], num_attempts)
            judgement_idx = 0
            for batch_idx in range(len(batch)):
                data_to_save = {
                    "idx": batch[batch_idx]["idx"],
                    "observation": batch[batch_idx]["observation"],
                    "judgements": all_judgements[judgement_idx:judgement_idx + num_attempts],
                }
                judgement_idx += num_attempts
                # save_to_hdf5(data_to_save, output_file_path, mode="a")
                jsonl_file.write(json.dumps(data_to_save) + "\n")


# Fire entry point for command-line argument processing
if __name__ == '__main__':
    fire.Fire(solve_problems)
