from datetime import datetime
from re import L
from matplotlib.style import available
from networkx import difference
import numpy as np
import pandas as pd
import os
import json

from pytest import skip
from CausalDisco.analytics import r2coeff
from tqdm import tqdm

from mi_estimators.mi_hsic import hsic_estimate, nHSIC_estimate
from plot_training_curve import additional_information, load_predictions


def experiment_type(experiment):
    if "results_xy" in experiment["result"]:
        return "CANM"
    elif "results_recovering_x_given_y" in experiment["result"]:
        return "Diffusion"
    elif "linear_ANM":
        return "Baseline"
    else:
        raise ValueError("Unknown type of experiment")


def score_experiment(experiment):
    if difference_in_loss(experiment) < 0:
        return 1
    return 0


def score_experiment_test(experiment):
    if difference_in_test_loss(experiment) < 0:
        return 1
    return 0


def get_full_metric_data(
    experiment, metric, start_timestep=0, end_timestep=-1, suffix=""
):
    key = f"test_loss{suffix}"
    if f"test_loss{suffix}" not in experiment["result"]["results_recovering_x_given_y"]:
        if (
            f"test_full_loss{suffix}"
            in experiment["result"]["results_recovering_x_given_y"]
        ):
            key = f"test_full_loss{suffix}"
        else:
            return [], []

    if f"test_loss{suffix}" not in experiment["result"][
        "results_recovering_x_given_y"
    ] or not isinstance(
        experiment["result"]["results_recovering_x_given_y"][f"test_loss{suffix}"], list
    ):
        if (
            f"test_full_loss{suffix}"
            in experiment["result"]["results_recovering_x_given_y"]
        ):
            key = f"test_full_loss{suffix}"
        else:
            return [], []

    x_given_y = [
        d[metric]
        for d in experiment["result"]["results_recovering_x_given_y"][key][
            start_timestep:end_timestep
        ]
    ]
    y_given_x = [
        d[metric]
        for d in experiment["result"]["results_recovering_y_given_x"][key][
            start_timestep:end_timestep
        ]
    ]

    return x_given_y, y_given_x


def difference_in_full_metric(
    experiment, metric, start_timestep=0, end_timestep=-1, suffix=""
):

    x_given_y, y_given_x = get_full_metric_data(
        experiment, metric, start_timestep, end_timestep, suffix
    )
    full_average_x_given_y = np.mean(x_given_y)
    full_average_y_given_x = np.mean(y_given_x)

    return full_average_y_given_x - full_average_x_given_y


def difference_in_full_metric_vote(
    experiment, metric, start_timestep=0, end_timestep=-1, suffix=""
):
    x_given_y, y_given_x = get_full_metric_data(
        experiment,
        metric,
        start_timestep,
        end_timestep,
        suffix,
    )
    is_smaller = np.less(y_given_x, x_given_y)
    if len(x_given_y) == 0:
        return 0.0
    return float(np.sum(is_smaller)) / len(is_smaller)


def score_experiment_full_metric(
    experiment, metric, start_timestep=0, end_timestep=-1, suffix=""
):

    if (
        difference_in_full_metric(
            experiment,
            metric,
            start_timestep,
            end_timestep,
            suffix,
        )
        < 0
    ):
        return 1
    return 0


def score_experiment_full_metric_vote(
    experiment, metric, start_timestep=0, end_timestep=-1, suffix=""
):
    x_given_y, y_given_x = get_full_metric_data(
        experiment, metric, start_timestep, end_timestep, suffix
    )
    is_smaller = np.less(y_given_x, x_given_y)
    if np.sum(is_smaller) > len(is_smaller) / 2:
        return 1
    return 0


def score_experiment_full_metric_weighted_vote(
    experiment, metric, start_timestep=0, end_timestep=-1
):
    if (
        "test_full_loss" not in experiment["result"]["results_recovering_x_given_y"]
        or metric
        not in experiment["result"]["results_recovering_x_given_y"]["test_full_loss"][0]
    ):
        return 0
    x_given_y = [
        d[metric]
        for d in experiment["result"]["results_recovering_x_given_y"]["test_full_loss"][
            start_timestep:end_timestep
        ]
    ]
    y_given_x = [
        d[metric]
        for d in experiment["result"]["results_recovering_y_given_x"]["test_full_loss"][
            start_timestep:end_timestep
        ]
    ]
    eps = 1e-8
    numerator = np.array(y_given_x) - np.array(x_given_y)
    denominator = np.abs(x_given_y) + np.abs(y_given_x) + eps
    relative_score = numerator / denominator  # ∈ (-1, 1)

    weighted_score = relative_score.sum()
    if weighted_score > 0:
        return 1
    return 0


def score_experiment_train_test(experiment):
    if difference_in_train_vs_test_loss(experiment) < 0:
        return 1
    return 0


def score_experiment_last_train_test(experiment):
    if difference_in_last_train_vs_test_loss(experiment) < 0:
        return 1
    return 0


def score_experiment_mutual_information(experiment):
    if difference_in_mutual_information(experiment) < 0:
        return 1
    return 0


def score_train_metric(experiment, metric, average_over=1):
    if metric == "mi_binning":
        # print(f"MI Binning diff: {difference_in_metric(experiment, metric)}")
        if difference_in_train_metric(experiment, metric, average_over) > 0:
            return 1
        else:
            # print(f"Wrong MI Binning: {experiment['id']}")
            return 0
    if difference_in_train_metric(experiment, metric, average_over) < 0:
        return 1
    return 0


