import argparse
from collections import OrderedDict
from datetime import datetime
import hashlib
import json
from pathlib import Path
import re
import sys
import os
import warnings

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
from experiments_utils.constants import EXPERIMENTS_CONFIG_REPO_PATH
from utils.file_utils import (
    copy_move,
    extract_run_datetime,
    find_files,
    get_args,
    get_attribute_from_dict,
    is_finished_execution,
    remove_empty_dirs,
    SanityChecker,
    stable_json_hash,
)

# TODO: (i) aggregate by all arguments related to agent config setting; (ii) create an ID for each config; (iii) create folders by the ID (iv) make an md log with the IDs and corresponding config

ARGS_TO_CLUSTER = [
    "agents_configs",
    "observation_type",
    "current_viewport_only",
    "viewport_width",
    "viewport_height",
    "test_config_base_dir",
    "no_caption_text_obs",
    "parsing_failure_th",
    "repeating_action_failure_th",
    "fuzzy_match_provider",
    "max_steps",
    "seed",
]

AGENT_CONFIGS_TO_CLUSTER = {
    "executor_agent": {
        "prompt",
        "num_previous_state_actions",
        "lm_config",
    },
    "critique_agent": {
        "mode",
        "max_critique_executor_loop",
        "prompt",
        "lm_config",
    },
    "image_refiner": {
        "max_refine",
        "lm_config",
    },
    "text_refiner": {
        "max_obs_length",
    },
    "request_refiner": {
        "captioning_model",
        "caption_input_img",
    },
}

LM_CONFIGS_TO_CLUSTER = {
    "gen_config_alias",
    "temperature",
    "top_p",
    "max_tokens",
    "top_k",
    "text_first",
    "google_chars_per_token",
    "model",
}

# TODO: add prompt configs to cluster
# PROMPT_CONFIGS_TO_CLUSTER = {
#     "system_prompt",
#     "examples",
# }


DIR_NAME_TEMPLATE = "{model}/{config_name}/{domain}"
json_prompts_dir = "agent/prompts/jsons"

sep1 = "____"
sep2 = "_"
sep3 = "="
sep4 = "-"

max_minute_diff = 5

# ===============================================================================
# Helper functions
# ===============================================================================


def build_destination_tree(dest_dict: dict[str, list[str]]) -> dict:
    """
    Given a dict whose keys are destination directory paths and whose values
    are lists of execution directories (which will be moved as subdirectories),
    build a nested dict representing the folder structure.

    For each destination path, we split it by os.sep. At the leaf node we attach
    a special key "_run_dirs" with the list of run directories (using only their basenames).
    """
    tree = {}
    for dest_path, run_dirs in dest_dict.items():
        parts = dest_path.parts
        node = tree
        # Build nodes for each part of the destination path.
        for part in parts:
            if part not in node:
                node[part] = {}
            node = node[part]
        # At the destination leaf, add each exec directory as a child
        for ed in run_dirs:
            ed_name = os.path.basename(ed)
            if ed_name not in node:
                node[ed_name] = {}
    return tree


def print_tree(
    tree: dict, prefix: str = "", current_level: int = 0, max_level: int | None = None, print_leaves: bool = False
) -> None:
    """
    Recursively prints the nested dictionary using tree branch characters.

    tree (dict): The tree dictionary to print.
    prefix (str): The prefix string used for the current level.
    current_level (int): The current recursion depth (0-indexed).
    max_level (int | None): If provided, limits printing to this depth.
                      When reached, an ellipsis ("...") is printed to indicate more levels.
    print_leaves (bool): If False, nodes that have no children will not be printed.
    """
    # Stop recursing if max_level is reached.
    if max_level is not None and current_level >= max_level:
        if tree:  # There are deeper levels that we're not printing.
            print(prefix + "└── " + "...")
        return

    # Depending on show_leaves, filter out nodes with empty children.
    if print_leaves:
        items = sorted(tree.items())
    else:
        # Only include nodes that have at least one child.
        items = sorted((name, subtree) for name, subtree in tree.items() if subtree)

    count = len(items)
    for idx, (name, subtree) in enumerate(items):
        connector = "└── " if idx == count - 1 else "├── "
        print(prefix + connector + name)
        new_prefix = prefix + ("    " if idx == count - 1 else "│   ")
        print_tree(subtree, new_prefix, current_level + 1, max_level, print_leaves)


