import matplotlib.pyplot as plt
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.offsetbox import OffsetImage, AnnotationBbox, TextArea
from scipy.stats import gaussian_kde


def scatter_img(ax, img, pos, zoom=1.0, alpha=1.0, cmap="gray", framecolor=None, title=None):
    ax.scatter(pos[0], pos[1], alpha=0)
    img = img.reshape(2, 28, 28).permute(1, 0, 2).reshape(28, 56)
    img = OffsetImage(img, zoom=zoom, cmap=cmap, alpha=alpha)
    if framecolor is not None:
        ab = AnnotationBbox(img, pos, frameon=True, pad=0.1, bboxprops=dict(edgecolor=framecolor, lw=2))
    else:
        ab = AnnotationBbox(img, pos, frameon=False)
    ax.add_artist(ab)

    if title is not None:
        text = TextArea(str(title))
        text_pos = (pos[0], pos[1]+0.25)
        ab = AnnotationBbox(text, text_pos)
        ax.add_artist(ab)


def visualise_mnist_chain(mdp, title=""):
    fig, ax = plt.subplots(figsize=(10, 10))

    for state in mdp.states:
        # build average states, i.e., abstract nodes
        avg_pos = np.array([x["position"] for x in state.info]).mean(axis=0)
        avg_pos = np.concatenate([avg_pos, [3.5]])
        ax.scatter(avg_pos[0], avg_pos[1], alpha=0)
        avg_img = state.data.mean(axis=0).reshape(28, 28)
        avg_img = OffsetImage(avg_img, zoom=1, cmap="gray")
        ab = AnnotationBbox(avg_img, avg_pos, frameon=False)
        ax.add_artist(ab)

        # show samples from the low-level space
        point_set = []
        n_sample = int(len(state.data) * 0.01)
        np.random.seed(len(state.data))
        for _ in range(n_sample):
            i = np.random.randint(0, state.data.shape[0])
            pos = state.info[i]["position"].copy()
            pos = np.concatenate([pos, 0.2*np.random.rand(1)])
            point_set.append(pos)
            ax.scatter(pos[0], pos[1], alpha=0)
            img_i = state.data[i].clone().numpy().reshape(28, 28)
            # img_i = np.stack([1-img_i, 1-img_i, 1-img_i, img_i], axis=-1)
            img_i = OffsetImage(img_i, zoom=0.4, alpha=1, cmap="gray")
            ab = AnnotationBbox(img_i, pos, frameon=False)
            ax.add_artist(ab)

        # plot the density of the samples
        points = np.array([x["position"] for x in state.info]).reshape(1, -1)
        kde = gaussian_kde(points, bw_method=0.25)
        rng = np.linspace(points.min()-0.5, points.max()+0.5, 200)
        y = kde.pdf(rng)
        plt.plot(rng, y+0.5, color="k", linestyle='-')
        plt.fill_between(rng, 0.5, y+0.5, alpha=0.3, color="k")

    plt.axis('off')
    plt.axis('equal')
    plt.tight_layout
    if title:
        pp = PdfPages(f"{title}.pdf")
        pp.savefig(fig)
        pp.close()
        plt.savefig(f"{title}.svg")
    else:
        plt.show()


def visualise_mnist_mdp(mdp, sample_mdp, title=""):
    fig, ax = plt.subplots(figsize=(10, 10))
    f1 = {}
    f2 = {}
    for state in sample_mdp.states:  # using sample_mdp just to show the same set of points
        # show samples from the low-level space
        n_sample = int(len(state.data) * 0.005)
        np.random.seed(state.id)
        for _ in range(n_sample):
            i = np.random.randint(0, state.data.shape[0])
            pos = state.info[i]["position"].copy()
            img_i = state.data[i]
            scatter_img(ax, img_i, pos, cmap="Greys", zoom=0.3)

    for state in mdp.states:
        # build average states, i.e., abstract nodes
        avg_pos = np.array([x["position"] for x in state.info]).mean(axis=0)
        avg_img = state.data.mean(axis=0)
        scatter_img(ax, avg_img, avg_pos, zoom=0.75, title=len(state.data))
        if state.factors[0] not in f1:
            f1[state.factors[0]] = []
        if state.factors[1] not in f2:
            f2[state.factors[1]] = []

        f1[state.factors[0]].extend([x["position"][0] for x in state.info])
        f2[state.factors[1]].extend([x["position"][1] for x in state.info])

    # plot the density of the samples
    for k in f1:
        points = np.array(f1[k]).reshape(1, -1)
        kde = gaussian_kde(points, bw_method=0.25)
        rng = np.linspace(-0.6, 5.6, 200)
        y = kde.pdf(rng)
        plt.plot(rng, y/2+5.6, color="k", linestyle='-')
        plt.fill_between(rng, 5.6, y/2+5.6, alpha=0.3, color="k")
    for k in f2:
        points = np.array(f2[k]).reshape(1, -1)
        kde = gaussian_kde(points, bw_method=0.25)
        rng = np.linspace(-0.6, 5.6, 200)
        y = kde.pdf(rng)
        plt.plot(y/2+5.6, rng, color="k", linestyle='-')
        plt.fill_betweenx(rng, 5.6, y/2+5.6, alpha=0.3, color="k")

    plt.axis('off')
    plt.axis('equal')
    plt.tight_layout
    if title:
        pp = PdfPages(f"{title}.pdf")
        pp.savefig(fig)
        pp.close()
        plt.savefig(f"{title}.svg")
    else:
        plt.show()