def difference_in_train_metric(experiment, metric, average_over=1):
    type = experiment_type(experiment)
    if type == "Diffusion":
        if (
            metric
            not in experiment["result"]["results_recovering_y_given_x"][
                "train_logging_infos"
            ][-1]
        ):
            return 0.0

        y_given_x = [
            result[metric]
            for result in experiment["result"]["results_recovering_y_given_x"][
                "train_logging_infos"
            ][-average_over:]
        ]

        x_given_y = [
            result[metric]
            for result in experiment["result"]["results_recovering_x_given_y"][
                "train_logging_infos"
            ][-average_over:]
        ]

        diff = np.mean(y_given_x) - np.mean(x_given_y)

        return diff
    else:
        raise ValueError("Unknown type of experiment")


def average_in_train_direction_metric(
    experiment, metric, direction="results_recovering_x_given_y", average_over=1
):
    type = experiment_type(experiment)
    if type == "Diffusion":
        if metric not in experiment["result"][direction]["train_logging_infos"][-1]:
            return 0.0

        data = [
            result[metric]
            for result in experiment["result"][direction]["train_logging_infos"][
                -average_over:
            ]
        ]

        avg = np.mean(data)

        return avg
    else:
        raise ValueError("Unknown type of experiment")


def difference_in_validation_metric(experiment, metric, average_over=1):
    type = experiment_type(experiment)
    if type == "Diffusion":
        if (
            len(
                experiment["result"]["results_recovering_y_given_x"][
                    "validation_logging_infos"
                ]
            )
            == 0
            or metric
            not in experiment["result"]["results_recovering_y_given_x"][
                "validation_logging_infos"
            ][-1]
        ):
            return 0.0

        y_given_x = [
            result[metric]
            for result in experiment["result"]["results_recovering_y_given_x"][
                "validation_logging_infos"
            ][-average_over:]
        ]

        x_given_y = [
            result[metric]
            for result in experiment["result"]["results_recovering_x_given_y"][
                "validation_logging_infos"
            ][-average_over:]
        ]

        diff = np.mean(y_given_x) - np.mean(x_given_y)

        # if metric == "mi_binning":
        #     print(f"Debug: Metric '{metric}' difference calculated as {diff}")
        #     print(
        #         experiment["result"]["results_recovering_y_given_x"][
        #             "validation_logging_infos"
        #         ][-1][metric]
        #     )
        #     print(
        #         experiment["result"]["results_recovering_x_given_y"][
        #             "validation_logging_infos"
        #         ][-1][metric]
        #     )
        return diff
    else:
        raise ValueError("Unknown type of experiment")


def difference_in_additional_metric(experiment, metric, average_over=1):
    type = experiment_type(experiment)
    if type == "Diffusion":
        if (
            len(
                experiment.get("additional_information", {}).get(
                    "results_recovering_y_given_x", []
                )
            )
            == 0
            or experiment.get("additional_information", {})
            .get("results_recovering_y_given_x", [])[0]
            .get(metric, None)
            is None
        ):
            return 0.0
        if (
            metric
            not in experiment["additional_information"]["results_recovering_y_given_x"][
                -1
            ]
        ):
            return 0.0

        y_given_x = [
            result[metric]
            for result in experiment["additional_information"][
                "results_recovering_y_given_x"
            ][-average_over:]
        ]
        x_given_y = [
            result[metric]
            for result in experiment["additional_information"][
                "results_recovering_x_given_y"
            ][-average_over:]
        ]
        diff = np.mean(y_given_x) - np.mean(x_given_y)
        return diff
    else:
        raise ValueError("Unknown type of experiment")


def score_additional_metric(experiment, metric, average_over=1):
    return int(difference_in_additional_metric(experiment, metric, average_over) < 0)


def score_validation_metric(experiment, metric, average_over=1):
    if metric == "mi_binning":
        # print(f"MI Binning diff: {difference_in_metric(experiment, metric)}")
        if difference_in_validation_metric(experiment, metric, average_over) > 0:
            return 1
        else:
            # print(f"Wrong MI Binning: {experiment['id']}")
            return 0
    if difference_in_validation_metric(experiment, metric, average_over) < 0:
        return 1
    return 0


def difference_in_mutual_information(experiment):
    type = experiment_type(experiment)
    if type == "Diffusion":
        diff = (
            experiment["result"]["results_recovering_y_given_x"]["train_logging_infos"][
                -1
            ]["mi_estimate"]
            - experiment["result"]["results_recovering_x_given_y"][
                "train_logging_infos"
            ][-1]["mi_estimate"]
        )
        # print(f"MI Diff: {diff}")
        return diff
    else:
        raise ValueError("Unknown type of experiment")


def difference_in_loss(experiment):
    type = experiment_type(experiment)
    if type == "CANM":
        likelihood_xy = experiment["result"]["results_xy"]["train_likelihood"]
        likelihood_yx = experiment["result"]["results_yx"]["train_likelihood"]
        return likelihood_yx - likelihood_xy
    elif type == "Diffusion":
        return (
            experiment["result"]["results_recovering_y_given_x"]["min_loss"]
            - experiment["result"]["results_recovering_x_given_y"]["min_loss"]
        )
    else:
        raise ValueError("Unknown type of experiment")


def difference_in_test_loss(experiment):
    type = experiment_type(experiment)
    if type == "Diffusion":
        if "test_loss" in experiment["result"]["results_recovering_y_given_x"]:
            return (
                experiment["result"]["results_recovering_y_given_x"]["test_loss"]
                - experiment["result"]["results_recovering_x_given_y"]["test_loss"]
            )
        return 0.0
    elif type == "CANM":
        return 0.0
    else:
        raise ValueError("Unknown type of experiment")


