import itertools
import os
from random import random, randrange, uniform
import time
import csv
import numpy as np

import matplotlib.pyplot as plt

from scipy import stats

from expground.types import List
from expground.logger import Log


# plt.style.use('bmh')


def label_mapping(label):
    x = label.split("_")[-1]
    return {
        "epsro": "EPSRO",
        "psro": "PSRO",
        "iterative": "Mixed-Oracles",
        "sp": "Self-Play",
        "pepsro": "PEPSRO",
        "p2sro": "P-PSRO",
        "psrorn": "PSRO-rN",
    }[x]


def smooth(learning_curves: List[np.ndarray], ratio: float):
    for i, e in enumerate(learning_curves):
        last = 0
        tmp = []
        for _e in e:
            v = ratio * _e + last * (1.0 - ratio)
            tmp.append(v)
            last = v
        learning_curves[i] = tmp
    return learning_curves


def plot_error(data, label="", alpha=0.4, x="epoch", training=10000, simulation=1000):
    label = label_mapping(label)
    # clip data
    for k in data.keys():
        min_length = min(60, min([len(e) for e in data[k]]))
        assert min_length > 50, (min_length, label)
        data[k] = [e[:min_length] for e in data[k]]
    data_mean = np.mean(np.array(data["Value"]), axis=0)
    error_bars = stats.sem(np.array(data["Value"]))
    # if label == "P-PSRO":
    #     # scale the training epoch
    #     data_mean = data_mean[:100]
    #     error_bars = error_bars[:100]
    if x == "epoch":
        x_axis = np.arange(len(data_mean))
        print("epoch", x_axis.shape)
    elif x == "wall_time":
        x_axis = np.array(data["Wall time"])
        print(
            label,
            " max:",
            np.max(x_axis - np.min(x_axis, axis=1, keepdims=True), axis=1),
        )
        if "EPSRO" not in label:
            x_axis = np.mean(
                x_axis[1:] - np.min(x_axis, axis=1, keepdims=True)[1:], axis=0
            )
        else:
            x_axis = np.mean(x_axis - np.min(x_axis, axis=1, keepdims=True), axis=0)
        print("wall_time", x_axis.shape, data_mean.shape)
    elif x == "sample_efficiency":
        n = np.arange(len(data_mean))
        simulation_grows = np.cumsum((2 * n - 1) * simulation)
        if "EPSRO" in label:
            x_axis = n * training
        elif "P-PSRO" in label:
            x_axis = n * 10000 + np.cumsum((2 * n - 1) * 50)
        else:
            x_axis = n * training + simulation_grows
        print("sample", x_axis.shape)
    plt.plot(x_axis, data_mean, label=label)
    plt.fill_between(
        # [i for i in range(data_mean.size)],
        x_axis,
        np.squeeze(data_mean - error_bars),
        np.squeeze(data_mean + error_bars),
        alpha=alpha,
    )


def cast_data_type(k, v):
    if k in ["Value", "Wall time"]:
        return float(v)
    elif k == "Step":
        return int(v)
    else:
        raise ValueError


ENV = "Leduc Poker"
environment = ["leduc_poker"]

support = ["dqn"]

if ENV == "Kuhn Poker":
    algo = ["epsro", "psro", "iterative", "sp", "p2sro", "psrorn"]
else:
    algo = ["epsro", "psro", "iterative", "sp", "p2sro"]

data_dir = os.path.join("experiments/dataset")
results = {}
for e, s, a in itertools.product(environment, support, algo):
    group_dict = {"Wall time": [], "Step": [], "Value": []}
    results["{}_{}_{}".format(e, s, a)] = group_dict
    dir_path = os.path.join(data_dir, "{}_{}".format(e, s), a)
    raw_f_list = os.listdir(dir_path)
    f_list = []
    for f_name in raw_f_list:
        if os.path.splitext(f_name)[-1] == ".csv":
            f_list.append(f_name)
    Log.info("parse {} csv files under: {}".format(len(f_list), dir_path))
    for f_name in f_list:
        tmp = {"Wall time": [], "Step": [], "Value": []}
        f_path = os.path.join(dir_path, f_name)
        with open(f_path, "r") as f:
            reader = csv.DictReader(f)
            step = 0
            for line in reader:
                for k, v in dict(line).items():
                    assert v is not None, v
                    v = cast_data_type(k, v)
                    tmp[k].append(v)
                step += 1
        for k, v in tmp.items():
            if len(v) == 0:
                continue
            group_dict[k].append(v)

time_string = time.strftime("%Y%m%d-%H%M%S")
PATH_RESULTS = os.path.join("results", "{}/{}".format(environment[0], time_string))
if not os.path.exists(PATH_RESULTS):
    os.makedirs(PATH_RESULTS)
legends = list(results.keys())

for j in range(3):
    fig_handle = plt.figure()
    plt.yscale("log")
    if j == 0:
        plt.xlabel("Iteration")
        string = "NashConv"
        for k in legends:
            plot_error(results[k], label=k)
        title = "NashConv on {}".format(ENV)
    if j == 1:
        plt.xlabel("Walltime (seconds)")
        string = "NashConv_walltime"
        for k in legends:
            plot_error(results[k], label=k, x="wall_time")
    if j == 2:
        plt.xlabel("Number of Episodes")
        string = "NashConv_sample"
        for k in legends:
            plot_error(
                results[k],
                label=k,
                x="sample_efficiency",
                training=10000,
                simulation=1000,
            )

    plt.title(title)
    plt.ylabel("NashConv")
    plt.legend(loc="upper right")
    fig_handle.tight_layout()
    plt.savefig(os.path.join(PATH_RESULTS, "figure_" + string + ".pdf"))
