import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as cm

plt.rcParams.update({'font.size': 16, 'font.family': 'serif'})

DATA_PATH = "output/exp/quantitative"
EXP_NAMES = {"cross_entropy_true": "Cross Entropy",
             "mmd_prior": "Kernel MMD",
             "kl_cls": "KL (Ratio Est.)",
             "closest_point": "Closest Point",
             "closest_point_dyn": "Closest Point (Dyn.)",
             "nn_point": "Closest Term. Point",
             "smooth_knn": "Smooth k-NN",
             "energy": "Energy",
             "mppi_cross_entropy_true": "MPPI (Cross Ent.)",
             "mppi_closest_point": "MPPI (Closest Pt.)"}
EXP_PROPS = {"cross_entropy_true": {"c": "tab:blue", "line": "--", "hatch": "/"},
             "closest_point": {"c": "tab:red", "line": "--", "hatch": "/"},
             "closest_point_dyn": {"c": "tab:purple", "line": "--", "hatch": "/"},
             "mppi_cross_entropy_true": {"c": "r", "line": "--", "hatch": "/"},
             "mppi_closest_point": {"c": "b", "line": "--", "hatch": "/"},
             "nn_point": {"c": "tab:olive", "line": "--", "hatch": "/"},
             "mmd_prior": {"c": "tab:green", "line": "-"},
             "kl_cls": {"c": "tab:orange", "line": "-"},
             "smooth_knn": {"c": "tab:pink", "line": "-"},
             "energy": {"c": "tab:cyan", "line": "-"}}
SCENES = ["circles_gaussian",
          "get_in_box_uni",
          "obstacles_mix",
          "squares_mix",
          "squares_uni"]

STD_ALPHA = 0.3
NUM_ITERS = 120

BASELINES = [
    "closest_point",
    "closest_point_dyn",
    "cross_entropy_true",
    # "mppi_cross_entropy_true",
    # "mppi_closest_point"
    # "nn_point",
]
SET_COSTS = [
    "energy",
    "kl_cls",
    "mmd_prior",
    "smooth_knn"
]

# Get all the experiments.
EXPERIMENTS = BASELINES + SET_COSTS
# EXPERIMENTS = []
# for exp_name in os.listdir(DATA_PATH):
#     exp_path = os.path.join(DATA_PATH, exp_name)
#     if not os.path.isdir(exp_path):
#         continue

#     EXPERIMENTS.append(exp_name)
N_EXP = len(EXPERIMENTS)


def get_valid_data(data):
    valid_idx, = np.nonzero(1 - np.any(np.isnan(data), axis=1))
    num_valid, = valid_idx.shape
    if num_valid < data.shape[0]:
        print("WARNING: Only {} / {} data sets are valid.".format(num_valid, data.shape[0]))
    return data[valid_idx]


def plot_chamfer(exps, all_data, ylims=None, linewidth=3, linestyle="-", plt_std=True):
    plt.figure(dpi=150)
    for i, exp_name in enumerate(exps):
        data = all_data[exp_name]
        N, T = data.shape

        cd_mean = np.mean(data, axis=0)
        cd_std = np.std(data, axis=0)

        ls = linestyle if not isinstance(linestyle, list) else linestyle[i]
        plt.plot(np.arange(T), cd_mean, label=EXP_NAMES[exp_name], c=EXP_PROPS[exp_name]["c"],
                 linewidth=linewidth, linestyle=ls)

        fill = plt_std if not isinstance(plt_std, list) else plt_std[i]
        if fill:
            plt.fill_between(np.arange(T), (cd_mean - cd_std), (cd_mean + cd_std),
                             alpha=STD_ALPHA, color=EXP_PROPS[exp_name]["c"])

    plt.tight_layout()

    plt.legend()
    plt.ylabel("Chamfer Distance")
    plt.xlabel("Time Step")
    plt.xlim(0, NUM_ITERS)
    if ylims is not None:
        plt.ylim(*ylims)
    plt.grid()


# Read all the data.
success_data = {}
path_length_data = {exp_name: {} for exp_name in EXPERIMENTS}
terminal_dist_data = {exp_name: [] for exp_name in EXPERIMENTS}
chamfer_data = {exp_name: [] for exp_name in EXPERIMENTS}
cross_ent_data = {exp_name: [] for exp_name in EXPERIMENTS}

for exp_name in EXPERIMENTS:
    exp_path = os.path.join(DATA_PATH, exp_name)

    all_success = []

    for scene_name in SCENES:
        scene_path = os.path.join(exp_path, scene_name)

        # Cross entropy.
        cross_ent_path = os.path.join(scene_path, "cross_entropy.npy")
        cross_ent = np.load(cross_ent_path)
        cross_ent = get_valid_data(cross_ent)  # Check for NaNs.

        cross_ent_data[exp_name].append(cross_ent)

        # Chamfer distance.
        cd_path = os.path.join(scene_path, "chamfer_dists.npy")
        chamfer = np.load(cd_path)
        chamfer = get_valid_data(chamfer)  # Check for NaNs.

        chamfer_data[exp_name].append(chamfer)

        # Path length.
        dist_path = os.path.join(scene_path, "path_dists.npy")
        path_length = np.load(dist_path)

        path_length_data[exp_name][scene_name] = path_length

        # Success rate.
        success_path = os.path.join(scene_path, "success.npy")
        success = np.load(success_path)

        all_success.append(success)

        # Terminal distance.
        dist_path = os.path.join(scene_path, "terminal_dists.npy")
        term_dist = np.load(dist_path)

        terminal_dist_data[exp_name].append(term_dist)

    # Combine data.
    cross_ent_data[exp_name] = np.concatenate(cross_ent_data[exp_name])
    chamfer_data[exp_name] = np.concatenate(chamfer_data[exp_name])
    terminal_dist_data[exp_name] = np.concatenate(terminal_dist_data[exp_name])

    # Compute success.
    all_success = np.concatenate(all_success)
    num_success = np.count_nonzero(all_success)
    num_total, = all_success.shape

    success_data[exp_name] = num_success / num_total

