# TODO: cleanup of config dicts; consider flat dict for case configs

import argparse
import json
import os
import re
import subprocess
import sys
import warnings
from datetime import datetime
from itertools import product
from typing import Any

from utils.logger_utils import logger

# Ignore warnings whose message starts with "Pydantic serializer warnings:"
warnings.filterwarnings("ignore", message="Pydantic serializer warnings:")

from offline_experiments.utils_offline_exper import partial_format
from utils.signal_utils import signal_manager

annotation = "k_question"
ann_k = "k_question"
overwrite = False
max_api_keys = 30
max_batch_runners = 100
envs_to_run = ["vwa"]
model = "gemini-2.5-flash-preview-04-17"
# model = "gpt-4.1-apr"
sort_by_config = True

CASES = {
    "2p": {
        "eval_criterias": ["tri"],
        "cot_parts": [
            # "desc_cot_ss_trace",
            # "desc_trace",
            "desc_cot_v2",
            # "desc_compare_v2",
            # "desc_compare"
            # "desc_ss",
            # "desc_cot",
            # "no_cot",
            # "basic_cot",
        ],
        "trace_info": ["actions"],
        "k_config": [
            # {"conditional": False, "cached_k_dir": None},
            # {"conditional": False, "query": "k_2p", "cached_k_dir": None},
            # {"conditional": False, "query": "k_2p_expert", "cached_k_dir": None},
            # {"conditional": False, "query": "k_2p_expert_v2", "cached_k_dir": None},
            # {"conditional": False, "query": "k_2p_expert_v3", "cached_k_dir": None},
            {"conditional": False, "query": "k_2p_expert_v4", "cached_k_dir": None},
            # {"conditional": False, "query": "k_expert", "cached_k_dir": None},
            # {"conditional": True, "expert": False, "cached_k_dir": None},
        ],
        "add_summ_info": False,
        "add_expectation_info": True,
        # "rules": ["k_rule_efficiency"],
        "rules": ["k_rule_efficiency_vision"],
        "sys_prompt": "base",  # TODO add same way as eval_criterias and other prompt parts
    },
    # "1p_no_k": {  # TODO
    #     "eval_criterias": ["bin", "tri"],
    #     "cot_parts": ["no_cot", "basic_cot"],
    #     "trace_info": ["actions"],
    #     "add_summ_info": False,
    #     "add_expectation_info": False,
    #     "rules": ["rule_efficiency_vision_1p"],
    # },
    # "1p_k": {  # TODO
    #     "eval_criterias": ["tri"],
    #     "cot_parts": ["desc_cot"],
    #     "trace_info": ["actions"],
    #     "add_summ_info": False,
    #     "add_expectation_info": True,
    #     "rules": ["k_rule_efficiency_1p_v2"],
    # },
}

run_config = {
    "sort_by_config": sort_by_config,
    "overwrite": overwrite,
    "batch_mode": True,
    "max_batch_size": -1,
    "max_api_keys": max_api_keys,
    # "multiprocess_batch_mode": True if "gemini" in gen_config["model"] else False,
    # "num_processes": 2,
    # "skip_payload": 15 * 1024 * 1024,
    "max_batch_size_runners": max_batch_runners,
    # "task_list": "rascunho_tasks.txt",
}


# (Source of traces, shortname to annotate to results)
traces_dirs = {
    "osw": ("osw_traces/ui-tars-1.5_50steps_2025-04-05", "ui-tars-50s_04-05"),
    "vwa": (
        "experiments/gemini-2.5-flash-preview-04-17/no_cot-expert-2025-05-06",
        "gemini-2.5-no_cot_expert-2025-05-06",
    ),
}

env_configs = {
    "vwa": {
        "domains": ["shopping", "reddit", "classifieds"],
        "traces_dir": traces_dirs["vwa"][0],
        "trace_path_template": f"{traces_dirs['vwa'][0]}/{{domain}}/htmls/render_{{task_id}}.html",
        "task_ids_file": "",
    },
    "osw": {
        "domains": [
            "chrome",
            "gimp",
            "libreoffice_calc",
            "libreoffice_impress",
            "libreoffice_writer",
            "multi_apps",
            "os",
            "thunderbird",
            "vlc",
            "vs_code",
        ],
        "traces_dir": traces_dirs["osw"][0],
        "trace_path_template": f"{traces_dirs['osw'][0]}/{{domain}}/{{task_id}}/trajectory.json",
        "additional_config": {"annotate_actions_on_image": True},
    },
}

env_configs = {env: env_configs[env] for env in envs_to_run}

