import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.patches as patches


def plot_particles(ax, particles, px_pts=None, goal_samples=None, lims=None, title=None):
    ax.clear()

    if title is not None:
        ax.set_title(title)

    if px_pts is not None:
        ax.contourf(*px_pts, 30)

    if goal_samples is not None:
        ax.scatter(goal_samples[:, 0], goal_samples[:, 1], c='tab:orange',
                   zorder=1, label="Goal Samples")

    markers, = ax.plot(particles[:, 0], particles[:, 1], c='tab:cyan', marker='o',
                       linestyle='None', zorder=1, label="Stein Particles")

    if goal_samples is not None:
        ax.legend()

    if lims is not None:
        ax.set_xlim(*lims[:2])
        ax.set_ylim(*lims[2:])

    return markers


def plot_metrics(ax, vals, labels, max_iter=None):
    markers = []
    for val, label in zip(vals, labels):
        m, = ax[label].plot(np.arange(len(val)), val, marker="x", label=label)
        markers.append(m)

        if max_iter is not None:
            ax[label].set_xlim(0, max_iter)

        ax[label].grid(which='both')
        ax[label].legend()

    return markers


def animate_particles(particles, px_pts, goal_samples=None, metric_vals=None, metric_labels=None,
                      lims=None, figsize=(15, 6), dpi=100, out_path="movie.mp4"):
    num_ax = 1 if metric_vals is None or metric_labels is None else 2
    aspect = (lims[3] - lims[2]) / (lims[1] - lims[0]) if lims is not None else 1
    n_iter = len(particles)

    if num_ax == 1:
        fig = plt.figure(figsize=figsize, dpi=dpi)
        ax_part = fig.add_subplot(1, num_ax, 1)
    else:
        fig, axd = plt.subplot_mosaic([['particles', lbl] for lbl in metric_labels],
                                      figsize=figsize, dpi=dpi, constrained_layout=True)
        ax_part = axd['particles']

    ax_part.set_aspect(aspect)
    markers = plot_particles(ax_part, particles[0], px_pts,
                             goal_samples=goal_samples, lims=lims, title="Iteration: 0")

    if num_ax == 2:
        for vals, lbl in zip(metric_vals, metric_labels):
            axd[lbl].set_ylim(np.min(vals), np.max(vals))

        metric_markers = plot_metrics(axd, [[v[0]] for v in metric_vals], metric_labels, max_iter=n_iter)

    def _init():  # only required for blitting to give a clean slate.
        return (markers,)

    def _animate(i):
        ax_part.set_title("Iteration: {}".format(i))
        markers.set_xdata(particles[i][:, 0])  # update particles
        markers.set_ydata(particles[i][:, 1])  # update particles
        if num_ax == 2:
            for m, val in zip(metric_markers, [v[:i + 1] for v in metric_vals]):
                m.set_xdata(np.arange(len(val)))
                m.set_ydata(val)
        return (markers,)

    ani = animation.FuncAnimation(fig, _animate, init_func=_init, interval=200, blit=True, save_count=n_iter)
    ani.save(out_path)

    plt.close(fig)


def plot_trajectory_2D(ax, trajectory, map_img=None, lims=None,
                       x_goal=None, p_goal=None, bbox=None, vels=None, title=None):
    if title is not None:
        ax.set_title(title)

    if map_img is not None:
        ax.imshow(map_img, extent=lims, cmap="Greys", zorder=-1)

    if p_goal is not None:
        X, Y, Z = p_goal.eval_grid(lims, 50)
        ax.contour(X, Y, Z, cmap="Oranges", alpha=0.7, zorder=0)

    if bbox is not None:
        w, h = bbox[2:] - bbox[:2]
        rect = patches.Rectangle(bbox[:2], w, h, linewidth=1, edgecolor='tab:cyan', facecolor='none')
        # Add the patch to the Axes
        ax.add_patch(rect)

    if x_goal is not None:
        x, y = (x_goal[0], x_goal[1]) if x_goal.ndim == 1 else (x_goal[:, 0], x_goal[:, 1])
        zorder = 2 if x_goal.ndim == 1 else 0
        ax.scatter(x, y, c='tab:red', marker='x', zorder=zorder)

    ax.plot(trajectory[:, 0], trajectory[:, 1], linewidth=2, zorder=1)
    if vels is not None:
        cax = ax.scatter(trajectory[:, 0], trajectory[:, 1], c=vels, zorder=1)
        plt.colorbar(cax, ax=ax, fraction=0.045)

    if lims is not None:
        ax.set_xlim(lims[:2])
        ax.set_ylim(lims[2:])


