import math
import os

import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt

sns.set_context(context="paper", font_scale=0.68)
sns.set_style("white", {"font.family": "serif"})

_results_path = "tracking_results.csv"

# Reduce the table to read the mean and std of the results
RESULTS = pd.read_csv(os.path.expanduser(_results_path))
RESULTS = RESULTS.groupby(["environment", "method"])["success"].agg(["mean", "std"])
SEEDS = 3

ENV_NAMES = {
    "EnvApple-v0": "Apple",
    "EnvBottle-v0": "Bottle",
    "EnvDrawer-v0": "Drawer",
    "EnvFridge-v0": "Fridge",
    "EnvSponge-v0": "Sponge",
    "EnvHammer-v0": "Hammer",
    "EnvScissors-v0": "Scissors",
    "EnvPlier-v0": "Pliers",
}

METHOD_NAMES = {
    "replay": "Replay",
    "scripted": "Oracle",
    "gemini_trajectory_oracle": "Traj. Oracle",
    "gemini_keypoint_oracle": "Keypoint Oracle",
    "gemini": "Ours",
    "gemini_few_shot_oracle": "Few-Shot",
    "gemini_iteration_2_oracle": "2. Iteration",
    # "gemini_iteration_3_oracle": "3. Iteration",
}


def _get_result(method, env):
    query = RESULTS.loc[(env, method)]
    return query["mean"], query["std"] / math.sqrt(SEEDS)


def results_table():
    methods = ["replay", "scripted", "gemini", "gemini_few_shot_oracle"]
    envs = list(ENV_NAMES.keys())  # Plot all envs
    spaces = 16

    header = (
        "&".join(
            ["\\textbf{Method} ".ljust(spaces)]
            + [f" \\textbf{{{ENV_NAMES[env]}}} ".ljust(spaces) for env in envs]
        )
        + "\\\\"
    )
    print(header)
    for method in methods:
        row = [f"{METHOD_NAMES[method]} "]
        for env in envs:
            mean, std = _get_result(method, env)
            row.append(f" {100 * mean:.1f} \\Std{{{100 * std:.1f}}} ".ljust(spaces))
        print("&".join(row) + "\\\\")


def waypoint_scaling():
    envs = ["EnvApple-v0", "EnvHammer-v0", "EnvSponge-v0", "EnvPlier-v0"]
    methods = ["gemini_3", "gemini_5", "gemini_10", "gemini", "gemini_40"]
    waypoints = [3, 5, 10, 20, 40]
    scale_factor = 1.02
    figsize = (0.95 * scale_factor * len(envs), scale_factor)
    fig, axes = plt.subplots(1, len(envs), figsize=figsize, sharey=True)

    for i, env in enumerate(envs):
        ax = axes.flat[i]
        y_mean, y_std = zip(*[_get_result(method, env) for method in methods])
        print(env, y_std)
        sns.lineplot(ax=ax, x=waypoints, y=y_mean, errorbar=None, marker="o")
        ax.fill_between(
            waypoints,
            np.array(y_mean) - np.array(y_std),
            np.array(y_mean) + np.array(y_std),
            alpha=0.2,
        )
        ax.set_title(ENV_NAMES[env], pad=1)
        sns.despine(ax=ax)
        ax.tick_params(axis="y", pad=-2, labelsize=5)
        ax.tick_params(axis="x", pad=-2, labelsize=5)
        ax.set_xticks([5, 10, 20, 40])
        ax.set_xlabel("Num Waypoints", labelpad=1)
        ax.set_ylim(0, 1)
        if i == 0:
            ax.set_ylabel("Success Rate", labelpad=1)

    fig.suptitle("Waypoint Scaling", fontsize=8, y=0.95)
    plt.tight_layout(pad=0, rect=[0, 0, 1, 1 - 0.05])
    plt.savefig("waypoint_scaling.pdf")


def fewshot_scaling():
    envs = ["EnvFridge-v0", "EnvHammer-v0", "EnvScissors-v0"]
    methods = ["gemini", "gemini_few_shot_oracle", "gemini_iteration_2_oracle"]
    iterations = [0, 1, 2]
    scale_factor = 1.02
    figsize = (0.95 * scale_factor * len(envs), scale_factor)
    fig, axes = plt.subplots(
        1, len(envs), figsize=figsize, sharey=True, constrained_layout=True
    )  # Using constrained_layout

    for i, env in enumerate(envs):
        ax = axes.flat[i]
        y_mean, y_std = zip(*[_get_result(method, env) for method in methods])
        sns.lineplot(ax=ax, x=iterations, y=y_mean, errorbar=None, marker="o")
        ax.fill_between(
            iterations,
            np.array(y_mean) - np.array(y_std),
            np.array(y_mean) + np.array(y_std),
            alpha=0.2,
        )
        ax.set_title(
            ENV_NAMES[env], fontsize=8, pad=5
        )  # Adjusted font size and padding
        sns.despine(ax=ax)
        ax.tick_params(axis="y", pad=0, labelsize=5)
        ax.tick_params(axis="x", pad=0, labelsize=5)
        ax.set_xticks(iterations)
        ax.set_xlabel("Iterations", labelpad=2)
        ax.set_ylim(0, 1)
        # Add padding between subplots
        if i == 0:
            ax.set_ylabel("Success Rate", labelpad=2)

    fig.suptitle("Few-shot Scaling", fontsize=10)  # Adjusted y position for title
    plt.tight_layout(pad=0, rect=[0, 0, 1, 1 - 0.05])
    plt.savefig("fewshot_scaling.pdf")


def oracle_table():
    envs = ["EnvApple-v0", "EnvDrawer-v0", "EnvSponge-v0", "EnvPlier-v0"]
    methods = ["gemini", "gemini_trajectory_oracle", "gemini_keypoint_oracle"]
    spaces = 16

    header = (
        "&".join(
            ["\\textbf{Method} ".ljust(spaces)]
            + [f" \\textbf{{{ENV_NAMES[env]}}} ".ljust(spaces) for env in envs]
        )
        + "\\\\"
    )
    print(header)
    for method in methods:
        row = [f"{METHOD_NAMES[method]} "]
        for env in envs:
            mean, std = _get_result(method, env)
            row.append(f" {100 * mean:.1f} \\Std{{{100 * std:.1f}}} ".ljust(spaces))
        print("&".join(row) + "\\\\")


if __name__ == "__main__":
    # Plot the results
    results_table()

    # Plot the gemini results
    waypoint_scaling()
    fewshot_scaling()
    oracle_table()