def difference_in_train_vs_test_loss(experiment):
    type = experiment_type(experiment)
    if type == "Diffusion":
        if "test_loss" in experiment["result"]["results_recovering_y_given_x"]:
            return abs(
                experiment["result"]["results_recovering_y_given_x"]["test_loss"]
                - experiment["result"]["results_recovering_y_given_x"]["min_loss"]
            ) - abs(
                experiment["result"]["results_recovering_x_given_y"]["test_loss"]
                - experiment["result"]["results_recovering_x_given_y"]["min_loss"]
            )
        return 0.0
    elif type == "CANM":
        return 0.0
    else:
        raise ValueError("Unknown type of experiment")


def difference_in_last_train_vs_test_loss(experiment):
    type = experiment_type(experiment)
    if type == "Diffusion":
        if "test_loss" in experiment["result"]["results_recovering_y_given_x"]:
            return abs(
                experiment["result"]["results_recovering_y_given_x"]["losses"][-1]
                - experiment["result"]["results_recovering_y_given_x"]["min_loss"]
            ) - abs(
                experiment["result"]["results_recovering_x_given_y"]["losses"][-1]
                - experiment["result"]["results_recovering_x_given_y"]["min_loss"]
            )
        return 0.0
    elif type == "CANM":
        return 0.0
    else:
        raise ValueError("Unknown type of experiment")


def difference_in_last_validation_loss(experiment):
    type = experiment_type(experiment)
    if type == "Diffusion":
        if "validation_losses" in experiment["result"]["results_recovering_y_given_x"]:
            y_given_x_losses = [
                loss
                for loss in experiment["result"]["results_recovering_y_given_x"][
                    "validation_losses"
                ]
                if not pd.isna(loss)
            ]
            x_given_y_losses = [
                loss
                for loss in experiment["result"]["results_recovering_x_given_y"][
                    "validation_losses"
                ]
                if not pd.isna(loss)
            ]
            if y_given_x_losses and x_given_y_losses:
                diff = y_given_x_losses[-1] - x_given_y_losses[-1]
            else:
                diff = 0.0
            # print(f"validation diff {diff}")
            return diff
        return 0.0
    elif type == "CANM":
        return 0.0
    else:
        raise ValueError("Unknown type of experiment")


def difference_in_last_10_validation_loss(experiment):
    type = experiment_type(experiment)
    if type == "Diffusion":
        if "validation_losses" in experiment["result"]["results_recovering_y_given_x"]:
            y_given_x_losses = [
                loss
                for loss in experiment["result"]["results_recovering_y_given_x"][
                    "validation_losses"
                ]
                if not pd.isna(loss)
            ]
            x_given_y_losses = [
                loss
                for loss in experiment["result"]["results_recovering_x_given_y"][
                    "validation_losses"
                ]
                if not pd.isna(loss)
            ]
            if y_given_x_losses and x_given_y_losses:
                diff = sum(y_given_x_losses[-10:]) / len(y_given_x_losses[-10:]) - sum(
                    x_given_y_losses[-10:]
                ) / len(x_given_y_losses[-10:])
            else:
                diff = 0.0
            # print(diff)
            return diff
        return 0.0
    elif type == "CANM":
        return 0.0
    else:
        raise ValueError("Unknown type of experiment")


def score_experiment_last_validation_loss(experiment):
    if difference_in_last_validation_loss(experiment) < 0:
        return 1
    return 0


def score_experiment_last_10_validation_loss(experiment):
    if difference_in_last_10_validation_loss(experiment) < 0:
        return 1
    return 0


def get_data_generation_dict(config):
    if "data_generation_dict" in config:
        return config["data_generation_dict"]
    elif "data_config" in config:
        return config["data_config"]["dictionary"]
    elif "data_ingredient" in config:
        return config["data_ingredient"]
    elif "data" in config:
        return config["data"]["dictionary"]
    else:
        raise ValueError("No data generation config found")


