import os
from typing import Sequence, Union
import numpy as np
import imageio
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import patches
from matplotlib.font_manager import FontProperties


def vis_nbody_seq(
        save_path,
        in_seq,
        target_seq=None,
        pred_seq=None,
        pred_label="pred",
        plot_stride=1, fs=10, norm="none"):
    r"""
    Parameters
    ----------
    in_seq, target_seq: np.ndarray
        shape = (T, H, W). Float value 0-1
    pred_seq:   np.ndarray or Sequence[np.ndarray]
        shape = (T, H, W). Float value 0-1
    """
    in_seq = in_seq.astype(np.float32)
    in_len = in_seq.shape[0]
    seq_list = [in_seq, ]
    label_list = ["context", ]
    seq_len_list = [in_len, ]
    if target_seq is not None:
        target_seq = target_seq.astype(np.float32)
        target_len = target_seq.shape[0]
        seq_list.append(target_seq)
        label_list.append("target")
        seq_len_list.append(target_len)
    else:
        target_len = 0
    if isinstance(pred_seq, Sequence):
        pred_seq_list = [ele.astype(np.float32) for ele in pred_seq]
        assert isinstance(pred_label, Sequence) and len(pred_label) == len(pred_seq)
        pred_label_list = pred_label
        pred_len = pred_seq_list[0].shape[0]
    elif isinstance(pred_seq, np.ndarray):
        pred_seq_list = [pred_seq.astype(np.float32), ]
        assert isinstance(pred_label, str)
        pred_label_list = [pred_label, ]
        pred_len = pred_seq_list[0].shape[0]
    else:
        assert pred_seq is None
        pred_seq_list = []
        pred_label_list = []
        pred_len = 0
    out_len = max(target_len, pred_len)
    max_len = max(in_len, out_len)
    seq_list += pred_seq_list
    label_list += pred_label_list
    seq_len_list += [out_len, ] * len(pred_seq_list)

    if norm == "none":
        norm = {'scale': 1.0,
                'shift': 0.0}
    elif norm == "to255":
        norm = {'scale': 255,
                'shift': 0}
    else:
        raise NotImplementedError
    nrows = len(seq_list)
    ncols = (max_len - 1) // plot_stride + 1
    fig, ax = plt.subplots(nrows=nrows,
                           ncols=ncols,
                           figsize=(3*ncols, 3*nrows))

    for i, (seq, label, seq_len) in enumerate(zip(seq_list, label_list, seq_len_list)):
        ax[i][0].set_ylabel(f"{label}", fontsize=fs)
        for j in range(0, max_len, plot_stride):
            if j < seq_len:
                x = seq[j] * norm['scale'] + norm['shift']
                ax[i][j // plot_stride].imshow(x, cmap="gray")
            else:
                ax[i][j // plot_stride].axis('off')
            if i == len(seq_list) - 1 and i > 0:  # the last row which is not the `in_seq`.
                ax[-1][j // plot_stride].set_title(f"step {int(j + plot_stride)}", y=-0.25, fontsize=fs)

    for i in range(len(ax)):
        for j in range(len(ax[i])):
            ax[i][j].xaxis.set_ticks([])
            ax[i][j].yaxis.set_ticks([])

    plt.subplots_adjust(hspace=0.05, wspace=0.05)
    plt.savefig(save_path)
    plt.close(fig)

def vis_nbody_seq_gif(
        save_path: str,
        in_seq: np.ndarray,
        target_seq: np.ndarray,
        pred_seq: Sequence[np.ndarray] = None,
        pred_label: Sequence[Sequence[str]] = None,
        norm="none", plot_stride: int = 1,
        gif_fps=1.0, font_size=20, figsize=None, suptitle_y=0.98,
        ):
    r"""
    Parameters
    ----------
    save_path
    in_seq, target_seq: np.ndarray
        shape = (T, H, W). Float value 0-1
    pred_seq:   np.ndarray or Sequence[np.ndarray]
        shape = (T, H, W). Float value 0-1
    label_list_list
    norm
    plot_stride
    gif_fps
    font_size
    figsize
    suptitle_y
    """
    # See https://github.com/tomsilver/pddlgym/issues/47#issuecomment-748229565
    matplotlib.use('agg')

    font = FontProperties()
    font.set_family('serif')
    # font.set_name('Times New Roman')
    font.set_size(font_size)
    # font.set_weight("bold")

    gt_font = FontProperties()
    gt_font.set_family('serif')
    # gt_font.set_name('Times New Roman')
    gt_font.set_size(font_size + 8)
    # gt_font.set_weight("bold")

    title_font = FontProperties()
    title_font.set_family('serif')
    # title_font.set_name('Times New Roman')
    title_font.set_size(font_size + 16)
    # title_font.set_weight("bold")

    # normalize the data
    if norm == "none":
        norm = {'scale': 1.0,
                'shift': 0.0}
    elif norm == "to255":
        norm = {'scale': 255,
                'shift': 0}
    else:
        raise NotImplementedError

    if isinstance(pred_seq, Sequence):
        pred_seq_list = [ele.astype(np.float32) for ele in pred_seq]
        assert isinstance(pred_label, Sequence) and len(pred_label) == len(pred_seq)
        pred_label_list = pred_label
        pred_len = pred_seq_list[0].shape[0]
    elif isinstance(pred_seq, np.ndarray):
        pred_seq_list = [pred_seq.astype(np.float32), ]
        assert isinstance(pred_label, str)
        pred_label_list = [pred_label, ]
        pred_len = pred_seq_list[0].shape[0]
    else:
        assert pred_seq is None
        pred_seq_list = []
        pred_label_list = []
        pred_len = 0

    in_seq = in_seq * norm['scale'] + norm['shift']
    target_seq = target_seq * norm['scale'] + norm['shift']
    pred_seq_list = [ele * norm['scale'] + norm['shift']
                     for ele in pred_seq_list]

    in_len = in_seq.shape[0]
    out_len = target_seq.shape[0]
    total_len = in_len + out_len

    img_np_list = []
    ncols = 2 + len(pred_seq_list)
    if figsize is None:
        figsize = (int(ncols * 8), 8)
    for i in range(0, total_len, plot_stride):
        fig, ax = plt.subplots(ncols=ncols,
                               figsize=figsize)
        ax[0].set_xlabel('Inputs', fontproperties=gt_font)
        if i < in_len:
            xt = in_seq[i, :, :]
        else:
            xt = in_seq[-1, :, :]
        ax[0].imshow(xt, cmap='gray')

        ax[1].set_xlabel('Target', fontproperties=gt_font)
        if i < in_len:
            xt = target_seq[0, :, :]
        else:
            xt = target_seq[i - in_len, :, :]
        ax[1].imshow(xt, cmap='gray')

        # Plot model predictions
        for k in range(len(pred_seq_list)):
            if i < in_len:
                ax[2 + k].axis('off')
            else:
                ax[2 + k].imshow(pred_seq_list[k][i - in_len, :, :], cmap='gray')
                ax[2 + k].set_xlabel(pred_label_list[k], fontproperties=font)
                # # add red box to the last two cols
                # if k == len(pred_seq_list) - 1:
                #     rect = patches.Rectangle(xy=(-69, -10),
                #                              width=135,
                #                              height=100, linewidth=5, edgecolor='r', facecolor='none')
                #     rect = ax[2 + k].add_patch(rect)
                #     rect.set_clip_on(False)

        for j in range(len(ax)):
            ax[j].xaxis.set_ticks([])
            ax[j].yaxis.set_ticks([])

        plt.subplots_adjust(hspace=0.05, wspace=0.05)
        # plt.tight_layout()
        if i < in_len:
            fig.suptitle(f"Input Step {plot_stride * (i + 1)}",
                         y=suptitle_y,
                         fontproperties=title_font)
        else:
            fig.suptitle(f"Predict Step {plot_stride * (i - in_len + 1)}",
                         y=suptitle_y,
                         fontproperties=title_font)
        # Convert figure to np.ndarray
        fig.canvas.draw()  # draw the canvas, cache the renderer
        img_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
        img_np = img_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        plt.close()
        img_np_list.append(img_np)

    imageio.mimsave(save_path, img_np_list, fps=gif_fps)

def vis_nbody_energy(
        save_path: str,
        KE: Union[Sequence[float], np.ndarray],
        PE: Union[Sequence[float], np.ndarray],
        KE_color: str = 'red',
        PE_color: str = 'blue',
        sum_color: str = 'black',
        marker = "s",
        marker_size = 8,
        pred_KE: Union[Sequence[float], np.ndarray] = None,
        pred_PE: Union[Sequence[float], np.ndarray] = None,
        pred_KE_color: str = 'red',
        pred_PE_color: str = 'blue',
        pred_sum_color: str = 'black',
        pred_marker = "+",
        pred_marker_size = 8,
    ):
    seq_len = len(KE)
    assert seq_len == len(PE)
    t_all = np.arange(seq_len)
    fig, ax = plt.subplots()
    ax.scatter(t_all, KE, label='KE',
               color=KE_color, marker=marker, s=marker_size,)
    ax.scatter(t_all, PE, label='PE',
               color=PE_color, marker=marker, s=marker_size,)
    ax.scatter(t_all, KE + PE, label='Etot',
               color=sum_color, marker=marker, s=marker_size,)
    if pred_KE is not None:
        ax.scatter(t_all, pred_KE, label='pred_KE',
                   color=pred_KE_color, marker=pred_marker, s=pred_marker_size,)
    if pred_PE is not None:
        ax.scatter(t_all, pred_PE, label='pred_PE',
                   color=pred_PE_color, marker=pred_marker, s=pred_marker_size,)
    if pred_KE is not None and pred_PE is not None:
        ax.scatter(t_all, pred_KE + pred_PE, label='pred_Etot',
                   color=pred_sum_color, marker=pred_marker, s=pred_marker_size,)
    ax.legend()
    ax.grid(True)
    plt.savefig(save_path)
    plt.close(fig)
