import numpy as np

import logging
import matplotlib
from matplotlib import patches

matplotlib.use("Agg")  # Set the backend before importing pyplot
import matplotlib.pyplot as plt

logging.getLogger("matplotlib").setLevel(logging.ERROR)


def per_episode_in_context(eval_res, name, ylim=None, max_return=None, max_return_eps=None):
    rets = np.vstack([h for h in eval_res.values()])
    means = rets.mean(0)
    stds = rets.std(0)
    x = np.arange(1, rets.shape[1] + 1)

    fig, ax = plt.subplots(dpi=100)
    ax.grid(visible=True)
    if ylim is not None:
        ax.set_ylim(ylim[0], ylim[1])
    ax.plot(x, means)
    ax.fill_between(x, means - stds, means + stds, alpha=0.2)

    ax.set_ylabel("Return")
    ax.set_xlabel("Episodes In-Context")
    ax.set_title(f"{name}")

    if max_return is not None:
        ax.axhline(
            max_return,
            ls="--",
            color="goldenrod",
            lw=2,
            label=f"optimal_return: {max_return:.2f}",
        )
    if max_return_eps is not None:
        ax.axhline(
            max_return_eps,
            ls="--",
            color="indigo",
            lw=2,
            label=f"max_perf_return: {max_return_eps:.2f}",
        )
    if max_return_eps is not None or max_return is not None:
        plt.legend()

    fig.savefig(f"rets_vs_eps_{name}.png")
    plt.close()

    return f"rets_vs_eps_{name}.png"


def plot_episode_stats(eval_res, eval_dones, stat, name, ylim=None, max_return=None, max_return_eps=None):
    eval_res = np.array(eval_res)
    eval_dones = np.array(eval_dones)
    x = np.arange(1, eval_res.shape[0] + 1)

    fig, ax = plt.subplots(dpi=100)
    ax.grid(visible=True)
    if ylim is not None:
        ax.set_ylim(ylim[0], ylim[1])
    ax.plot(x, eval_res)

    ax.set_ylabel(f"{stat}")
    ax.set_xlabel("Steps In-Context")
    ax.set_title(f"{name}")

    # if max_return is not None:
    #     ax.axhline(
    #         max_return,
    #         ls="--",
    #         color="goldenrod",
    #         lw=2,
    #         label=f"optimal_return: {max_return:.2f}",
    #     )
    # if max_return_eps is not None:
    #     ax.axhline(
    #         max_return_eps,
    #         ls="--",
    #         color="indigo",
    #         lw=2,
    #         label=f"max_perf_return: {max_return_eps:.2f}",
    #     )
    # if max_return_eps is not None or max_return is not None:
    #     plt.legend()
    done_steps = np.where(eval_dones != 0)[0]
    for step in done_steps:
        ax.axvline(x=step + 1, color='r', linestyle='--', alpha=0.3)  # Add vertical line at each done step

    fig.savefig(f"{stat}_{name}.png")
    plt.close()

    return f"{stat}_{name}.png"


def plot_semi_circle(eval_states, eval_dones, goal, name, alpha=1.0, legend=True):
    eval_states = np.array(eval_states)
    eval_dones = np.array(eval_dones)

    fig, ax = plt.subplots()

    # Create a circle with radius 1, centered at (0, 0)
    circle = plt.Circle((0, 0), radius=1, edgecolor='black', facecolor='none')
    target = plt.Circle(goal, radius=0.2, edgecolor='g', facecolor='g', alpha=0.3)
    # Add the circle to the plot
    ax.add_patch(circle)
    ax.add_patch(target)

    trajs = []
    traj = []
    cmap = plt.get_cmap('plasma')

    for p, d in zip(eval_states, eval_dones):
        traj.append(p)
        if d != 0:
            trajs.append(traj)
            traj = []

    for i, traj in enumerate(trajs):
        t = np.array(traj)
        color = cmap(i / len(trajs))
        ax.plot(t[:, 0][:-1], t[:, 1][:-1], marker='o', linestyle='-', label=f"Episode {i}", markersize=1, alpha=alpha, color=color)

    # Set the limits of the plot to ensure the circle is fully visible
    ax.set_xlim(-1.1, 1.1)
    ax.set_ylim(-1.1, 1.1)

    # Set aspect ratio to be equal to ensure the circle isn't distorted
    ax.set_aspect('equal')
    if legend:
        ax.legend()

    # Display the plot
    fig.savefig(f"SemiCircle_{name}.png")
    plt.close()

    return f"SemiCircle_{name}.png"


