import os
import json
import shutil

import matplotlib.pyplot as plt
import numpy as np

from algorithms.utils.utils import (
    get_run_arrays,
    get_iqm_mean_conf,
    get_iqm_mean_conf_parallel,
)
from algorithms.utils import LABELS, ORDER, COLORS


exclude = ["server", "debug", "deprecated"]


def delete_all_plots(seed_evals=False):
    run_dir = "runs"
    exp_dirs = [
        os.path.join(run_dir, f)
        for f in os.listdir("runs")
        if os.path.isdir(os.path.join(run_dir, f)) and f not in exclude
    ]
    for exp_dir in exp_dirs:
        for root, dirs, files in os.walk(exp_dir):
            for file in files:
                if file.endswith(".png") or file.endswith(".pdf"):
                    parent_folders = root.split(os.path.sep)
                    if "seed_eval" not in parent_folders or seed_evals:
                        os.remove(os.path.join(root, file))


def analyse_one_seed(path, seed, criterion, start_upd=0, end_upd=40):
    col_width = 0.1

    with open(os.path.join(path, "params.json"), "r") as file:
        params = json.load(file)
    models_layers = params["hidden_layers"]

    # set paths
    adqn_path = os.path.join(path, "adqn", criterion)
    seed_eval_path = os.path.join(adqn_path, "seed_eval")
    os.makedirs(seed_eval_path, exist_ok=True)

    # ----------------validation returns--------------------------
    ylabel = "average return"
    plt.ylim(-400, 300)

    # get arrays
    returns, reg_index = get_run_arrays(
        "returns", adqn_path
    )  # shape(n_seeds, target_updates, regressors, n_runs)
    if end_upd == -1:
        end_upd = returns.shape[1]

    losses, _ = get_run_arrays(
        "losses", adqn_path
    )  # shape(n_seeds, target_updates, regressors)

    # treat episodes as seeds and calc iqm over episodes
    returns = np.transpose(returns[seed][start_upd:end_upd, :, :], (1, 2, 0))
    losses = losses[seed, start_upd:end_upd, :]  # shape (target_updates, regressors)
    reg_index = reg_index[seed, start_upd:end_upd]

    # calculate confidences
    sorted_losses = np.sort(losses, 1)
    confidence = (sorted_losses[:, 1] - sorted_losses[:, 0]) / sorted_losses[:, 1]
    norm_confidence = (confidence / np.max(confidence)) * 0.3

    for i, arch in enumerate(models_layers):
        mean, conf = get_iqm_mean_conf_parallel(returns[i])
        xs = np.arange(start_upd, end_upd)
        plt.fill_between(
            xs,
            conf[1],
            conf[0],
            alpha=0.2,
            color=COLORS[f"DQN_{i}"],
        )
        plt.plot(
            xs,
            mean,
            label=arch,
            color=COLORS[f"DQN_{i}"],
        )

    # fill background with regressor choice
    bottom, top = plt.ylim()
    for i in range(len(xs) - 1):
        unique, counts = np.unique(reg_index[i], return_counts=True)
        cum_percent = 0
        for j in range(len(models_layers)):
            if j in unique:
                reg_bottom = bottom + (top - bottom) * cum_percent
                cum_percent += counts[np.where(unique == j)] / np.sum(counts)
                reg_top = bottom + (top - bottom) * cum_percent
                plt.fill_between(
                    [xs[i], xs[i + 1]],
                    reg_bottom,
                    reg_top,
                    color=COLORS[f"DQN_{j}"],
                    alpha=norm_confidence[i],
                )

    plt.xlabel("target updates")
    plt.ylabel(ylabel)
    plt.xlim(xs[0], xs[-1])
    plt.legend()
    fig = plt.gcf()
    fig.set_size_inches(col_width * (end_upd - start_upd), 5)
    plt.tight_layout()
    plt.savefig(os.path.join(seed_eval_path, f"{seed}_returns.pdf"))
    plt.savefig(os.path.join(seed_eval_path, f"{seed}_returns.png"), dpi=200)
    plt.close()

    # ----------------plot losses-----------------------------------
    plt.figure()
    ylabel = "loss"

    # adqn
    losses, reg_index = get_run_arrays(
        "losses", adqn_path
    )  # shape(n_seeds, target_updates, regressors, n_runs)
    reg_index = reg_index[seed, start_upd:end_upd]

    losses = losses[seed, start_upd:end_upd, :]
    for i, arch in enumerate(models_layers):
        xs = np.arange(start_upd, end_upd)
        plt.plot(
            xs,
            losses[:, i],
            label=arch,
            color=COLORS[f"DQN_{i}"],
        )

    # fill background with regressor choice
    bottom, top = plt.ylim()
    for i in range(len(xs) - 1):
        unique, counts = np.unique(reg_index[i], return_counts=True)
        cum_percent = 0
        for j in range(len(models_layers)):
            if j in unique:
                reg_bottom = bottom + (top - bottom) * cum_percent
                cum_percent += counts[np.where(unique == j)] / np.sum(counts)
                reg_top = bottom + (top - bottom) * cum_percent
                plt.fill_between(
                    [xs[i], xs[i + 1]],
                    reg_bottom,
                    reg_top,
                    color=COLORS[f"DQN_{j}"],
                    alpha=0.2,
                )

    plt.xlabel("target updates")
    plt.ylabel(ylabel)
    plt.xlim(xs[0], xs[-1])
    plt.legend()
    fig = plt.gcf()
    fig.set_size_inches(col_width * (end_upd - start_upd), 5)
    plt.tight_layout()
    plt.savefig(os.path.join(seed_eval_path, f"{seed}_losses.pdf"))
    plt.savefig(os.path.join(seed_eval_path, f"{seed}_losses.png"), dpi=200)
    plt.close()