def generate_dataframe_from_experiments(experiments):
    items = []
    for e in tqdm(experiments):
        data_generation_dict = get_data_generation_dict(e["config"])

        if experiment_type(e) == "Baseline":
            # We need unpacking
            for key, result in e["result"].items():
                items.append(
                    {
                        "epochs": 0,
                        "model_type": "NA",
                        "type": key,
                        "x_noise_type": data_generation_dict["X"]["type"],
                        "depth": data_generation_dict["depth"],
                        "size": data_generation_dict["X"]["length"],
                        "shape": data_generation_dict["shape"],
                        "transform_type": data_generation_dict["transformation"][
                            "type"
                        ],
                        "transform_parameters": str(
                            data_generation_dict["transformation"]["args"]
                        ),
                        "noise_type": data_generation_dict.get("noise_type", "N/A"),
                        "standardize": data_generation_dict.get("standardize", True),
                        "seed": data_generation_dict["seed"],
                        "noise_parameters": str(
                            data_generation_dict.get("noise_parameters", "N/A")
                        ),
                        "mediator_noise_parameters_std": data_generation_dict.get(
                            "mediator_noise_parameters", {}
                        )
                        .get("std", {})
                        .get("value", "N/A"),
                        **result,
                        "data_id": e["config"].get("data_config", {}).get("id", "N/A"),
                    }
                )
            continue

        # parse as datetime (Python 3.7+)
        start = datetime.fromisoformat(e["start_time"])
        stop = datetime.fromisoformat(e["stop_time"])

        # compute elapsed seconds
        runtime_seconds = (stop - start).total_seconds()

        if experiment_type(e) == "CANM":
            test_size_metrics = {
                "difference_in_loss": difference_in_loss(e),
                "correct": score_experiment(e),
                "correct_test": score_experiment_test(e),
                "seed": data_generation_dict["seed"],
                "depth": data_generation_dict["depth"],
                "transform_type": data_generation_dict["transformation"]["type"],
                "transform_parameters": str(
                    data_generation_dict["transformation"]["args"]
                ),
                "noise_type": data_generation_dict.get("noise_type", "N/A"),
                "bestN": e["config"].get("depth", "N/A"),
                "type": "CANM",
                "size": data_generation_dict["X"]["length"],
                "data_id": e["config"].get("data_config", {}).get("id", "N/A"),
                "runtime": runtime_seconds,
            }
            items.append(test_size_metrics)
            continue

        test_size_metrics = {}
        for test_size in [5, 10, 50, 100, 200, ""]:
            for model in ["_last_model", "_train_and_test_last_model"]:
                suffix = f"{model}_{test_size}"
                if model == "_train_and_test_last_model":
                    suffix = "_train_and_test_last_model"
                if test_size == "":
                    suffix = "_last_model"

                # metrics = ["npeet", "npeet_5", "npeet_10", "hsic_estimate"]
                metrics = ["hsic_estimate"]
                for metric in metrics:
                    test_size_metrics[f"full_average_{metric}_diff{suffix}"] = (
                        difference_in_full_metric(
                            e,
                            metric,
                            start_timestep=0,
                            end_timestep=-1,
                            suffix=suffix,
                        )
                    )
                    test_size_metrics[f"full_average_{metric}_vote{suffix}"] = (
                        difference_in_full_metric_vote(
                            e,
                            metric,
                            start_timestep=0,
                            end_timestep=-1,
                            suffix=suffix,
                        )
                    )
                    test_size_metrics[f"correct_{metric}_vote{suffix}"] = (
                        score_experiment_full_metric_vote(
                            e,
                            metric,
                            start_timestep=0,
                            end_timestep=-1,
                            suffix=suffix,
                        )
                    )
                    test_size_metrics[f"correct_{metric}_average{suffix}"] = (
                        score_experiment_full_metric(
                            e,
                            metric,
                            start_timestep=0,
                            end_timestep=-1,
                            suffix=suffix,
                        )
                    )

        item = {
            **test_size_metrics,
            **{
                "name": e["config"].get("data_config", {}).get("name", "N/A"),
                "data_id": e["config"].get("data_config", {}).get("id", "N/A"),
                "n_samples": e["config"].get("data_config", {}).get("n_samples", "N/A"),
                "learning_rate": e["config"]["training_hyperparameters"].get(
                    "learning_rate", "N/A"
                ),
                "use_best_validation": e["config"].get("use_best_validation", False),
                "test_evaluation_passes": e["config"].get("test_evaluation_passes", 1),
                "x_noise_type": data_generation_dict["X"]["type"],
                "correct": score_experiment(e),
                "depth": data_generation_dict["depth"],
                "seed": data_generation_dict["seed"],
                "standardize": data_generation_dict.get("standardize", True),
                "beta_schedule": str(
                    e["config"]["training_hyperparameters"].get("beta_schedule", {})
                ),
                "optimizer_settings": str(
                    e["config"]["training_hyperparameters"].get(
                        "optimizer_settings", {}
                    )
                ),
                "size": data_generation_dict["X"]["length"],
                "shape": data_generation_dict["shape"],
                "type": experiment_type(e),
                "noise_type": data_generation_dict.get("noise_type", "N/A"),
                "noise_parameters": str(
                    data_generation_dict.get("noise_parameters", "N/A")
                ),
                "runtime": runtime_seconds,
                "architecture": tuple(
                    e["config"]["training_hyperparameters"].get("layer_sizes", "N/A")
                ),
                "hidden_dim": e["config"]["training_hyperparameters"].get(
                    "hidden_dim", "N/A"
                ),
                "num_blocks": e["config"]["training_hyperparameters"].get(
                    "num_blocks", "N/A"
                ),
                "transform_type": data_generation_dict["transformation"]["type"],
                "transform_parameters": str(
                    data_generation_dict["transformation"]["args"]
                ),
                "difference_in_loss": difference_in_loss(e),
                # "difference_in_test_loss": difference_in_test_loss(e),
                # "difference_in_train_vs_test_loss": difference_in_train_vs_test_loss(e),
                # "correct_test": score_experiment_test(e),
                # "correct_train_test": score_experiment_train_test(e),
                # "correct_last_train_test": score_experiment_last_train_test(e),
                # "correct_last_validation_loss": score_experiment_last_validation_loss(e),
                # "correct_last_10_validation_loss": score_experiment_last_10_validation_loss(
                #     e
                # ),
                # "difference_in_last_10_validation_loss": difference_in_last_10_validation_loss(
                #     e
                # ),
                # "difference_in_last_validation_loss": difference_in_last_validation_loss(e),
                "correct_train_correlation": score_train_metric(e, "correlation"),
                "correct_train_hsic_500": score_train_metric(
                    e, "hsic_estimate", average_over=500
                ),
                "correct_train_hsic_200": score_train_metric(
                    e, "hsic_estimate", average_over=200
                ),
                "correct_train_hsic_100": score_train_metric(
                    e, "hsic_estimate", average_over=100
                ),
                "correct_train_hsic_50": score_train_metric(
                    e, "hsic_estimate", average_over=50
                ),
                "correct_train_hsic_10": score_train_metric(
                    e, "hsic_estimate", average_over=10
                ),
                # "correct_train_mi_binning": score_train_metric(e, "mi_binning"),
                # "correct_train_mi_regression": score_train_metric(e, "mi_regression"),
                # "correct_train_mi_estimate": score_train_metric(e, "mi_estimate"),
                # "correct_val_correlation": score_validation_metric(e, "correlation"),
                # "correct_val_hsic_100": score_validation_metric(
                #     e, "hsic_estimate", average_over=100
                # ),
                # "correct_val_hsic_50": score_validation_metric(
                #     e, "hsic_estimate", average_over=50
                # ),
                # "correct_val_hsic_10": score_validation_metric(
                #     e, "hsic_estimate", average_over=10
                # ),
                # "correct_val_hsic_3": score_validation_metric(
                #     e, "hsic_estimate", average_over=3
                # ),
                # "correct_val_hsic_2": score_validation_metric(
                #     e, "hsic_estimate", average_over=2
                # ),
                # "correct_val_hsic_1": score_validation_metric(
                #     e, "hsic_estimate", average_over=1
                # ),
                # "difference_in_val_hsic_10": difference_in_validation_metric(
                #     e, "hsic_estimate", average_over=10
                # ),
                # "difference_in_val_hsic_3": difference_in_validation_metric(
                #     e, "hsic_estimate", average_over=3
                # ),
                # "correct_val_mi_binning": score_validation_metric(e, "mi_binning"),
                # "correct_val_mi_regression": score_validation_metric(e, "mi_regression"),
                # "correct_val_mi_estimate": score_validation_metric(e, "mi_estimate"),
                # "difference_in_correlation": difference_in_train_metric(e, "correlation"),
                "difference_in_hsic_100": difference_in_train_metric(
                    e, "hsic_estimate", average_over=100
                ),
                "difference_in_hsic_50": difference_in_train_metric(
                    e, "hsic_estimate", average_over=50
                ),
                "difference_in_hsic_10": difference_in_train_metric(
                    e, "hsic_estimate", average_over=10
                ),
                # "difference_in_mi_binning": difference_in_train_metric(e, "mi_binning"),
                # "difference_in_mi_regression": difference_in_train_metric(
                #     e, "mi_regression"
                # ),
                # "correct_mi_estimate": score_train_metric(e, "mi_estimate"),
                # "correct_nhsic_3": score_additional_metric(e, "nHSIC", average_over=3),
                # "correct_nhsic": score_additional_metric(e, "nHSIC"),
                # "difference_in_nhsic": difference_in_additional_metric(e, "nHSIC"),
                # "correct_hsic_norm": score_additional_metric(e, "HSIC_norm"),
                # "difference_in_hsic_norm": difference_in_additional_metric(e, "HSIC_norm"),
                # "difference_in_mutual_information": difference_in_mutual_information(e),
                "epochs": e["config"]["epochs"],
                "condition_dim": e["config"]["training_hyperparameters"].get(
                    "condition_dim",
                    -1,
                ),
                "condition_embed": e["config"]["training_hyperparameters"].get(
                    "condition_embed",
                    "True",
                ),
                "train_loss": e["config"]["training_hyperparameters"].get(
                    "train_loss",
                    "mse",
                ),
                "test_loss": e["config"]["training_hyperparameters"].get(
                    "test_loss",
                    "mse",
                ),
                "pretrained_model": e["config"].get("model_path", {}),
                "model_type": e["config"].get("model_type", "plain"),
                "mediator_noise_type": data_generation_dict.get(
                    "mediator_noise_type", "N/A"
                ),
                "mediator_noise_parameters": str(
                    data_generation_dict.get("mediator_noise_parameters", {})
                ),
                "max_noise": e["config"]["training_hyperparameters"].get(
                    "max_noise", -1
                ),
                "train_loss_schedule": str(
                    e["config"]["training_hyperparameters"].get(
                        "train_loss_schedule", {}
                    )
                ),
                "mse_x_given_y": average_in_train_direction_metric(
                    e, "mse", "results_recovering_x_given_y"
                ),
                "mse_y_given_x": average_in_train_direction_metric(
                    e, "mse", "results_recovering_y_given_x"
                ),
                "hsic_x_given_y": average_in_train_direction_metric(
                    e,
                    "hsic_estimate",
                    "results_recovering_x_given_y",
                    average_over=50,
                ),
                "hsic_y_given_x": average_in_train_direction_metric(
                    e,
                    "hsic_estimate",
                    "results_recovering_y_given_x",
                    average_over=50,
                ),
                # "full_average_hsic_x_given_y": e["result"]["results_recovering_x_given_y"]
                # .get("average_test_hsic", {})
                # .get("value", 0.0),
                # "full_average_hsic_y_given_x": e["result"]["results_recovering_y_given_x"]
                # .get("average_test_hsic", {})
                # .get("value", 0.0),
                "full_average_hsic_diff_best_model": difference_in_full_metric(
                    e,
                    "hsic_estimate",
                    start_timestep=0,
                    end_timestep=-1,
                    suffix="_best_model",
                ),
                "full_average_hsic_diff_last_model": difference_in_full_metric(
                    e,
                    "hsic_estimate",
                    start_timestep=0,
                    end_timestep=-1,
                    suffix="_last_model",
                ),
                # "full_average_npeet_x_given_y_best_model": e["result"][
                #     "results_recovering_x_given_y"
                # ]
                # .get("average_test_npeet_best_model", {})
                # .get("value", 0.0),
                # "full_average_npeet_y_given_x_best_model": e["result"][
                #     "results_recovering_y_given_x"
                # ]
                # .get("average_test_npeet_best_model", {})
                # .get("value", 0.0),
                # "full_average_npeet_diff_last_model": (
                #     difference_in_full_metric(
                #         e,
                #         "npeet",
                #         start_timestep=0,
                #         end_timestep=-1,
                #         suffix="_last_model",
                #     )
                # ),
                # "full_average_npeet_5_diff_last_model": (
                #     difference_in_full_metric(
                #         e,
                #         "npeet_5",
                #         start_timestep=0,
                #         end_timestep=-1,
                #         suffix="_last_model",
                #     )
                # ),
                # "full_average_npeet_10_diff_last_model": (
                #     difference_in_full_metric(
                #         e,
                #         "npeet_10",
                #         start_timestep=0,
                #         end_timestep=-1,
                #         suffix="_last_model",
                #     )
                # ),
                # "full_average_npeet_x_given_y_last_model": e["result"][
                #     "results_recovering_x_given_y"
                # ]
                # .get("average_test_npeet_last_model", {})
                # .get("value", 0.0),
                # "full_average_npeet_y_given_x_last_model": e["result"][
                #     "results_recovering_y_given_x"
                # ]
                # .get("average_test_npeet_last_model", {})
                # .get("value", 0.0),
                "correct_full_average_hsic_best_model": score_experiment_full_metric(
                    e, "hsic_estimate", suffix="_best_model"
                ),
                "correct_full_average_hsic_last_model": score_experiment_full_metric(
                    e, "hsic_estimate", suffix="_last_model"
                ),
                # "correct_full_average_npeet_best_model": score_experiment_full_metric(
                #     e, "npeet", suffix="_best_model"
                # ),
                # "correct_full_average_npeet_last_model": score_experiment_full_metric(
                #     e, "npeet", suffix="_last_model"
                # ),
                # "correct_full_vote_hsic2_estimate": score_experiment_full_metric_vote(
                #     e, "hsic2_estimate"
                # ),
                # "correct_full_average_mse": score_experiment_full_metric(e, "mse"),
                # "correct_full_average_hsic_cut": score_experiment_full_metric(
                #     e,
                #     "hsic_estimate",
                #     start_timestep=0,
                #     end_timestep=50,
                # ),
                "correct_full_hsic_vote_last_model": score_experiment_full_metric_vote(
                    e,
                    "hsic_estimate",
                    suffix="_last_model",
                ),
                "correct_full_hsic_vote_best_model": score_experiment_full_metric_vote(
                    e,
                    "hsic_estimate",
                ),
                # "correct_full_average_hsic_vote_cut": score_experiment_full_metric_vote(
                #     e,
                #     "hsic_estimate",
                #     start_timestep=0,
                #     end_timestep=250,
                # ),
                # "correct_full_average_hsic_vote_cut_1": score_experiment_full_metric_vote(
                #     e,
                #     "hsic_estimate",
                #     start_timestep=0,
                #     end_timestep=160,
                # ),
                # "correct_full_average_hsic_vote_cut_2": score_experiment_full_metric_vote(
                #     e,
                #     "hsic_estimate",
                #     start_timestep=10,
                #     end_timestep=240,
                # ),
                # "difference_full_average_npeet": difference_in_full_metric(
                #     e, "npeet", suffix="_last_model"
                # ),
                # "difference_full_vote_npeet": difference_in_full_metric_vote(
                #     e, "npeet", suffix="_last_model"
                # ),
                # "difference_full_average_npeet_5": difference_in_full_metric(
                #     e, "npeet_5", suffix="_last_model"
                # ),
                # "difference_full_average_npeet_5_vote": difference_in_full_metric_vote(
                #     e, "npeet_5", suffix="_last_model"
                # ),
                # "difference_full_average_npeet_10": difference_in_full_metric(
                #     e, "npeet_10", suffix="_last_model"
                # ),
                # "difference_full_average_npeet_10_vote": difference_in_full_metric_vote(
                #     e, "npeet_10", suffix="_last_model"
                # ),
                "difference_full_hsic_vote": difference_in_full_metric_vote(
                    e, "hsic_estimate", suffix="_last_model"
                ),
                # "correct_full_averate_npeet": score_experiment_full_metric(
                #     e,
                #     "npeet",
                #     suffix="_last_model",
                # ),
                # "correct_full_vote_npeet": score_experiment_full_metric_vote(
                #     e,
                #     "npeet",
                #     suffix="_last_model",
                # ),
                # "correct_full_averate_npeet_5": score_experiment_full_metric(
                #     e,
                #     "npeet_5",
                #     suffix="_last_model",
                # ),
                # "correct_full_vote_npeet_5": score_experiment_full_metric_vote(
                #     e,
                #     "npeet_5",
                #     suffix="_last_model",
                # ),
                # "correct_full_averate_npeet_10": score_experiment_full_metric(
                #     e,
                #     "npeet_10",
                #     suffix="_last_model",
                # ),
                # "correct_full_vote_npeet_10": score_experiment_full_metric_vote(
                #     e,
                #     "npeet_10",
                #     suffix="_last_model",
                # ),
                # "correct_full_weighted_vote_npeet": score_experiment_full_metric_weighted_vote(
                #     e,
                #     "npeet",
                # ),
                # "correct_full_weighted_vote_hsic": score_experiment_full_metric_weighted_vote(
                #     e,
                #     "hsic_estimate",
                # ),
                # "correct_full_weighted_vote_correlation": score_experiment_full_metric_weighted_vote(
                #     e,
                #     "correlation",
                # ),
                # "correct_full_vote_npeet_cut_0_15": score_experiment_full_metric_vote(
                #     e,
                #     "npeet",
                #     start_timestep=0,
                #     end_timestep=15,
                # ),
                # "correct_full_vote_npeet_cut_0_50": score_experiment_full_metric_vote(
                #     e,
                #     "npeet",
                #     start_timestep=0,
                #     end_timestep=50,
                # ),
                # "correct_full_vote_npeet_cut_0_150": score_experiment_full_metric_vote(
                #     e,
                #     "npeet",
                #     start_timestep=0,
                #     end_timestep=150,
                # ),
                # "correct_full_vote_npeet_cut_70_120": score_experiment_full_metric_vote(
                #     e,
                #     "npeet",
                #     start_timestep=70,
                #     end_timestep=120,
                # ),
                # "correct_full_vote_mine": score_experiment_full_metric_vote(
                #     e,
                #     "mine_estimate",
                # ),
                # "correct_full_vote_mine_cut_0_100": score_experiment_full_metric_vote(
                #     e,
                #     "mine_estimate",
                #     start_timestep=0,
                #     end_timestep=100,
                # ),
            },
        }
        items.append(item)
    return pd.DataFrame(items)