def draw_square(center, ax, face='none'):
    # Calculate the bottom-left corner based on the center and edge size
    edge_size = 1
    bottom_left = (center[0] - edge_size / 2, center[1] - edge_size / 2)

    # Create the square (Rectangle with equal width and height)
    square = patches.Rectangle(bottom_left, edge_size, edge_size, edgecolor='black', facecolor=face)

    # Add the square to the plot
    ax.add_patch(square)


def draw_grid(ax):
    # Create a figure and axis
    # Loop through the grid centers and draw squares
    for x in range(9):
        for y in range(9):
            draw_square((x, y), ax)

    # Set the limits of the plot to fit the grid
    ax.set_xlim(-1, 9)
    ax.set_ylim(-1, 9)

    # Set aspect ratio to be equal to ensure the squares aren't distorted
    ax.set_aspect('equal')


def plot_dr(eval_states, eval_dones, goal, name):
    fig, ax = plt.subplots()
    draw_grid(ax)
    draw_square(goal, ax, 'green')
    eval_states = np.array(list(map(lambda x: divmod(x, 9), eval_states)))
    eval_dones = np.array(eval_dones)

    trajs = []
    traj = []

    cnt = 0
    for p, d in zip(eval_states, eval_dones):
        cnt += 1
        if d != 0:
            if len(traj) < 20 and cnt != len(eval_dones):
                traj.append(goal)
            trajs.append(traj)
            traj = []
        traj.append(p)

    # print(trajs)
    cmap = plt.get_cmap('plasma')

    for i, traj in enumerate(trajs):
        t = np.array(traj)
        # print(len(t))
        color = cmap(i / len(trajs))
        # print(i, t, flush=True)
        ax.plot(t[:, 0], t[:, 1], marker='o', linestyle='-', label=f"Episode {i}", markersize=1, alpha=0.5, color=color)
    #
    # # Set the limits of the plot to ensure the circle is fully visible
    # ax.set_xlim(-1.1, 1.1)
    # ax.set_ylim(-1.1, 1.1)

    # Set aspect ratio to be equal to ensure the circle isn't distorted
    ax.set_aspect('equal')
    ax.legend(loc='upper left', bbox_to_anchor=(1, 1))

    # Display the plot
    fig.savefig(f"DarkRoom_{name}.png")
    plt.close()

    return f"DarkRoom_{name}.png"


def plot_goal_pred(eval_predictions, goal, norm_coef, name):
    goal = goal / norm_coef
    # print(trajs)
    cmap = plt.get_cmap('plasma')

    n_comp = goal.shape[0]

    for i in range(n_comp):
        t = np.array(eval_predictions)[:, i]
        # print(len(t))
        color = cmap(i / n_comp)
        # print(i, t, flush=True)
        plt.plot(t, linestyle='-', label=f"Component {i} prediction", alpha=1.0, color=color)
        plt.plot([goal[i]] * len(t), linestyle='--', label=f"Component {i} true", alpha=1.0, color=color)
    #
    # # Set the limits of the plot to ensure the circle is fully visible
    # ax.set_xlim(-1.1, 1.1)
    # ax.set_ylim(-1.1, 1.1)

    plt.legend()

    # Display the plot
    plt.savefig(f"TargetPred_{name}.png")
    plt.close()

    return f"TargetPred_{name}.png"
