import math
import os

import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt

sns.set_context(context="paper", font_scale=0.68)
sns.set_style("white", {"font.family": "serif"})

_results_path = "tracking_results.csv"

# Reduce the table to read the mean and std of the results
RESULTS = pd.read_csv(os.path.expanduser(_results_path))
RESULTS = RESULTS.groupby(["environment", "method"])["success"].agg(["mean", "std"])
SEEDS = 3

ENV_NAMES = {
    "EnvApple-v0": "Apple",
    "EnvBottle-v0": "Bottle",
    "EnvDrawer-v0": "Drawer",
    "EnvFridge-v0": "Fridge",
    "EnvSponge-v0": "Sponge",
    "EnvHammer-v0": "Hammer",
    "EnvScissors-v0": "Scissors",
    "EnvPlier-v0": "Pliers",
}

METHOD_NAMES = {
    "replay": "Pre-recorded",
    "scripted": "Oracle",
    "gemini_trajectory_oracle": "Traj. Oracle",
    "gemini_keypoint_oracle": "Keypoint Oracle",
    "gemini": "Ours (Zero-Shot)",
    "gemini_few_shot_oracle": "Ours (Few-Shot)",
    "gemini_iteration_2_oracle": "2. Iteration",
    "gemini_iteration_3_oracle": "3. Iteration",
}


def _get_result(method, env):
    query = RESULTS.loc[(env, method)]
    return query["mean"], query["std"] / math.sqrt(SEEDS)


def results_table():
    methods = ["replay", "gemini", "gemini_few_shot_oracle", "scripted"]
    envs = list(ENV_NAMES.keys())
    scale_factor = 1.02
    figsize = (0.95 * scale_factor * len(envs), 1.2 * scale_factor)

    # Create figure
    fig = plt.figure(figsize=figsize)

    # Create gridspecs - one for legend, one for bars, one for shared x-axis
    gs = fig.add_gridspec(3, 1, height_ratios=[0.3, 0.7, 0.00], hspace=0)

    # Create subplot for legend
    legend_ax = fig.add_subplot(gs[0])
    legend_ax.set_axis_off()

    # Create gridspec for bar plots
    gs_bars = gs[1].subgridspec(1, len(envs), wspace=0)

    # Create the subplots for bars
    axes = []
    for i in range(len(envs)):
        ax = fig.add_subplot(gs_bars[i])
        if i > 0:  # Share y axis with first plot
            ax.sharey(axes[0])
        axes.append(ax)

    # Define colors for each method
    colors = {
        "replay": "#66BDF9",  # Converted from [102, 189, 249]
        "gemini_few_shot_oracle": "#FF9B00",  # Converted from (255, 155, 0)
        "gemini": "#FFD796",  # Converted from (255, 215, 150)
        "scripted": "#999999",  # Converted from (127, 127, 127)
    }

    # Define hatches for methods (diagonal stripes for oracle)
    hatches = {method: "///" if method == "scripted" else "" for method in methods}

    # Create the shared x-axis
    shared_ax = fig.add_subplot(gs[2])
    shared_ax.set_xlim(-0.5, len(methods) - 0.5)
    shared_ax.tick_params(length=4, width=0.8)
    shared_ax.set_xticks(range(len(methods)), labels=[])
    shared_ax.spines["top"].set_visible(False)
    shared_ax.spines["right"].set_visible(False)
    shared_ax.spines["left"].set_visible(False)
    shared_ax.set_yticks([])  # Remove y-ticks

    for i, env in enumerate(envs):
        ax = axes[i]
        # Remove grid lines
        ax.set_axisbelow(True)  # Place bars below other elements

        means = []
        stds = []
        for method in methods:
            mean, std = _get_result(method, env)
            means.append(mean * 100)  # Convert to percentage
            stds.append(std * 100)  # Convert to percentage

        # Plot bars without gaps between them
        x = np.arange(len(methods))
        bars = ax.bar(
            x,
            means,
            yerr=stds,
            capsize=3,
            color=[colors[m] for m in methods],
            zorder=3,
            width=1.0,
        )

        # Add hatches to the bars
        for bar, method, mean, std in zip(bars, methods, means, stds):
            bar.set_hatch(hatches[method])
            # Add value label above each bar and its error bar
            ax.text(
                bar.get_x() + bar.get_width() / 2,
                mean + std + 2,
                f"{mean:.0f}",
                ha="center",
                va="bottom",
                fontsize=5,
            )

        # Remove x axis completely
        ax.set_xticks([])
        ax.spines["bottom"].set_visible(False)

        sns.despine(ax=ax)
        ax.tick_params(axis="y", pad=-2, labelsize=5)
        ax.set_ylim(0, 100)
        ax.set_yticks(np.arange(0, 120, 20))

        # Only show y label for leftmost plot
        if i == 0:
            ax.set_ylabel("Success Rate (%)", labelpad=1)
        else:
            # Remove the vertical line between subplots
            ax.spines["left"].set_visible(False)
            ax.tick_params(axis="y", which="both", length=0)
            ax.set_yticklabels([])

        # Add title below plot, moved higher up
        ax.text(
            0.5,
            -0.04,
            ENV_NAMES[env],
            horizontalalignment="center",
            verticalalignment="top",
            transform=ax.transAxes,
            fontsize=8,
        )

    # Create legend above plots
    legend_elements = [
        plt.Rectangle(
            (0, 0),
            1,
            1,
            facecolor=colors[m],
            edgecolor="none",
            linewidth=0,
            label=METHOD_NAMES[m],
            hatch=hatches[m],
        )
        for m in methods
    ]
    legend_ax.legend(
        handles=legend_elements,
        loc="upper center",
        bbox_to_anchor=(0.5, 1.0),  # Moved up from 0.3 to 0.8
        ncol=len(methods),
        frameon=False,
        fontsize=7,
        columnspacing=1.2,
        handletextpad=0.5,
        handlelength=1.5,
    )

    plt.savefig("method_comparison.pdf", bbox_inches="tight", pad_inches=0)


