import os
import json
from tqdm import tqdm
import numpy as np


# Traverse directory
def process_directory(base_path):
    metric_directories = []

    for root, dirs, files in os.walk(base_path):
        if any(file.endswith(".json") and file != "metrics.json" for file in files):
            metric_directories.append((root, files))

    for directory, files in tqdm(
        metric_directories, desc="Processing metric directories"
    ):
        process_files(directory, files)


# Combine JSON files and create master JSON file this will combine smaller JSON 
# files that were used for logging efficiency into a single JSON file with the 
# environment step count as the outermost keys. 
def process_files(directory, files):
    combined_data = {}

    for file in files:
        if file.endswith(".json") and file != "metrics.json":
            file_path = os.path.join(directory, file)
            key = os.path.splitext(file)[0]

            with open(file_path, "r") as f:
                file_data = json.load(f)

            combined_data[key] = file_data

    # Sort keys numerically
    sorted_data = {k: combined_data[k] for k in sorted(combined_data, key=int)}

    metric_name = os.path.basename(directory)
    master_file_path = os.path.join(directory, f"master_{metric_name}.json")
    with open(master_file_path, "w") as f:
        json.dump(sorted_data, f)

# Remove all the smaller JOSN files from the code and keep only the created 
# master files. 
def find_files_to_delete(base_path):
    files_to_delete = []

    for root, dirs, files in os.walk(base_path):
        if any(
            file.endswith(".json")
            and file != "metrics.json"
            and not file.startswith("master_")
            for file in files
        ):
            files_to_delete.extend(
                [
                    (root, file)
                    for file in files
                    if file.endswith(".json")
                    and file != "metrics.json"
                    and not file.startswith("master_")
                ]
            )

    return files_to_delete


# Delete files
def delete_files(files_to_delete):
    for root, file in tqdm(files_to_delete, desc="Deleting files"):
        file_path = os.path.join(root, file)
        os.remove(file_path)


def count_metrics_files(folder_path):
    count = 0
    for root, dirs, files in os.walk(folder_path):
        for file in files:
            if file == "metrics.json":
                count += 1
    return count

# Merge data into a single JOSN file for downstream computation of the 
# sample efficiency curves, the probability of improvement plots, the aggregated  
# environment scores and the performance profile curves. 
def merge_data(path_use):
    data = {}
    results_folder = path_use
    total_files = count_metrics_files(results_folder)
    progress_bar = tqdm(total=total_files, desc="Processing metrics.json files")

    for algo_folder in os.listdir(results_folder):
        algo_path = os.path.join(results_folder, algo_folder)
        for env_task_folder in os.listdir(algo_path):
            if not env_task_folder.__contains__(":"):
                env_name, task_name = env_task_folder.split("_")
            else:
                env_name, task_name = env_task_folder.split(":")
            env_task_path = os.path.join(algo_path, env_task_folder)
            for seed_folder in os.listdir(env_task_path):
                seed_path = os.path.join(env_task_path, seed_folder)
                metrics_path = os.path.join(seed_path, "metrics.json")
                try:
                    with open(metrics_path) as f:
                        metrics_data = json.load(f)
                except json.JSONDecodeError:
                    print(f"Error: Unable to parse {metrics_path}")
                    progress_bar.update(1)
                    continue

                step_counts = metrics_data["test_return_mean_T"]
                metric_data = metrics_data["test_return_mean"]
                for i, step_count in enumerate(step_counts):
                    run_name = f"run_{seed_folder}"
                    step_name = f"step_{i+1}"
                    if env_name not in data:
                        data[env_name] = {}
                    if task_name not in data[env_name]:
                        data[env_name][task_name] = {}
                    if algo_folder not in data[env_name][task_name]:
                        data[env_name][task_name][algo_folder] = {}
                    if run_name not in data[env_name][task_name][algo_folder]:
                        data[env_name][task_name][algo_folder][run_name] = {}
                    data[env_name][task_name][algo_folder][run_name][step_name] = {
                        "step_count": step_count,
                        "test_return_mean": metric_data[i],
                    }

                # Name of absolute metric return key
                # absolute_metric_return_mean
                data[env_name][task_name][algo_folder][run_name]["absolute_metrics"] = {
                    "test_return_mean": metrics_data["absolute_metric_return_mean"][0],
                }

                progress_bar.update(1)

    progress_bar.close()

    with open(f"{path_use}.json", "w") as outfile:
        json.dump(data, outfile, indent=4)