def analyse_one_path(path, save_path=None):

    with open(os.path.join(path, "params.json"), "r") as file:
        params = json.load(file)
    models_layers = params["hidden_layers"]
    model_seeds = params["m_seeds"]

    # set paths and colors
    adqn_path = os.path.join(path, "adqn")
    adqn_paths = [(f, os.path.join(adqn_path, f)) for f in os.listdir(adqn_path)]
    dqn_path = os.path.join(path, "dqn")
    model_paths = [
        (model, os.path.join(dqn_path, f"{model}_{seed}"))
        for model, seed in zip(models_layers, model_seeds)
    ]

    # ----------------compare returns dqn - adqn-------------------------
    plt.figure()
    ylabel = "average return"

    # adqn
    for i, [crit, adqn_path] in enumerate(adqn_paths):
        if crit == "eps_min":
            returns, _ = get_run_arrays("avg_epoch_ret", adqn_path)
            mean, conf = get_iqm_mean_conf(returns)
            xs = np.arange(returns.shape[1])
            plt.plot(
                xs,
                mean,
                color=COLORS[f"aDQN_{crit}"],
                label=LABELS[f"aDQN_{crit}"],
                zorder=ORDER[f"aDQN_{crit}"],
            )
            plt.fill_between(
                xs,
                conf[1],
                conf[0],
                color=COLORS[f"aDQN_{crit}"],
                zorder=ORDER[f"aDQN_{crit}"],
                alpha=0.2,
            )

    # dqn
    for i, [arch, model_path] in enumerate(model_paths):
        returns, _ = get_run_arrays("avg_epoch_ret", model_path)
        mean, conf = get_iqm_mean_conf(returns)
        xs = np.arange(returns.shape[1])
        plt.fill_between(
            xs,
            conf[1],
            conf[0],
            color=COLORS[f"DQN_{i}"],
            zorder=ORDER[f"DQN_{i}"],
            alpha=0.2,
        )
        plt.plot(
            xs,
            mean,
            color=COLORS[f"DQN_{i}"],
            linestyle="--",
            label=arch,
            zorder=ORDER[f"DQN_{i}"],
        )

    plt.ylim(-300, 300)
    plt.xlabel("epoch")
    plt.ylabel(ylabel)
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.savefig(os.path.join(path, f"avg_epoch_ret.png"))
    plt.savefig(os.path.join(path, f"avg_epoch_ret.pdf"))
    plt.close()

    # ----------------compare criterions-------------------------
    if len(adqn_paths) > 1:
        plt.figure()
        ylabel = "average return"

        # adqn
        label_order = ["eps_min", "min", "random", "max"]
        for crit in label_order:
            for i, (crit_i, adqn_path) in enumerate(adqn_paths):
                if crit_i == crit:
                    returns, _ = get_run_arrays("avg_epoch_ret", adqn_path)
                    mean, conf = get_iqm_mean_conf(returns)
                    xs = np.arange(returns.shape[1])
                    plt.plot(
                        xs,
                        mean,
                        color=COLORS[f"aDQN_{crit}"],
                        label=LABELS[f"aDQN_{crit}"],
                        zorder=ORDER[f"aDQN_{crit}"],
                    )
                    plt.fill_between(
                        xs,
                        conf[1],
                        conf[0],
                        color=COLORS[f"aDQN_{crit}"],
                        zorder=ORDER[f"aDQN_{crit}"],
                        alpha=0.2,
                    )
                    break

        plt.ylim(-300, 300)
        plt.xlabel("epoch")
        plt.ylabel(ylabel)
        plt.legend()
        plt.grid()
        plt.tight_layout()
        plt.savefig(os.path.join(path, f"criterion_comparison.png"))
        plt.savefig(os.path.join(path, f"criterion_comparison.pdf"))
        plt.close()