def waypoint_scaling():
    envs = ["EnvApple-v0", "EnvHammer-v0", "EnvSponge-v0", "EnvPlier-v0"]
    methods = ["gemini_3", "gemini_5", "gemini_10", "gemini", "gemini_40"]
    waypoints = [3, 5, 10, 20, 40]
    scale_factor = 1.02
    figsize = (0.95 * scale_factor * len(envs), scale_factor)
    fig, axes = plt.subplots(1, len(envs), figsize=figsize, sharey=True)
    plt.subplots_adjust(wspace=0.3)  # Increased spacing between subplots

    for i, env in enumerate(envs):
        ax = axes[i]
        # Remove grid lines
        ax.set_axisbelow(True)  # Place lines below other elements

        y_mean, y_std = zip(*[_get_result(method, env) for method in methods])
        y_mean = [m * 100 for m in y_mean]  # Convert to percentage
        y_std = [s * 100 for s in y_std]  # Convert to percentage
        print(env, y_std)
        line = sns.lineplot(
            ax=ax,
            x=waypoints,
            y=y_mean,
            errorbar=None,
            marker="o",
            zorder=3,
            color="#FF9B00",
        )
        ax.fill_between(
            waypoints,
            np.array(y_mean) - np.array(y_std),
            np.array(y_mean) + np.array(y_std),
            alpha=0.2,
            zorder=2,
            color="#FF9B00",
        )
        ax.set_title(ENV_NAMES[env], pad=1)
        sns.despine(ax=ax)
        ax.tick_params(axis="y", pad=-2, labelsize=5)
        ax.tick_params(axis="x", pad=-2, labelsize=5)

        # Set x ticks and limits with padding
        ax.set_xticks([5, 10, 20, 40])
        ax.set_xlim(2, 41)  # Slightly wider than data range to avoid edge crowding

        ax.set_xlabel("Num Waypoints", labelpad=1)
        ax.set_ylim(0, 100)
        ax.set_yticks(np.arange(0, 120, 20))
        if i == 0:
            ax.set_ylabel("Success Rate (%)", labelpad=1)

    fig.suptitle("Waypoint Scaling", fontsize=8, y=0.95)
    plt.tight_layout(pad=0, rect=[0, 0, 1, 1 - 0.05])
    plt.savefig("waypoint_scaling.pdf")