def get_run_data(experiment_dir):
    """
    Load the run data from the specified experiment directory.

    Args:
        experiment_dir (str): Directory containing the experiment results.

    Returns:
        dict: Dictionary containing the loaded run data.
    """
    run_file = os.path.join(experiment_dir, "run.json")
    config_file = os.path.join(experiment_dir, "config.json")

    if not os.path.exists(run_file):
        raise FileNotFoundError(f"Run file not found: {run_file}")

    with open(run_file, "r") as f:
        run_data = json.load(f)

    with open(config_file, "r") as f:
        run_data["config"] = json.load(f)

    additional_info_file = os.path.join(experiment_dir, "additional_info.json")
    if os.path.exists(additional_info_file):
        with open(additional_info_file) as f:
            additional_info = json.load(f)
            run_data["additional_information"] = additional_info

    return run_data


def load_results(results_dir, id_range=None):

    if not os.path.exists(results_dir):
        print(f"Results directory not found: {results_dir}")
        return []

    all_results = []
    exp_ids = sorted([int(id) for id in os.listdir(results_dir) if id.isdigit()])

    skipped = 0

    if id_range is not None:
        start, end = id_range
        exp_ids = [id for id in exp_ids if start <= id < end]

    for exp_id in exp_ids:
        run_file = os.path.join(results_dir, str(exp_id), "run.json")
        config_file = os.path.join(results_dir, str(exp_id), "config.json")
        if os.path.exists(run_file):
            try:
                with open(run_file, "r") as f:
                    run_data = json.load(f)
                with open(config_file, "r") as f:
                    run_data["config"] = json.load(f)
                run_data["id"] = exp_id
                run_data["path"] = os.path.join(results_dir, str(exp_id))

                additional_info_file = os.path.join(
                    run_data["path"], "additional_info.json"
                )
                if os.path.exists(additional_info_file):
                    with open(additional_info_file) as f:
                        additional_info = json.load(f)
                        run_data["additional_information"] = additional_info

                if "result" not in run_data or run_data["status"] != "COMPLETED":
                    skipped += 1
                    continue

                all_results.append(run_data)
            except json.JSONDecodeError:
                print(f"Error decoding JSON for experiment {exp_id}")
                skipped += 1
                continue

    if skipped > 0:
        print(f"Skipped {skipped} some experiments missing results")
    return all_results