def visualise_mnist_transitions(abstract_transitions, action_names):
    n_col = 3
    n_row = len(abstract_transitions) // n_col
    n_row += 1 if len(abstract_transitions) % n_col else 0
    _, axes = plt.subplots(n_row, n_col*2, figsize=(n_col*2*2, n_row*2))
    axes = axes.flatten()

    for i, (s, option, s_prime, reward, steps, prob) in enumerate(abstract_transitions):
        ndim = s.data.shape[-1] // 784
        s_avg = np.mean(s.data, axis=0).reshape(ndim, 28, 28)
        s_avg = np.transpose(s_avg, (1, 0, 2)).reshape(28, -1)
        axes[2*i].imshow(s_avg, cmap="Greys")
        axes[2*i].set_title("{}.{} + {} -> {}".format(s.id, s.refinement, action_names[option], s_prime.id))
        axes[2*i].axis('off')

        s_prime_avg = np.mean(s_prime.data, axis=0).reshape(ndim, 28, 28)
        s_prime_avg = np.transpose(s_prime_avg, (1, 0, 2)).reshape(28, -1)
        axes[2*i+1].imshow(s_prime_avg, cmap="Greys")
        axes[2*i+1].set_title("p={}, r={}, t={}".format(round(prob, 2), round(reward, 1), round(steps, 0)))
        axes[2*i+1].axis('off')

    plt.show()


def visualise_mnist_states(abstract_states):
    n_col = 5
    n_row = len(abstract_states) // n_col
    n_row += 1 if len(abstract_states) % n_col else 0
    _, axes = plt.subplots(n_row, n_col, figsize=(n_col * 2, n_row * 2))
    axes = axes.flatten()
    for ax, state in zip(axes, abstract_states):
        ndim = state.data.shape[-1] // 784
        avg = state.data.mean(dim=0).reshape(ndim, 28, 28)
        avg = avg.permute(1, 0, 2).reshape(28, -1)
        ax.imshow(avg, cmap="Greys")
        ax.set_title(f"{state.factors}")
        ax.axis('off')
    plt.tight_layout()
    plt.show()


def visualise_monty_states(abstract_states):
    n_col = 5
    n_row = len(abstract_states) // n_col
    n_row += 1 if len(abstract_states) % n_col else 0
    _, axes = plt.subplots(n_row, n_col, figsize=(n_col * 2, n_row * 2))
    axes = axes.flatten()
    n_sample = 100
    for ax, state in zip(axes, abstract_states):
        indices = np.random.randint(0, len(state.data), (n_sample,))
        avg = state.data[indices].mean(axis=0).reshape(210, 160, 3)
        ax.imshow(avg)
        title = f"{state.id}"
        for (f, f_id, r_id) in state.factors:
            title += f".{f.name}.{f_id}.{r_id}"
        ax.set_title(title)
        ax.axis('off')
    plt.tight_layout()
    plt.show()


def visualise_monty_ram_states(abstract_states, name=""):
    from gym_montezuma.envs.montezuma_env import MontezumasRevengeEnv
    env = MontezumasRevengeEnv(observation_mode="rgb_array")
    n_col = 8
    n_row = len(abstract_states) // n_col
    n_row += 1 if len(abstract_states) % n_col else 0
    _, axes = plt.subplots(n_row, n_col, figsize=(n_col * 2, n_row * 2))
    axes = axes.flatten()
    ram_idx = list(range(128))
    n_sample = 100
    for ax, state in zip(axes, abstract_states):
        imgs = np.zeros((100, 210, 160, 3))
        indices = np.random.randint(0, len(state.data), (n_sample,))
        for i, idx in enumerate(indices):
            env.reset()
            for j, v in enumerate(state.data[idx]):
                env.unwrapped.ale.setRAM(ram_idx[j], int(255*v))
            env.unwrapped.step(0)
            imgs[i] = env.observation.copy().reshape(210, 160, 3) / 255
        avg = np.mean(imgs, axis=0)
        ax.imshow(avg)
        title = f"{state.id} ({len(state.data)})"
        # for (f, f_id, r_id) in state.factors:
        #     title += f".{f.name}.{f_id}.{r_id}"
        ax.set_title(title)
        ax.axis('off')
    env.close()
    plt.tight_layout()
    if name != "":
        plt.savefig(name)
    else:
        plt.show()