def path_without_duplicates(save_dir: str, base_dirname: str) -> str:
    save_dir_path = Path(save_dir)
    new_parts = [part for part in Path(base_dirname).parts if part not in save_dir_path.parts]
    return str(save_dir_path / Path(*new_parts))


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--d", help="Directory containing the subdirs to organize", type=str)
    parser.add_argument("--e", type=str, default=["debug", "prompt_tuning_tests", "zzOld", "zzIgnore"])
    parser.add_argument("--s", action="store_true", help="Enable sanity check")
    return parser.parse_args()


# ===============================================================================
# LINK: repo hashmap functions
# ===============================================================================


def normalize_params(configs_dict: dict, configs_dict_path: str):
    if configs_dict and configs_dict_path:
        warnings.warn("Both `configs_dict` and `configs_dict_path` are provided. Using `configs_dict`.")

    if not configs_dict and not configs_dict_path:
        raise ValueError("No `configs_dict` or `configs_dict_path` provided")

    if configs_dict_path and not configs_dict:
        if not os.path.exists(configs_dict_path):
            raise FileNotFoundError(f"No configs_dict found at {configs_dict_path}")

        with open(configs_dict_path, "r") as f:
            configs_dict = json.load(f)

    return configs_dict


def get_config_by_hash(hash: str, configs_dict: dict = {}, configs_dict_path="") -> tuple[dict | None, str | None]:
    configs_dict = normalize_params(configs_dict, configs_dict_path)

    if "args_hash_to_int" not in configs_dict:
        warnings.warn("No `args_hash_to_int` in configs_repo")
        return None, None

    if hash in configs_dict["args_hash_to_int"]:
        config_int_id = configs_dict["args_hash_to_int"][hash]
        return configs_dict[config_int_id], config_int_id
    else:
        return None, None


def add_config_by_hash(
    hash: str, new_data: dict = {}, configs_dict: dict = {}, configs_dict_path=""
) -> tuple[dict, str]:
    configs_dict = normalize_params(configs_dict, configs_dict_path)

    if "args_hash_to_int" not in configs_dict:
        configs_dict["args_hash_to_int"] = {}

    config, int_id = get_config_by_hash(hash, configs_dict=configs_dict)

    # If config already exists, return its ID and updated config_dict
    if int_id is not None:
        configs_dict[int_id].update(new_data)
        return configs_dict, int_id

    # If no config exists:

    # 1) create a new integer ID
    if len(configs_dict["args_hash_to_int"]) == 0:
        int_id = "001"
    else:
        int_id = max([int(k) for k in configs_dict["args_hash_to_int"].values()]) + 1
        int_id = f"{int_id:03d}"

    # 2) associate the hash with the new integer ID
    configs_dict["args_hash_to_int"][hash] = int_id

    # 3) add the new config to the configs_dict
    configs_dict[int_id] = new_data

    return configs_dict, int_id


def get_add_config_name(cluster_args: dict, configs_dict_path: str = EXPERIMENTS_CONFIG_REPO_PATH):
    try:
        with open(configs_dict_path, "r") as f:
            configs_repo = json.load(f)
    except (json.decoder.JSONDecodeError, FileNotFoundError):
        configs_repo = {"args_hash_to_int": {}}

    config_hash = get_config_hash(cluster_args)
    config, int_id = get_config_by_hash(config_hash, configs_dict=configs_repo)

    # If config already exists, return its ID
    if config is not None:
        return config["config_id"]

    # If no config exists, create a new one
    configs_repo, int_id = add_config_by_hash(hash=config_hash, configs_dict=configs_repo)

    # Get a name for the new config
    config_name = input(
        f"\nNo previous config found for args:\n{cluster_args}.\nOptions:\n1. Press enter to use {int_id}"
        "\n2. Type '-a' to automatically generate a name"
        "\n3. Enter a name manually."
    ).strip()

    if config_name.lower() == "-a":
        config_name = create_config_name(cluster_args)
        input(f"Name created: {config_name}. Press enter to continue...")
    else:
        config_name = config_name if config_name else int_id

    # Add the new config to the configs_dict
    configs_repo[int_id] = {"config_id": config_name, "args": cluster_args}

    # Save the configs_dict to the file
    with open(configs_dict_path, "w") as f:
        json.dump(configs_repo, f, indent=2)

    return config_name