llm_configs = {
    "gpt-4.1-apr": {
        "model": "gpt-4.1-2025-04-14",
        "temperature": 1,
        "top_p": 0.01,
        "max_tokens": 8192,
    },
    "gpt-4o-aug": {
        "model": "gpt-4o-2024-08-06",
        "temperature": 1,
        "top_p": 0.01,
        "max_tokens": 8192,
    },
    "gemini-2.5-flash-preview-04-17": {
        "model": "gemini-2.5-flash-preview-04-17",
        "temperature": 0.5,
        "top_p": 0.01,
        "top_k": 40,
        "max_tokens": 8192,
        "thinking_budget": 0,
    },
    "gemini-2.5-flash-preview-apr-thinking-auto": {
        "model": "gemini-2.5-flash-preview-04-17",
        "temperature": 0.5,
        "top_p": 0.01,
        "top_k": 40,
        "max_tokens": 8192,
        # "thinking_budget": auto,
    },
}


timestamp = datetime.now().strftime("%Y-%m-%d")
exper_dir = "./offline_experiments"
model_name = re.sub("/", "-", model.split("/")[-1])
out_dir_template = f"{exper_dir}/{{env}}/{model_name}_{{trace_source_path}}/{{domain}}"


gen_config = {
    **llm_configs[model],
}


def dump_exper_args(
    config: dict,
):
    output_dir = config["out_dir"]
    if os.path.exists(f"{output_dir}/exper_args.json"):
        return

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    with open(f"{output_dir}/exper_args.json", "w") as f:
        args_to_dump = {
            "gen_config": gen_config,
            "date": timestamp,
        }

        if config is not None:
            args_to_dump.update(config)

        json.dump(args_to_dump, f, indent=4)


def get_first_pass_subdir_name(config):
    trace_info = config["prompt_args"]["trace_info"]
    k_config = config["prompt_args"]["k_config"]
    dir_str = "k-"
    if k_config["conditional"]:
        dir_str += "cond-"
        if trace_info:
            dir_str += f"{trace_info}-"
    else:
        dir_str += "uncond-"

    if "expert" in k_config["query"]:
        v_num = re.search(r"v(\d+)", k_config["query"])
        if v_num:
            dir_str += f"expertv{v_num.group(1)}"
        else:
            dir_str += "expert" if "expert" in k_config["query"] else ""
    else:
        dir_str += f"-{k_config['query']}"

    dir_str += ann_k
    dir_str = dir_str.strip("-")

    return dir_str.strip("-")


def build_config_name(eval_criteria="", cot_part="", trace_info="", k_config=None, env="", rule=None, case_config=None):
    config_str = ""
    if eval_criteria:
        config_str += f"{eval_criteria}"
    if cot_part:
        config_str += f"-{cot_part}"
    if trace_info:
        config_str += f"-{trace_info}"

    k_config_str = ""
    if k_config:
        k_config_str += "-cond" if k_config["conditional"] else "-uncond"
        # TODO: fix older folders to remove this "IF"
        if "expert" in k_config["query"]:
            # match v<int>
            v_num = re.search(r"v(\d+)", k_config["query"])
            vnum_str = ""
            if v_num:
                vnum_str = f"v{v_num.group(1)}"

            k_config_str += f"-expert{vnum_str}" if "expert" in k_config["query"] else ""
        else:
            k_config_str += f"-{k_config['query']}"

    rules_str = rule if rule else ""

    final_str = f"{config_str}-{k_config_str}-{rules_str}"

    # Sub one more "--"
    final_str = re.sub(r"-{2,}", "-", final_str)

    if case_config:
        sys_p_str = case_config.get("sys_prompt", "")
        if sys_p_str and "base" not in sys_p_str:
            final_str += f"-sys_{sys_p_str}"

    return final_str.strip("-")


def build_all_first_pass_configs():
    configs_per_env = {}
    for env in env_configs:
        trace_path_template = env_configs[env]["trace_path_template"]
        all_configs = {}
        for case_name, case_config in CASES.items():
            if "k_config" not in case_config:
                continue

        combinations = []
        for k_config in case_config["k_config"]:
            # Check if this is a conditional case
            is_conditional = k_config["conditional"]

            if is_conditional:
                # Generate combinations for conditional cases (original behavior)
                combinations.extend(
                    list(
                        product(
                            env_configs[env]["domains"],
                            case_config["trace_info"],
                            [k_config],
                        )
                    )
                )
            else:
                # For unconditional cases, ignore trace_info combinations
                combinations.extend(
                    list(
                        product(
                            env_configs[env]["domains"],
                            [None],
                            [k_config],
                        )
                    )
                )

        # Create new cases with all combinations
        for i, (domain, trace_info, k_config) in enumerate(combinations):
            key = f"{env}_{case_name}-"
            key += build_config_name(trace_info=trace_info, k_config=k_config, env=env)

            subkey = domain

            if key not in all_configs:
                all_configs[key] = {}

            config: dict[str, Any] = {}
            config = {
                "prompt_args": {
                    "eval_criteria": "",
                    "cot_part": "",
                    "trace_info": trace_info,
                    "add_summ_info": case_config["add_summ_info"],
                    "add_expectation_info": case_config["add_expectation_info"],
                    "k_config": k_config,
                    "rules": case_config.get("rules", []),
                },
                "trace_path_template": partial_format(trace_path_template, domain=domain),
                "env": env,
                "domain": domain,
                "additional_config": env_configs[env].get("additional_config", {}),
                "meta_data": case_config,  # TODO: clean this up
            }
            trace_source_shortname = traces_dirs[env][1]
            out_dir = f"{out_dir_template.format(domain=domain, env=env, trace_source_path=trace_source_shortname)}/{get_first_pass_subdir_name(config)}"
            config["out_dir"] = out_dir
            all_configs[key][subkey] = config
        configs_per_env[env] = all_configs
    return configs_per_env