def restructure_folders():
    run_dir = "runs"
    exp_dirs = [
        os.path.join(run_dir, f)
        for f in os.listdir("runs")
        if os.path.isdir(os.path.join(run_dir, f)) and f not in exclude
    ]
    for path in exp_dirs:
        name_parts = path.split("/")[-1].split("_")
        for i in range(len(name_parts)):
            if name_parts[i][0] == "c":
                criterion = name_parts[i][1:]
                break

        adqn_path = os.path.join(path, "adqn")

        if criterion == "all":
            if not os.path.exists(adqn_path):
                print(f'Restructuring "{path}"...')
                os.makedirs(adqn_path)

            for folder in os.listdir(path):
                if folder.startswith("adqn_"):
                    new_folder_name = folder.replace("adqn_", "")
                    shutil.move(
                        os.path.join(path, folder),
                        os.path.join(adqn_path, new_folder_name),
                    )

        else:
            returns_path = os.path.join(adqn_path, "returns")
            if os.path.exists(returns_path):
                print(f'Restructuring "{path}"...')
                new_folder_path = os.path.join(adqn_path, criterion)
                os.makedirs(new_folder_path, exist_ok=True)

                for item in os.listdir(adqn_path):
                    if item != criterion:
                        shutil.move(
                            os.path.join(adqn_path, item),
                            os.path.join(new_folder_path, item),
                        )


def analyse_all_paths():
    run_dir = "runs"
    exp_dirs = [
        os.path.join(run_dir, f)
        for f in os.listdir("runs")
        if os.path.isdir(os.path.join(run_dir, f)) and f not in exclude
    ]
    for exp_dir in exp_dirs:
        plot_path = os.path.join(exp_dir, "avg_epoch_ret.png")
        if not os.path.exists(plot_path):
            print(f'Analyzing "{exp_dir}"...')
            analyse_one_path(exp_dir)


if __name__ == "__main__":
    # delete_all_plots(seed_evals=False)
    analyse_all_paths()
    # restructure_folders()
    # analyse_one_path()

    # folder = os.path.join(
    #     "runs", "100-100_50-50-50_25-25-25-25_call_tuf200_mspe6000_ne60"
    # )
    # check_seeds = os.path.join(folder, "adqn", "eps_min", "avg_epoch_ret")
    # # seeds = range(len([name for name in os.listdir(check_seeds)]))
    # seeds = [0]

    # for s in seeds:
    #     print(f"Analyzing seed {s} for {folder}...")
    #     analyse_one_seed(
    #         folder,
    #         seed=s,
    #         criterion="eps_min",
    #         start_upd=1260,
    #         end_upd=1400,
    #     )