# ===============================================================================
# LINK: Config to name functions
# ===============================================================================


# config to dirnames
def history_type_to_name(history_type: str, prompt_filename: str):
    if history_type == "rationale_action":
        if "prev_actions" in prompt_filename:
            return f"text{sep4}a"
        else:
            return f"text{sep4}u{sep4}a"
    else:
        raise ValueError(f"history type naming not implemented for: {history_type}")


def agents_configs_to_dirname(cluster_args: dict, config_name: str):
    if "executor_agent" in cluster_args["agents_configs"]:
        config_name += f"{sep1}EXEC{sep1}" if "EXEC" not in config_name else ""
        agent_config = cluster_args["agents_configs"]["executor_agent"]

        if "prompt" in agent_config:
            prompt_name = agent_config["prompt"]["prompt_name"]
            history_type = get_attribute_from_dict(key_path="meta_data:history_type", dict_data=agent_config["prompt"])
            config_name += f"mem{sep3}{history_type_to_name(history_type, prompt_filename=prompt_name)}"

        if "num_previous_state_actions" in agent_config:
            t = agent_config["num_previous_state_actions"]
            config_name += f"{sep2}t{sep3}{t}"

    if "critique_agent" in cluster_args["agents_configs"]:
        config_name += f"{sep1}CRIT{sep1}" if "CRIT" not in config_name else ""
        agent_config = cluster_args["agents_configs"]["critique_agent"]

        if "mode" in agent_config:
            mode = agent_config["mode"]
            v = "2p" if mode == "two_pass" else "1p"
            config_name += f"m{sep3}{v}"

        if "prompt" in agent_config:
            prompt_name = agent_config["prompt"]["prompt_name"]
            eval_type = "bin" if "binary" in prompt_name else "tri"
            config_name += f"{sep2}eval{sep3}{eval_type}"
        # if "num_previous_state_actions" in k:
        #     dir_name += f"-t={v}"

    if "image_refiner" in cluster_args["agents_configs"]:
        config_name += f"{sep1}IMGREF{sep1}" if "IMGREF" not in config_name else ""
        agent_config = cluster_args["agents_configs"]["image_refiner"]

        if "max_refine" in agent_config:
            max_refine = agent_config["max_refine"]
            config_name += f"{sep2}max{sep3}{max_refine}"

    return config_name


def lm_configs_to_dirname(cluster_args: dict):
    pass


def create_config_name(cluster_args: dict):
    config_name = ""
    if "agents_configs" in cluster_args:
        config_name += agents_configs_to_dirname(cluster_args, config_name)

    if "lm_configs" in cluster_args:
        pass  # TODO
        # config_name += lm_configs_to_dirname(cluster_args)
    return config_name


