import os
import json
import numpy as np

import matplotlib.pyplot as plt

from algorithms.utils.utils import (
    get_run_arrays,
    get_iqm_mean_conf,
    get_iqm_mean_conf_parallel,
)
from algorithms.utils import COLORS, LABELS, ORDER


def analyse_one_path(path, save_path):

    with open(os.path.join(path, "params.json"), "r") as file:
        params = json.load(file)
    models_layers = params["hidden_layers"]
    model_seeds = params["m_seeds"]

    # set paths and colors
    adqn_path = os.path.join(path, "adqn")
    adqn_paths = [(f, os.path.join(adqn_path, f)) for f in os.listdir(adqn_path)]
    dqn_path = os.path.join(path, "dqn")
    model_paths = [
        (model, os.path.join(dqn_path, f"{model}_{seed}"))
        for model, seed in zip(models_layers, model_seeds)
    ]

    # ----------------compare returns dqn - adqn-------------------------
    plt.figure()
    ylabel = "average return"

    # adqn
    for i, [crit, adqn_path] in enumerate(adqn_paths):
        if crit == "eps_min":
            returns, _ = get_run_arrays("avg_epoch_ret", adqn_path)
            mean, conf = get_iqm_mean_conf(returns)
            xs = np.arange(returns.shape[1])
            plt.plot(
                xs,
                mean,
                color=COLORS[f"aDQN_{crit}"],
                label=LABELS[f"aDQN_{crit}"],
                zorder=ORDER[f"aDQN_{crit}"],
            )
            plt.fill_between(
                xs,
                conf[1],
                conf[0],
                color=COLORS[f"aDQN_{crit}"],
                zorder=ORDER[f"aDQN_{crit}"],
                alpha=0.2,
            )

    # dqn
    for i, [arch, model_path] in enumerate(model_paths):
        returns, _ = get_run_arrays("avg_epoch_ret", model_path)
        mean, conf = get_iqm_mean_conf(returns)
        xs = np.arange(returns.shape[1])
        plt.fill_between(
            xs,
            conf[1],
            conf[0],
            color=COLORS[f"DQN_{i}"],
            zorder=ORDER[f"DQN_{i}"],
            alpha=0.2,
        )
        plt.plot(
            xs,
            mean,
            color=COLORS[f"DQN_{i}"],
            linestyle="--",
            label=arch,
            zorder=ORDER[f"DQN_{i}"],
        )

    plt.ylim(-300, 300)
    plt.xlabel("epoch")
    plt.ylabel(ylabel)
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.savefig(f"../../plots/{save_path}.png")
    plt.savefig(f"../../plots/{save_path}.pdf")
    plt.close()


def create_subfigure(path, ax):

    with open(os.path.join(path, "params.json"), "r") as file:
        params = json.load(file)
    models_layers = params["hidden_layers"]
    model_seeds = params["m_seeds"]

    # set paths and colors
    adqn_path = os.path.join(path, "adqn")
    adqn_paths = [(f, os.path.join(adqn_path, f)) for f in os.listdir(adqn_path)]
    dqn_path = os.path.join(path, "dqn")
    model_paths = [
        (model, os.path.join(dqn_path, f"{model}_{seed}"))
        for model, seed in zip(models_layers, model_seeds)
    ]

    # ----------------compare returns dqn - adqn-------------------------

    # adqn
    for i, [crit, adqn_path] in enumerate(adqn_paths):
        if crit == "eps_min":
            returns, _ = get_run_arrays("avg_epoch_ret", adqn_path)
            mean, conf = get_iqm_mean_conf(returns)
            xs = np.arange(returns.shape[1])
            ax.plot(
                xs,
                mean,
                color=COLORS[f"aDQN_{crit}"],
                label=LABELS[f"aDQN_{crit}"],
                zorder=ORDER[f"aDQN_{crit}"],
            )
            ax.fill_between(
                xs,
                conf[1],
                conf[0],
                color=COLORS[f"aDQN_{crit}"],
                zorder=ORDER[f"aDQN_{crit}"],
                alpha=0.2,
            )

    # dqn
    for i, [arch, model_path] in enumerate(model_paths):
        returns, _ = get_run_arrays("avg_epoch_ret", model_path)
        mean, conf = get_iqm_mean_conf(returns)
        xs = np.arange(returns.shape[1])
        ax.fill_between(
            xs,
            conf[1],
            conf[0],
            color=COLORS[f"DQN_{i}"],
            zorder=ORDER[f"DQN_{i}"],
            alpha=0.2,
        )
        ax.plot(
            xs,
            mean,
            color=COLORS[f"DQN_{i}"],
            linestyle="--",
            label=arch,
            zorder=ORDER[f"DQN_{i}"],
        )

    ax.set_ylim(-300, 300)
    ax.set_xlabel("epoch")
    ax.set_ylabel("average return")
    ax.grid()


def create_tup_ablation_figure(paths, tuf_list, save_path):
    fig, axs = plt.subplots(1, 3, figsize=(15, 4))

    # Loop to create subplots
    for i in range(3):
        ax = axs[i]

        create_subfigure(paths[i], ax)

        # Set subplot title
        ax.set_title(f"$T = {tuf_list[i]}$")

    # Set common legend
    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc="lower center", ncol=5)

    # Adjust layout and save the plot
    plt.tight_layout(rect=[0, 0.06, 1, 1])
    plt.savefig(f"../../plots/{save_path}.png")
    plt.savefig(f"../../plots/{save_path}.pdf")
    plt.show()