def fewshot_scaling():
    envs = ["EnvFridge-v0", "EnvHammer-v0", "EnvScissors-v0"]
    methods = [
        "gemini",
        "gemini_few_shot_oracle",
        "gemini_iteration_2_oracle",
        "gemini_iteration_3_oracle",
    ]
    iterations = [0, 1, 2, 3]
    scale_factor = 1.02
    figsize = (0.95 * scale_factor * len(envs), scale_factor)
    fig, axes = plt.subplots(1, len(envs), figsize=figsize, sharey=True)
    plt.subplots_adjust(wspace=0.3)  # Increased spacing between subplots

    for i, env in enumerate(envs):
        ax = axes[i]
        # Remove grid lines
        ax.set_axisbelow(True)  # Place lines below other elements

        y_mean, y_std = zip(*[_get_result(method, env) for method in methods])
        y_mean = [m * 100 for m in y_mean]  # Convert to percentage
        y_std = [s * 100 for s in y_std]  # Convert to percentage
        line = sns.lineplot(
            ax=ax,
            x=iterations,
            y=y_mean,
            errorbar=None,
            marker="o",
            zorder=3,
            color="#FF9B00",
        )
        ax.fill_between(
            iterations,
            np.array(y_mean) - np.array(y_std),
            np.array(y_mean) + np.array(y_std),
            alpha=0.2,
            zorder=2,
            color="#FF9B00",
        )
        ax.set_title(ENV_NAMES[env], pad=5)
        sns.despine(ax=ax)
        ax.tick_params(axis="y", pad=0, labelsize=5)
        ax.tick_params(axis="x", pad=0, labelsize=5)

        # Set x ticks and limits with padding
        ax.set_xticks(iterations)
        ax.set_xlim(-0.5, 3.5)  # Add padding on both sides

        ax.set_xlabel("Iterations", labelpad=2)
        ax.set_ylim(0, 100)
        ax.set_yticks(np.arange(0, 120, 20))
        if i == 0:
            ax.set_ylabel("Success Rate (%)", labelpad=2)

    fig.suptitle("Few-shot Scaling", fontsize=8)  # Adjusted y position for title
    plt.tight_layout(pad=0, rect=[0, 0, 1, 1 - 0.05])
    plt.savefig("fewshot_scaling.pdf")