def config_to_dirname(cluster_args: dict, config_name: str = "", dir_name_template: str = DIR_NAME_TEMPLATE):
    # Obs.: set for dir_name_template = "{model}/{config_name}/{domain}"
    # TODO: make it more general

    # Parse model from `agents_models` or `agents_configs`
    if "agents_models" in cluster_args:
        model = cluster_args["agents_models"][0].split(":")[-1].strip()
        dir_name_template = dir_name_template.replace("{model}", f"{model}")

    elif "agents_configs" in cluster_args:
        model = cluster_args["agents_configs"]["executor_agent"]["lm_config"]["model"]
        dir_name_template = dir_name_template.replace("{model}", f"{model}")

    # Parse the domain from `test_config_base_dir`
    if "test_config_base_dir" in cluster_args:
        dir_no_domain, domain = parse_test_config_dir(cluster_args["test_config_base_dir"])

        if "not_vague" in dir_no_domain:
            dir_name_template = dir_name_template.replace("{domain}", domain)
        else:
            dir_name_template = dir_name_template.replace("{domain}", f"{domain}-orig_tasks")

    # If no config_name was provided, create one
    if not config_name:
        config_name = create_config_name(cluster_args)

    return dir_name_template.format(config_name=config_name)


# config to config_ids
def get_config_hash(cluster_args: dict):
    """Get hash of cluster_args excluding args that are not Agent/Model specific. E.g.: domain"""

    # Create a copy of the dictionary to avoid modifying the original
    args_copy = cluster_args.copy()

    test_config_base_dir = args_copy["test_config_base_dir"]
    dir_no_domain, _ = parse_test_config_dir(test_config_base_dir)
    args_copy["test_config_base_dir"] = dir_no_domain
    return stable_json_hash(args_copy)


# ===============================================================================
# LINK Parsers for specific `args` entries
# ===============================================================================


def parse_prompt_config(prompt_name: str):
    prompt_path = f"{json_prompts_dir}/{prompt_name}.json"

    if not os.path.exists(prompt_path):
        return {"prompt_name": prompt_name}

    with open(prompt_path, "r") as f:
        prompt_dict = json.load(f)
        # prompt_dict = subselect_configs(prompt_dict, PROMPT_CONFIGS_TO_CLUSTER)
        prompt_dict["prompt_name"] = prompt_name
        return prompt_dict


def parse_test_config_dir(test_config_value):
    parts = Path(test_config_value).parts
    domain = re.sub("test_", "", parts[-1])
    dir_no_domain = str(Path().joinpath(*parts[:-1]))
    return dir_no_domain, domain


def subselect_configs(all_configs: dict, configs_to_cluster: list[str]):
    final_configs = OrderedDict()
    for k, arg in all_configs.items():
        if k not in configs_to_cluster:
            continue
        final_configs[k] = arg
    return final_configs


def parse_lm_configs(all_lm_args: dict):
    # Sort all_agent_args by key
    all_lm_args = OrderedDict(sorted(all_lm_args.items()))
    return subselect_configs(all_lm_args, LM_CONFIGS_TO_CLUSTER)


def parse_agent_configs(agent_key, all_agent_args: dict):
    # Sort all_agent_args by key
    all_agent_args = OrderedDict(sorted(all_agent_args.items()))

    sub_args = subselect_configs(all_agent_args, AGENT_CONFIGS_TO_CLUSTER[agent_key])

    for k, arg in sub_args.items():
        if "prompt" in k:
            arg = parse_prompt_config(arg)

        if "lm_config" in k:
            arg = parse_lm_configs(arg)

        sub_args[k] = arg

    return sub_args


def get_args_to_cluster(all_args: dict) -> dict:
    """
    Get a subset of `args` that will be used to cluster runs.
    """
    args_to_cluster = OrderedDict()
    for key in ARGS_TO_CLUSTER:
        arg = get_attribute_from_dict(key_path=key, dict_data=all_args, delimiter=":")

        # If args.json file does not contain the arg to cluser, skip.
        if arg is None:
            continue

        # For `agents_configs`, parse only selected elements in AGENT_CONFIGS_TO_CLUSTER
        if key == "agents_configs":
            for module_name, module_config in arg.items():
                if not isinstance(module_config, dict):
                    continue
                arg[module_name] = parse_agent_configs(module_name, module_config)

        args_to_cluster[key] = arg
    return args_to_cluster


# ===============================================================================
# LINK: Clustering functions
# ===============================================================================