def create_comparison_figure(path, save_path):
    # set paths and colors
    adqn_path = os.path.join(path, "adqn")
    adqn_paths = [(f, os.path.join(adqn_path, f)) for f in os.listdir(adqn_path)]

    # ----------------compare criterions-------------------------
    plt.figure()
    ylabel = "average return"

    # adqn
    label_order = ["eps_min", "min", "random", "max"]
    for crit in label_order:
        for i, (crit_i, adqn_path) in enumerate(adqn_paths):
            if crit_i == crit:
                returns, _ = get_run_arrays("avg_epoch_ret", adqn_path)
                mean, conf = get_iqm_mean_conf(returns)
                xs = np.arange(returns.shape[1])
                plt.plot(
                    xs,
                    mean,
                    color=COLORS[f"aDQN_{crit}"],
                    label=LABELS[f"aDQN_{crit}"],
                    zorder=ORDER[f"aDQN_{crit}"],
                )
                plt.fill_between(
                    xs,
                    conf[1],
                    conf[0],
                    color=COLORS[f"aDQN_{crit}"],
                    zorder=ORDER[f"aDQN_{crit}"],
                    alpha=0.2,
                )
                break

    plt.ylim(-300, 300)
    plt.xlabel("epoch")
    plt.ylabel(ylabel)
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.savefig(f"../../plots/{save_path}.png")
    plt.savefig(f"../../plots/{save_path}.pdf")
    plt.close()


def create_val_ret_ablation_figure(
    path, save_path, seed, criterion, start_upd=0, end_upd=40
):
    col_width = 0.1

    with open(os.path.join(path, "params.json"), "r") as file:
        params = json.load(file)
    models_layers = params["hidden_layers"]

    # set paths
    adqn_path = os.path.join(path, "adqn", criterion)
    seed_eval_path = os.path.join(adqn_path, "seed_eval")
    os.makedirs(seed_eval_path, exist_ok=True)

    # ----------------validation returns--------------------------
    ylabel = "average return"
    plt.ylim(-400, 300)

    # get arrays
    returns, reg_index = get_run_arrays(
        "returns", adqn_path
    )  # shape(n_seeds, target_updates, regressors, n_runs)
    if end_upd == -1:
        end_upd = returns.shape[1]

    losses, _ = get_run_arrays(
        "losses", adqn_path
    )  # shape(n_seeds, target_updates, regressors)

    # treat episodes as seeds and calc iqm over episodes
    returns = np.transpose(returns[seed][start_upd:end_upd, :, :], (1, 2, 0))
    losses = losses[seed, start_upd:end_upd, :]  # shape (target_updates, regressors)
    reg_index = reg_index[seed, start_upd:end_upd]

    # calculate confidences
    sorted_losses = np.sort(losses, 1)
    confidence = (sorted_losses[:, 1] - sorted_losses[:, 0]) / sorted_losses[:, 1]
    norm_confidence = (confidence / np.max(confidence)) * 0.3

    for i, arch in enumerate(models_layers):
        mean, conf = get_iqm_mean_conf_parallel(returns[i])
        xs = np.arange(start_upd, end_upd)
        plt.fill_between(
            xs,
            conf[1],
            conf[0],
            alpha=0.2,
            color=COLORS[f"DQN_{i}"],
        )
        plt.plot(
            xs,
            mean,
            label=arch,
            color=COLORS[f"DQN_{i}"],
        )

    # fill background with regressor choice
    bottom, top = plt.ylim()
    for i in range(len(xs) - 1):
        unique, counts = np.unique(reg_index[i], return_counts=True)
        cum_percent = 0
        for j in range(len(models_layers)):
            if j in unique:
                reg_bottom = bottom + (top - bottom) * cum_percent
                cum_percent += counts[np.where(unique == j)] / np.sum(counts)
                reg_top = bottom + (top - bottom) * cum_percent
                plt.fill_between(
                    [xs[i], xs[i + 1]],
                    reg_bottom,
                    reg_top,
                    color=COLORS[f"DQN_{j}"],
                    alpha=norm_confidence[i],
                )

    plt.xlabel("target updates")
    plt.ylabel(ylabel)
    plt.xlim(xs[0], xs[-1])
    plt.legend()
    fig = plt.gcf()
    fig.set_size_inches(col_width * (end_upd - start_upd), 5)
    plt.tight_layout()
    plt.savefig(f"../../plots/{save_path}.png")
    plt.savefig(f"../../plots/{save_path}.pdf")
    plt.close()


if __name__ == "__main__":
    # path = "runs/200-200_100-100_50-50_25-25_call_tuf200_mspe6000_ne80"
    # analyse_one_path(path, "adqn_4_archs")
    # create_comparison_figure(path, "criterion_comparison")

    # path = "runs/100-100_50-50-50_25-25-25-25_call_tuf200_mspe6000_ne80"
    # analyse_one_path(path, "adqn_3_archs")

    # path = "runs/100-100_100-100_100-100_call_tuf200_mspe6000_ne80"
    # analyse_one_path(path, "adqn_same_arch")

    # folder = os.path.join(
    #     "runs", "100-100_50-50-50_25-25-25-25_call_tuf200_mspe6000_ne60"
    # )

    # create_val_ret_ablation_figure(
    #     folder,
    #     "val_return_ablation",
    #     seed=0,
    #     criterion="eps_min",
    #     start_upd=1260,
    #     end_upd=1400,
    # )

    paths = [
        "runs/100-100_50-50-50_25-25-25-25_call_tuf200_mspe6000_ne80",
        "runs/100-100_50-50-50_25-25-25-25_call_tuf1000_mspe6000_ne80",
        "runs/100-100_50-50-50_25-25-25-25_call_tuf2000_mspe6000_ne80",
    ]
    create_tup_ablation_figure(paths, ["200", "1000", "2000"], "adqn_tuf_ablation")