# Plot cross entropy over iteration data.
fig = plt.figure(dpi=150)
for i, exp_name in enumerate(EXPERIMENTS):
    exp_path = os.path.join(DATA_PATH, exp_name)

    data = cross_ent_data[exp_name]
    N, T = data.shape

    ce_mean = np.mean(data, axis=0)
    ce_std = np.std(data, axis=0)

    plt.plot(np.arange(T), ce_mean, label=EXP_NAMES[exp_name], c=EXP_PROPS[exp_name]["c"])
    plt.fill_between(np.arange(T), (ce_mean - ce_std), (ce_mean + ce_std),
                     alpha=STD_ALPHA, color=EXP_PROPS[exp_name]["c"])

plt.tight_layout()

plt.legend()
plt.ylabel("Cross Entropy")
# plt.yscale('log')
plt.xlabel("Time Step")
plt.xlim(0, NUM_ITERS)
# plt.ylim(top=500)
plt.ylim(-400, 1000)
# plt.ylim(0)
plt.grid()

# Plot Chamfer distance over iteration.
# plot_chamfer(EXPERIMENTS, chamfer_data)
# Change name just for this one.
old_mmd = EXP_NAMES["mmd_prior"]
EXP_NAMES["mmd_prior"] = "Goal Set (MMD)"
plot_chamfer(BASELINES + ["mmd_prior"], chamfer_data, ylims=[-10, 60])
EXP_NAMES["mmd_prior"] = old_mmd
plot_chamfer(SET_COSTS + ["cross_entropy_true"], chamfer_data, ylims=[-2, 15],
             linestyle=["-", "-", "-", "-", "--"], plt_std=[True, True, True, True, False])

# Success rate AOC.
LINEWIDTH = 3
MAX_DIST = 0.4
x = np.linspace(0, MAX_DIST)

fig = plt.figure(dpi=150)
for i, exp_name in enumerate(EXPERIMENTS):
    dists = terminal_dist_data[exp_name]
    n_dists, = dists.shape
    pass_rate = np.count_nonzero(dists < np.expand_dims(x, -1), -1) / n_dists

    plt.plot(x, pass_rate, label=EXP_NAMES[exp_name], linewidth=LINEWIDTH,
             c=EXP_PROPS[exp_name]["c"], linestyle=EXP_PROPS[exp_name]["line"])

plt.tight_layout()

plt.legend()
plt.xlim(0, MAX_DIST)
plt.ylim(top=1.)
plt.ylabel("Pass Rate (%)")
plt.xlabel("Final Distance to Goal (m)")
plt.grid()

# Path length bar chart.
TERM_THRESH = 0.15
DOT_SIZE = 8

means = []
max_pt = 0
fig = plt.figure(dpi=150)
for i, exp_name in enumerate(EXPERIMENTS):
    all_data = np.concatenate([path_length_data[exp_name][s] for s in SCENES])

    data = all_data[terminal_dist_data[exp_name] < TERM_THRESH]

    print("For exp", exp_name, "keeping", data.shape[0], "of", all_data.shape[0], "examples")

    means.append(np.mean(data))
    plt.scatter(i * np.ones(len(data), dtype=int), data, c=EXP_PROPS[exp_name]["c"], s=DOT_SIZE)
    max_pt = max(max_pt, data.max())

# Baselines have hatches.
colors = [EXP_PROPS[exp]["c"] for exp in BASELINES]
hatches = [EXP_PROPS[exp]["hatch"] for exp in BASELINES]
alpha_colors = [cm.to_rgba(c, alpha=0.3) for c in colors]
x_names = [EXP_NAMES[exp_name] for exp_name in BASELINES]

plt.bar(x_names, means[:len(BASELINES)], color=alpha_colors, edgecolor=colors, hatch=hatches)

# Methods, no hatches.
colors = [EXP_PROPS[exp]["c"] for exp in SET_COSTS]
alpha_colors = [cm.to_rgba(c, alpha=0.3) for c in colors]
x_names = [EXP_NAMES[exp_name] for exp_name in SET_COSTS]

plt.bar(x_names, means[len(BASELINES):], color=alpha_colors, edgecolor=colors)

plt.xticks(rotation=30, ha='right')
plt.ylabel("Path Length (m)")
plt.ylim(top=max_pt + 0.5)  # For some reason, the bar chart does not automatically find the top point.
plt.tight_layout()

print("\nPath Lengths:\n")
for scene_name in SCENES:
    print("***", scene_name, "***")
    for exp_name in EXPERIMENTS:
        data = path_length_data[exp_name][scene_name]
        print(EXP_NAMES[exp_name], "{:.2f}".format(np.mean(data)), "$\\pm$", "{:.2f}".format(np.std(data)))
    print()

print("\nPath Success:\n")
for exp_name in EXPERIMENTS:
    print(EXP_NAMES[exp_name], ":", success_data[exp_name])

plt.show()