def delete_experiments(experiment_dir, filter_func, debug_flag=False):
    """
    Delete experiments that match a specific filter function

    Args:
        experiment_dir (str): Directory containing experiments
        filter_func (Callable): Function that takes experiment data and returns True if it should be deleted
        debug_flag (bool): If True, only print what would happen without deleting

    Returns:
        tuple: (deleted_count, total_count)
    """
    import os
    import json
    import shutil

    deleted_count = 0
    total_count = 0

    # Ensure the directory exists
    if not os.path.exists(experiment_dir):
        print(f"Directory not found: {experiment_dir}")
        return (0, 0)

    # List all experiment directories
    for exp_name in os.listdir(experiment_dir):
        exp_path = os.path.join(experiment_dir, exp_name)

        # Skip non-directories
        if not os.path.isdir(exp_path):
            continue

        total_count += 1

        # Look for config.json or params.json to evaluate
        config_path = os.path.join(exp_path, "config.json")
        if not os.path.exists(config_path):
            config_path = os.path.join(exp_path, "params.json")

        if not os.path.exists(config_path):
            print(f"Skipping {exp_name}: No config/params file found")
            continue

        # Load experiment configuration
        try:
            with open(config_path, "r") as f:
                config = json.load(f)
        except json.JSONDecodeError:
            print(f"Skipping {exp_name}: Invalid JSON in config file")
            continue

        # Apply filter function to determine if experiment should be deleted
        if filter_func(config):
            deleted_count += 1
            if debug_flag:
                print(f"Would delete: {exp_path}")
                print(f"{config}")
            else:
                print(f"Deleting: {exp_path}")
                try:
                    shutil.rmtree(exp_path)
                except Exception as e:
                    print(f"Error deleting {exp_path}: {e}")

    # Print summary
    if debug_flag:
        print(f"Would delete {deleted_count} out of {total_count} experiments")
    else:
        print(f"Deleted {deleted_count} out of {total_count} experiments")

    return (deleted_count, total_count)


