import os
import json
import sys
import csv

import argparse

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import make_interp_spline


def read_csv(exp_name, keys):
    data = {}
    file_name = os.path.join("data", exp_name, "progress.csv")
    all_data = np.genfromtxt(file_name, delimiter=",", names=True)
    for key in keys:
        assert key in all_data.dtype.names, "{} not in data keys: {}".format(
            key, all_data.dtype.names
        )
        data[key] = all_data[key]

    return data


def preprocess_data(X, Y, filter_outliers=False, window=1, spline=False):
    ### Removing NaNs and infs for multistage training stuff
    inan = ~np.isnan(Y)
    X = X[inan]
    Y = Y[inan]

    ### Filter Outliers
    if filter_outliers:
        mean, std = np.median(Y), np.std(Y) + 1e-5
        dist = abs(Y - mean)
        not_outlier = dist < 3 * std
        X = X[not_outlier]
        Y = Y[not_outlier]

    Y = np.convolve(Y, np.ones(window) / window, mode="same")
    X = X[:-window]
    Y = Y[:-window]

    if spline:
        X_Y_Spline = make_interp_spline(X, Y)

        # Returns evenly spaced numbers
        # over a specified interval.
        X = np.linspace(X.min(), X.max(), 500)
        Y = X_Y_Spline(X)

    return X, Y


def make_plot(
    exp_names: list,
    y_key: str,
    y_keys: list,
    x_key: str,
    filter_outliers: bool,
    window: int,
) -> None:
    plt.figure()

    if y_keys is not None:
        assert len(exp_names) == 1
        file_name = os.path.join(exp_names[0], "progress.csv")
        data = np.genfromtxt(file_name, delimiter=",", names=True)
        print(data.dtype.names)
        for ykey in y_keys:
            X = data[x_key]
            Y = data[ykey]
            X, Y = preprocess_data(X, Y, filter_outliers=filter_outliers, window=window)
            plt.plot(X, Y)

        plt.xlabel(x_key)
        plt.legend(y_keys)

    else:
        legend_keys = []
        for exp_name in exp_names:
            file_name = os.path.join(exp_name, "progress.csv")
            data = np.genfromtxt(file_name, delimiter=",", names=True)
            print(data.dtype.names)
            X = data[x_key]
            Y = data[y_key]
            X, Y = preprocess_data(X, Y, filter_outliers=filter_outliers, window=window)
            plt.plot(X, Y)

            metadata_name = os.path.join(exp_name, "variant.json")
            metadata = json.load(open(metadata_name))
            if "name" in metadata.keys() and metadata["name"] is not None:
                legend_keys.append(metadata["name"])
            else:
                legend_keys.append(exp_name)

        plt.ylabel(y_key)
        plt.xlabel(x_key)
        plt.legend(legend_keys)

    plt.show()


def make_multipolicy_plot_grid(
    exp_names,
    n_policies,
    y_key,
    x_key="EvaluationIteration",
    filter_outliers=True,
    window=1,
):

    fig, axs = plt.subplots(1, n_policies, figsize=(n_policies * 4, 4))

    l_keys = []
    full_y_keys = [y_key.format(i) for i in range(n_policies)]
    data_keys = [x_key] + full_y_keys
    for exp_name in exp_names:
        data = read_csv(exp_name, data_keys)
        for i in range(n_policies):
            X = data[x_key]
            Y = data[y_key.format(i)]
            if filter_outliers:
                mean, std = np.median(Y), np.std(Y) + 1e-5
                dist = abs(Y - mean)
                not_outlier = dist < 3 * std
                X = X[not_outlier]
                Y = Y[not_outlier]
            Y = np.convolve(Y, np.ones(window) / window, mode="same")
            axs[i].plot(X[:-window], Y[:-window])

        metadata_name = os.path.join("data", exp_name, "variant.json")
        metadata = json.load(open(metadata_name))
        if "name" in metadata.keys() and metadata["name"] is not None:
            l_keys.append(metadata["name"])
        else:
            l_keys.append(exp_name)

    ### Label axis and stuff

    axs[0].set_ylabel(y_key.replace("{}", ""))

    for i in range(n_policies):
        axs[i].set_title("Policy {}".format(i))
        axs[i].set_xlabel(x_key)
    plt.legend(l_keys)


# make_multipolicy_plot_grid(
#     ["dnc_sac_ReacherCluster-v0_ent_norm_0"],
#     2,
#     "LocalPolicy{}AverageTrainAverageReturn",
# )


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("exp_names", nargs="+")
    parser.add_argument("--y_key", default=None)
    parser.add_argument("--y_keys", nargs="+", default="EvaluationAverageReturn")
    parser.add_argument("--x_key", default="TotalEnvSteps")
    parser.add_argument("--filter_outliers", type=bool, default=False)
    parser.add_argument("--window", type=int, default=1)

    args = parser.parse_args()
    make_plot(
        args.exp_names,
        args.y_key,
        args.y_keys,
        args.x_key,
        args.filter_outliers,
        args.window,
    )


if __name__ == "__main__":
    main()