def interpolate_values(case_steps, num_entries=201):
    step_indices = (
        np.linspace(1, len(case_steps), num=num_entries, endpoint=True, dtype=int) - 1
    )
    return [case_steps[i] for i in step_indices]

# Make the evaluation step counts between on-policy and off-policy algorithms the same. 
def process_algo_stepcounts(path_use):
    with open(f"{path_use}.json", "r") as input_file:
        json_data = json.load(input_file)

    new_data = {}
    for env, tasks in json_data.items():
        new_data[env] = {}
        for task, algorithms in tasks.items():
            new_data[env][task] = {}
            for algo, runs in algorithms.items():
                new_data[env][task][algo] = {}
                for run, steps in runs.items():
                    new_data[env][task][algo][run] = {}
                    step_count_list = [
                        int(x.split("_")[1]) for x in steps.keys() if "step" in x
                    ]

                    # Determine if it's a high-step-count or low-step-count run
                    max_step_count = max(
                        steps[f"step_{step}"]["step_count"] for step in step_count_list
                    )
                    is_high_step_count = max_step_count > 2_500_000

                    processed_steps = []

                    for step, values in steps.items():
                        if not step.startswith("step_"):
                            continue

                        if is_high_step_count:
                            values["step_count"] //= 10
                        processed_steps.append((step, values))

                    if len(processed_steps) != 201:
                        processed_steps = interpolate_values(processed_steps)

                    for idx, (step, values) in enumerate(processed_steps):
                        new_step = f"step_{idx + 1}"
                        new_data[env][task][algo][run][new_step] = values

                    # Add absolute_metric information as a separate step
                    new_data[env][task][algo][run]["absolute_metrics"] = steps[
                        "absolute_metrics"
                    ]

    # Save the modified JSON data
    with open(f"{path_use}.json", "w") as output_file:
        json.dump(new_data, output_file, indent=4)

# Remove additional seeds when all seeds did not complete. This is for downstream  
# computation of the sample efficiency curves, the probability of improvement plots, 
# the aggregated environment scores and the performance profile curves. 
def remove_extra_runs(path_use):
    with open(f"{path_use}.json", "r") as f:
        input_json = json.load(f)

    min_runs = float("inf")

    # Find the minimum number of runs across all configurations
    for env in input_json:
        for task in input_json[env]:
            for algo in input_json[env][task]:
                num_runs = len(input_json[env][task][algo])
                min_runs = min(min_runs, num_runs)

    # Create a new JSON with the fixed number of runs
    output_json = {}
    for env in input_json:
        output_json[env] = {}
        for task in input_json[env]:
            output_json[env][task] = {}
            for algo in input_json[env][task]:
                output_json[env][task][algo] = {}
                run_keys = sorted(list(input_json[env][task][algo].keys()))
                for i in range(min_runs):
                    old_key = run_keys[i]
                    new_key = f"run_{i}"
                    output_json[env][task][algo][new_key] = input_json[env][task][algo][
                        old_key
                    ]

    # Write the fixed JSON file
    with open(f"{path_use}.json", "w") as f:
        json.dump(output_json, f, indent=2)


if __name__ == "__main__":
    # Replace with the path to the data directory and experiment name
    data_path = "path to where data is stored"
    exp_name = "name of the experiment folder"
    path_use = data_path + exp_name

    process_directory(path_use)
    files_to_delete = find_files_to_delete(path_use)
    delete_files(files_to_delete)
    merge_data(path_use)
    process_algo_stepcounts(path_use)
    remove_extra_runs(path_use)