def analyze_experiments(
    results_range,
    result_dir=None,
    results_dirs=None,
    filter_func=lambda x: x,
    filter_columns=None,
    only_last_seed=True,
    head=None,
):
    """
    Load experiment results, generate a dataframe, print the unique values for all grouping keys,
    and group the data only by keys that have more than one unique category to compute aggregation statistics.

    Args:
        results_dir (str): The directory where experiment results are stored.
        results_range (tuple): A tuple specifying the start and end indices for the results to load.

    Returns:
        pd.DataFrame: Aggregated results containing mean, variance, and count for the specified metrics.
    """

    if results_dirs is not None and result_dir is not None:
        raise ValueError("Specify either results_dirs or result_dir, not both.")

    if result_dir is not None:
        results_dirs = [result_dir]

    # Load experiments from the specified directory and range
    experiments = []
    for results_dir in results_dirs:
        experiments.extend(list(load_results(results_dir, results_range)))

    # Convert experiments to a dataframe
    df = generate_dataframe_from_experiments(experiments)
    df = filter_func(df)

    # Define all potential grouping keys
    group_keys = [
        "model_type",
        "type",
        "depth",
        "standardize",
        "transform_type",
        "noise_type",
        "epochs",
        "size",
        "optimizer_settings",
        "condition_dim",
        "condition_embed",
        # "train_loss",
        "mediator_noise",
        "max_noise",
        "noise_parameters",
        "transform_parameters",
        "pretrained_model",
        "train_loss_schedule",
        "name",
        "data_id",
        "n_samples",
        "hidden_dim",
        "num_blocks",
        "beta_schedule",
        "learning_rate",
        "mediator_noise_type",
        "mediator_noise_parameters",
        "mediator_noise_parameters_std",
        "bestN",
        "use_best_validation",
        "test_evaluation_passes",
        # "seed",
    ]

    # Define agg_columns as all columns in the dataframe that are not group keys
    agg_columns = [
        col
        for col in df.columns
        if col not in group_keys and np.issubdtype(df[col].dtype, np.number)
        # and col != "seed"
    ]

    agg_columns = [
        col for col in agg_columns if col in df.columns
    ]  # Filter out columns not present in the dataframe

    # Split grouping keys into those with multiple unique values and those with a single unique value
    smart_group_keys = [
        key
        for key in group_keys
        if key in df.columns
        and df[key].nunique() > 1
        or (key == "transform_type" or key == "noise_type")
    ]
    fixed_keys = [
        key for key in group_keys if key in df.columns and df[key].nunique() == 1
    ]

    # Print unique values for each grouping key used (keys with multiple unique values)
    if smart_group_keys:
        print("Grouping keys with multiple unique values:")
        for key in smart_group_keys:
            if key in df:
                unique_values = df[key].unique()
                print(f"{key}: {unique_values}")
    else:
        print("No grouping keys have more than one unique value.")

    # Print unique values for the keys that are fixed (only one unique value)
    if fixed_keys:
        print("\nOther keys with a single unique value:")
        for key in fixed_keys:
            if key in df:
                unique_values = df[key].unique()
                print(f"{key}: {unique_values}")

    if only_last_seed:
        df = df.groupby(["seed"] + smart_group_keys, as_index=False).last()

    if "seed" in df.columns:
        print(f"Seeds used: {df['seed'].unique()}")

    df[smart_group_keys] = df[smart_group_keys].fillna("N/A")

    # Group the dataframe and compute the aggregations if there are any variable keys,
    # otherwise aggregate over the entire dataframe.
    if smart_group_keys:
        group = df.groupby(smart_group_keys)
        if head:
            group = group.head(head).groupby(smart_group_keys)
        result = group[agg_columns].agg(["mean", "std", "count"])
    else:
        result = df[agg_columns].agg(["mean", "var", "count"])

    tooltip_texts = pd.DataFrame(
        "", index=result.index, columns=result.columns  # Default: empty
    )

    try:
        for method in result.columns.levels[0]:
            if (method, "mean") in result.columns:
                tooltip_texts[(method, "mean")] = result.index.map(
                    lambda idx: f"{idx}"
                ).to_list()
    except Exception as e:
        print(f"An error occurred: {e}")

    def adaptive_multiindex_row_separators(df):
        if not isinstance(df.index, pd.MultiIndex):
            # Handle single-level index: just check changes between rows
            styles = []
            for i in range(len(df)):
                if i == 0 or df.index[i] != df.index[i - 1]:
                    styles.append([""] * df.shape[1])
                else:
                    styles.append(["border-top: 2px solid black"] * df.shape[1])
            return pd.DataFrame(styles, index=df.index, columns=df.columns)

        # MultiIndex case
        n_levels = df.index.nlevels
        styles = []

        for i in range(len(df)):
            row_styles = [""] * df.shape[1]
            if i == 0:
                styles.append(row_styles)
                continue

            for level in range(n_levels):
                prev = df.index[i - 1][level]
                curr = df.index[i][level]
                if prev != curr:
                    thickness = max(n_levels - level - 1, 0)
                    row_styles = [f"border-top: {thickness}px solid black"] * df.shape[
                        1
                    ]
                    break

            styles.append(row_styles)

        return pd.DataFrame(styles, index=df.index, columns=df.columns)

    if filter_columns is not None:
        # Filter the columns based on the provided filter_columns at the top level index
        result = result.loc[:, result.columns.get_level_values(0).isin(filter_columns)]

    styled = (
        result.style.bar(
            subset=pd.IndexSlice[:, (slice(None), "mean")], cmap="coolwarm"
        )
        .set_tooltips(tooltip_texts)
        .apply(adaptive_multiindex_row_separators, axis=None)
        .set_table_styles(
            [
                # Align row index cells (row headings) to the top
                {"selector": "th.row_heading", "props": [("vertical-align", "top")]},
                # Optional: also align column headers
                {"selector": "th.col_heading", "props": [("vertical-align", "bottom")]},
            ]
        )
    )

    return styled, df