def build_all_eval_configs(out_dir_template=out_dir_template):
    configs_per_env = {env: {} for env in env_configs}
    for env in env_configs:
        trace_path_template = env_configs[env]["trace_path_template"]
        all_configs = {}
        for case_name, case_config in CASES.items():
            # Generate all combinations for each case
            combinations = list(
                product(
                    env_configs[env]["domains"],
                    case_config["eval_criterias"],
                    case_config["cot_parts"],
                    case_config["trace_info"],
                    case_config.get("k_config", [None]),
                    case_config.get("rules", [None]),
                )
            )

            # Create new cases with all combinations
            for i, (domain, eval_criteria, cot_part, trace_info, k_config, rule) in enumerate(combinations):
                key = f"{env}_{case_name}-"
                config_name = build_config_name(eval_criteria, cot_part, trace_info, k_config, env, rule, case_config)
                key += config_name

                subkey = domain
                trace_source_short_name = traces_dirs[env][1]
                if key not in all_configs:
                    all_configs[key] = {}

                config: dict[str, Any] = {}
                config = {
                    "prompt_args": {
                        "eval_criteria": eval_criteria,
                        "cot_part": cot_part,
                        "trace_info": trace_info,
                        "add_summ_info": case_config["add_summ_info"],
                        "add_expectation_info": case_config["add_expectation_info"],
                        "k_config": k_config.copy() if k_config else None,
                        "rule": rule if rule else None,
                    },
                    "env": env,
                    "domain": domain,
                    "trace_path_template": partial_format(trace_path_template, domain=domain),
                    "out_dir": f"{out_dir_template.format(domain=domain, env=env, trace_source_path=trace_source_short_name)}/{case_name}-{config_name}{annotation}",
                    "additional_config": env_configs[env].get("additional_config", {}),
                    "meta_data": case_config,  # TODO: clean this up
                }
                if k_config:
                    if k_config.get("cached_k_dir"):
                        topmost_dir = k_config["cached_k_dir"]
                        cached_k_dir = f"{topmost_dir}/{get_first_pass_subdir_name(config)}"
                    else:
                        cached_k_dir = f"{out_dir_template.format(domain=domain, env=env, trace_source_path=trace_source_short_name)}/{get_first_pass_subdir_name(config)}"
                    config["prompt_args"]["k_config"]["cached_k_dir"] = cached_k_dir
                all_configs[key][subkey] = config

        configs_per_env[env] = all_configs
    return configs_per_env


def run_and_wait(command, log_file_path):
    process = subprocess.Popen(
        command,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
    )

    signal_manager.register_termination_signals(lambda: process.terminate())  # type:ignore

    while True:
        output_line = process.stdout.readline()  # type:ignore
        if output_line == "" and process.poll() is not None:
            break
        if output_line:
            print(output_line, end="")  # Print to command line
            logfile.write(output_line)
            logfile.flush()
    process.wait()


if __name__ == "__main__":
    from llms.constants.constants import API_KEYS_REPO

    parser = argparse.ArgumentParser()
    parser.add_argument("--fp", action="store_true")
    parser.add_argument("--v", action="store_true")
    args = parser.parse_args()

    # Log the proccess ID
    logger.info(f"Process ID: {os.getpid()}")

    if sys.gettrace():
        build_all_first_pass_configs()
        build_all_eval_configs()
        # exit()
        args.fp = True
        args.v = True

    timestamp_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    log_file_path = f"{exper_dir}/logs/{model_name}-{timestamp_time}.log"

    os.makedirs(f"{exper_dir}/logs", exist_ok=True)

    if args.fp:
        with open(log_file_path, "w") as logfile:
            run_and_wait(
                ["python", "-u", "-m", "offline_experiments.first_pass"],
                logfile,
            )

    if args.v:
        with open(log_file_path, "a") as logfile:
            run_and_wait(
                ["python", "-u", "-m", "offline_experiments.verify"],
                logfile,
            )