def cluster_runs_by_time(
    runs_with_times: list[tuple[str, datetime | None]], max_diff_minutes: int
) -> list[list[tuple[str, datetime | None]]]:
    """
    Given a list of (execution_dir, run_time), where run_time is a datetime or None,
    cluster them so consecutive runs in the same cluster do not exceed 'max_diff_minutes'
    in difference from the *previous* run in that cluster.

    Returns a list of clusters, where each cluster is a list of (execution_dir, run_time).
    """
    if not runs_with_times:
        return []

    # Filter out runs that have no datetime
    timed_dirs = [(exec_dir, t) for (exec_dir, t) in runs_with_times if t is not None]

    # Sort known times by actual datetime
    timed_dirs.sort(key=lambda x: x[1])

    clusters = []
    if timed_dirs:
        last_time = timed_dirs[0][1]
        current_cluster = {"init_time": last_time, "run_dirs": [timed_dirs[0][0]]}

        for exec_dir, t in timed_dirs[1:]:
            diff = (t - last_time).total_seconds() / 60.0
            if diff <= max_diff_minutes:
                # same cluster
                current_cluster["run_dirs"].append(exec_dir)
            else:
                # start a new cluster
                clusters.append(current_cluster)
                current_cluster = {"init_time": t, "run_dirs": [exec_dir]}
            last_time = t

        # Append the final cluster
        clusters.append(current_cluster)

    non_timed_dirs = [exec_dir for (exec_dir, t) in runs_with_times if t is None]
    if non_timed_dirs:
        clusters.append({"init_time": None, "run_dirs": non_timed_dirs})
    return clusters


def cluster_runs(results_dir, save_dir, exclude_dirs: list[str] = [], verbose: bool = False) -> dict[str, list[str]]:
    # Get all args.json paths in the `results_dir` and it's subdirs
    args_files = find_files(results_dir, "args.json", upwards=False, downwards=True)

    # Remove the files containing any of the strings in `exclude_dirs`
    args_files = [f for f in args_files if not any(d in str(Path(f).resolve()) for d in exclude_dirs)]

    # Remove the files that are not finished running (redundancy because call outside, but OK to have in the function for safety)
    args_files = [f for f in args_files if is_finished_execution(Path(f).parent)]

    # For each experiment run, get the dir path containing results and the clustering args from `args.json`
    args_per_run = []
    for args_file in args_files:
        # Load the args.json file
        all_args = get_args(args_file)

        # Directory containing results for this run
        run_dir = os.path.dirname(args_file)

        # Get the args to cluster runs
        args_per_run.append((run_dir, get_args_to_cluster(all_args)))

    # Group directories that have identical args
    grouped = {}  # maps args to directories of runs that used those args
    for run_dir, args_to_cluster in args_per_run:
        # Get a hash for the args being used for clustering
        args_hash = stable_json_hash(args_to_cluster)

        if args_hash not in grouped:
            grouped[args_hash] = {"args": args_to_cluster, "run_dirs": []}
        grouped[args_hash]["run_dirs"].append(run_dir)

    # Creates directories for each cluster
    # Note: the directory names are not the same as the clusters by args.
    # To create the directories, code below also cluster by time and applying special logic to some elements in `args` such as the domain
    # This way, we get something like: `.../config_id/` containing `shopping/runs` `reddit/runs`, etc. See `config_to_dirname` and DIR_NAME_TEMPLATE

    topdir_to_rundirs: dict[str, list[str]] = {}  # maps the final directories to the individual run directories
    for args_hash in grouped:
        # Get the config_id for config in current args
        config_name = get_add_config_name(grouped[args_hash]["args"])

        # Get the directory name for the corresponding config
        cluster_dir_name = config_to_dirname(grouped[args_hash]["args"], config_name)
        # Depending on the directory being organized, can duplicate some parts
        # E.g.: if "results/google/gemini-1.5-flash-002" given => "results/google/gemini-1.5-flash-002/gemini-1.5-flash-002"
        # Added this for convenience to use this script both on `results/google` and `results/google/gemini-1.5-flash-002`
        # without creating duplicate parts.
        cluster_dir_name = path_without_duplicates(save_dir=save_dir, base_dirname=cluster_dir_name)

        # --- For directories with the same args, additionally cluster them by time ---
        # Extract execution time from each run directory
        runs_with_times = [(run_dir, extract_run_datetime(run_dir)) for run_dir in grouped[args_hash]["run_dirs"]]

        # Cluster runs by time
        run_dirs_by_time = cluster_runs_by_time(runs_with_times, max_minute_diff)

        # Assign a directory for each cluster
        for run_dirs in run_dirs_by_time:
            time_ann = (
                "" if run_dirs["init_time"] is None else f"-{datetime.strftime(run_dirs['init_time'], '%Y-%m-%d-%H%M')}"
            )
            cluster_dir = Path(f"{cluster_dir_name}{time_ann}")
            topdir_to_rundirs[cluster_dir] = run_dirs["run_dirs"]

    if verbose:
        print("=================")
        print("Proposed directory structure:")
        print("=================")
        # Build the tree from your destination dict:
        dest_tree = build_destination_tree(topdir_to_rundirs)
        print_tree(dest_tree, print_leaves=False)
        user_input = input("Press Enter to continue enter 'q' to quit...").lower()
        if user_input == "q":
            sys.exit()

    return topdir_to_rundirs