def oracle_plot():
    methods = ["gemini", "gemini_trajectory_oracle", "gemini_keypoint_oracle"]
    envs = ["EnvApple-v0", "EnvDrawer-v0", "EnvSponge-v0", "EnvPlier-v0"]
    scale_factor = 1.02
    figsize = (0.95 * scale_factor * len(envs), 1.2 * scale_factor)
    fig = plt.figure(figsize=figsize)

    # Create gridspecs - one for legend, one for bars, one for shared x-axis
    gs = fig.add_gridspec(3, 1, height_ratios=[0.3, 0.7, 0.00], hspace=0)

    # Create subplot for legend
    legend_ax = fig.add_subplot(gs[0])
    legend_ax.set_axis_off()

    # Create gridspec for bar plots
    gs_bars = gs[1].subgridspec(1, len(envs), wspace=0.1)

    # Create the subplots for bars
    axes = []
    for i in range(len(envs)):
        ax = fig.add_subplot(gs_bars[i])
        if i > 0:  # Share y axis with first plot
            ax.sharey(axes[0])
        axes.append(ax)

    # Create shared x-axis
    shared_ax = fig.add_subplot(gs[2])
    shared_ax.set_xlim(-0.5, len(methods) - 0.5)
    shared_ax.tick_params(length=4, width=0.8)
    shared_ax.set_xticks(range(len(methods)), labels=[])
    shared_ax.spines["top"].set_visible(False)
    shared_ax.spines["right"].set_visible(False)
    shared_ax.spines["left"].set_visible(False)
    shared_ax.set_yticks([])  # Remove y-ticks

    # Define colors for each method
    colors = {
        "gemini": "#FFD796",  # Converted from (255, 215, 150)
        "gemini_trajectory_oracle": "#9B7CD9",  # Darker violet
        "gemini_keypoint_oracle": "#A882A8",  # Darker purple
    }

    # Define hatches for methods
    hatches = {
        "gemini": "",
        "gemini_trajectory_oracle": "///",
        "gemini_keypoint_oracle": "///",
    }

    for i, env in enumerate(envs):
        ax = axes[i]
        # Remove grid lines
        ax.set_axisbelow(True)  # Place bars below other elements

        means = []
        stds = []
        for method in methods:
            mean, std = _get_result(method, env)
            means.append(mean * 100)  # Convert to percentage
            stds.append(std * 100)  # Convert to percentage

        x = np.arange(len(methods))
        bars = ax.bar(
            x,
            means,
            yerr=stds,
            capsize=3,
            color=[colors[m] for m in methods],
            zorder=3,
            width=1.0,
        )

        # Add hatches to the bars
        for bar, method, mean, std in zip(bars, methods, means, stds):
            bar.set_hatch(hatches[method])
            # Add value label above error bar
            ax.text(
                bar.get_x() + bar.get_width() / 2,
                mean + std + 2,
                f"{mean:.0f}",
                ha="center",
                va="bottom",
                fontsize=5,
            )

        # Add title below plot
        ax.text(
            0.5,
            -0.04,
            ENV_NAMES[env],
            horizontalalignment="center",
            verticalalignment="top",
            transform=ax.transAxes,
            fontsize=8,
        )
        ax.set_xticks([])  # Remove x ticks completely
        ax.spines["bottom"].set_visible(False)  # Hide bottom spine
        sns.despine(ax=ax)
        ax.tick_params(axis="y", pad=-2, labelsize=5)
        ax.tick_params(axis="x", pad=-2, labelsize=5)
        ax.set_ylim(50, 100)
        ax.set_yticks(np.arange(50, 101, 10))
        if i == 0:
            ax.set_ylabel("Success Rate (%)", labelpad=1)
            ax.tick_params(axis="y", pad=-2, labelsize=5, length=4, width=0.8)
            ax.spines["left"].set_visible(True)
        else:
            ax.spines["left"].set_visible(False)
            ax.tick_params(axis="y", which="both", length=0)
            ax.set_yticklabels([])

    # Create legend in the top subplot
    legend_elements = [
        plt.Rectangle(
            (0, 0),
            1,
            1,
            facecolor=colors[m],
            edgecolor="none",
            linewidth=0,
            label=METHOD_NAMES[m],
            hatch=hatches[m],
        )
        for m in methods
    ]
    legend_ax.legend(
        handles=legend_elements,
        loc="upper center",
        bbox_to_anchor=(0.5, 1.0),
        ncol=len(methods),
        frameon=False,
        fontsize=7,
        columnspacing=1.2,
        handletextpad=0.5,
        handlelength=1.5,
    )

    plt.savefig("oracle_comparison.pdf", bbox_inches="tight", pad_inches=0)


if __name__ == "__main__":
    # Plot the results
    results_table()
    oracle_plot()

    # Plot the gemini results
    waypoint_scaling()
    fewshot_scaling()