def plot_trajectories_2D(ax, trajectories, weights, map_img=None, lims=None,
                         x_goal=None, p_goal=None, bbox=None, title=None, alpha=0.4):
    if title is not None:
        ax.set_title(title)

    if map_img is not None:
        ax.imshow(map_img, extent=lims, cmap="Greys", zorder=-1)

    if p_goal is not None:
        X, Y, Z = p_goal.eval_grid(lims, 50)
        ax.contour(X, Y, Z, cmap="Oranges", zorder=0)

    if bbox is not None:
        w, h = bbox[2:] - bbox[:2]
        rect = patches.Rectangle(bbox[:2], w, h, linewidth=1, edgecolor='tab:cyan', facecolor='none')
        # Add the patch to the Axes
        ax.add_patch(rect)

    if x_goal is not None:
        x, y = (x_goal[0], x_goal[1]) if x_goal.ndim == 1 else (x_goal[:, 0], x_goal[:, 1])
        zorder = 2 if x_goal.ndim == 1 else 0
        ax.scatter(x, y, c='tab:red', marker='x', zorder=zorder)

    markers = []
    for w_k, tr in zip(weights, trajectories):
        m, = ax.plot(tr[:, 0], tr[:, 1], color=plt.cm.viridis(w_k), alpha=alpha, zorder=0)
        markers.append(m)

    if lims is not None:
        ax.set_xlim(lims[:2])
        ax.set_ylim(lims[2:])

    return markers


def animate_trajectories_2D(trajectories, weights, map_img=None, lims=None, x_goal=None,
                            p_goal=None, figsize=(6, 6), out_path="trajectories.mp4"):
    n_iter = len(trajectories)
    aspect = (lims[3] - lims[2]) / (lims[1] - lims[0]) if lims is not None else 1

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    ax.set_aspect(aspect)

    markers = plot_trajectories_2D(ax, trajectories[0], weights[0], map_img=map_img,
                                   lims=lims, x_goal=x_goal, p_goal=p_goal, title="Iteration: 0")

    def _init():  # only required for blitting to give a clean slate.
        return markers

    def _animate(i):
        ax.set_title("Iteration: {}".format(i))
        for w_k, tr, m in zip(weights[i], trajectories[i], markers):
            m.set_xdata(tr[:, 0])
            m.set_ydata(tr[:, 1])
            m.set_color(plt.cm.viridis(w_k))
        return markers

    ani = animation.FuncAnimation(fig, _animate, init_func=_init, interval=200, blit=True, save_count=n_iter)
    ani.save(out_path)

    plt.close(fig)


def trajectory_files_to_figures(save_path, state=None, rollout=None,
                                map_img=None, lims=None, x_goal=None,
                                p_goal=None, figsize=(6, 6), out_path=None):
    """Loads trajectories saved to a file and creates an image for each one."""
    files = []
    for ele in os.listdir(save_path):
        if ele.endswith('npy') and not ele.startswith("grads"):
            files.append(os.path.join(save_path, ele))
    files.sort()

    # If output path is not provided, it should be the same as the save path.
    if out_path is None:
        out_path = save_path

    for f in files:
        traj = np.load(f)
        # If a rollout function is provided, use this action trajectory to rollout the state.
        if rollout is not None and state is not None:
            traj = torch.as_tensor(traj, dtype=state.dtype, device=state.device)
            traj = rollout(state, traj).state()
            traj = traj.cpu().numpy()

        # Get the trajectory sequence and construct a path.
        seq = f.split("/")[-1].replace(".npy", "")
        out_file = os.path.join(out_path, "{}.jpg".format(seq))
        # Create the figure.
        plt.figure(0, figsize=figsize)
        plt.gca().clear()
        plot_trajectories_2D(plt.gca(), traj, np.zeros(len(traj)),
                             title="Iteration {}".format(seq),
                             map_img=map_img, lims=lims,
                             x_goal=x_goal, p_goal=p_goal)
        plt.savefig(out_file)

    plt.close(0)