def get_not_finished_dirs(dir_to_organize: str, exclude_dirs: list[str] = []):
    args_files = find_files(dir_to_organize, "args.json", upwards=False, downwards=True)
    args_files = [Path(f) for f in args_files if not any(exclude_dir in f for exclude_dir in exclude_dirs)]
    running_dirs = set(f.parent for f in args_files if not is_finished_execution(f.parent))
    return [str(d) for d in running_dirs]


if __name__ == "__main__":
    args = parse_args()
    dir_to_organize = args.d
    exclude_dirs = args.e
    sanity_check = args.s

    # dir_to_organize = "results/gemini-2.0-flash-001"

    if not os.path.exists(dir_to_organize):
        sys.exit(f"Directory {dir_to_organize} does not exist")

    exclude_dirs.extend(get_not_finished_dirs(dir_to_organize, exclude_dirs=exclude_dirs))

    if sanity_check:
        print("Creating hashes for original directory...")
        sanity_checker = SanityChecker(num_processes=-1)
        # dir_to_compare = "results/"
        sanity_checker.set_original_hashes(dir_to_organize, exclude_dirs=exclude_dirs)

    save_dir = dir_to_organize
    topdir_to_rundirs = cluster_runs(dir_to_organize, save_dir=save_dir, exclude_dirs=exclude_dirs, verbose=True)

    for cluster_dirname, run_dirs in topdir_to_rundirs.items():
        os.makedirs(cluster_dirname, exist_ok=True)
        for exec_dir in run_dirs:
            dest_path = Path(cluster_dirname) / os.path.basename(exec_dir)
            if os.path.exists(dest_path):
                print(f"Skipping {exec_dir} because {dest_path} already exists")
                continue
            # move exec_dir to top_dirname
            copy_move(
                src=exec_dir,
                dst=dest_path,
                mode="move",
                merge_dirs=False,  # False: in cases like copy/move `classifieds/tasks1` -> `dirA/classifieds/tasks1` and  `dirA/classifieds/tasks1` already exists,
                # then `classifieds/tasks1` -> `dirA/classifieds/tasks1_<uid>` instead (these are cases of different experiments but equal relative paths)
                copy_move_only_new=False,  # False as we are refactoring the same folder (if True, would copy/move only files from src not in dst)
                overwrite_file=False,  # If copy/move `x.html` -> `dirA/x.html` and  `dirA/x.html` already exists, then `x.html` -> `dirA/x_<uid>.html`
            )

    remove_empty_dirs(dir_to_organize, exclude_dirs=exclude_dirs)

    if sanity_check:
        print("Creating hashes for new directory...")
        sanity_checker.set_new_hashes(save_dir, exclude_dirs=exclude_dirs)
        sanity_checker.sanity_check()
