from pathlib import Path

import PIL.Image
import pandas as pd
from torch import tensor
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib import colors, cm, pyplot as plt
import matplotlib
import seaborn as sns

from matplotlib.colors import colorConverter
from matplotlib.image import AxesImage
from matplotlib.collections import PatchCollection

import matplotlib.font_manager as mfm
from typing import List, Optional, Tuple
import io

font_path = '/usr/share/fonts/truetype/ancient-scripts/Symbola_hint.ttf'
prop = mfm.FontProperties(fname=font_path)  # find this font
radar_symbol = u"\U0001F4E1"


def plot_to_image(figure):
    """Converts the matplotlib plot specified by 'figure' to a PNG image and
    returns it. The supplied figure is closed and inaccessible after this call."""
    # Save the plot to a PNG in memory.
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    # Closing the figure prevents it from being displayed directly inside
    # the notebook.
    plt.close(figure)
    buf.seek(0)
    # Convert PNG buffer to TF image
    image = tensor(PIL.Image.open(buf))[..., None]
    return image


class AnimatedImages(object):
    def __init__(self, df_train: pd.DataFrame,
                 df_test: pd.DataFrame,
                 image_extent: List[int],
                 config: dict,
                 interval=200,
                 scale=15, fps=3,
                 density_varname="mass",
                 u_varname="u",
                 v_varname="v",
                 df_val: pd.DataFrame = None,
                 interpolation="gaussian",
                 **kwargs):
        self.extent = image_extent
        self.writer = animation.writers['ffmpeg']()
        self.df_train = df_train
        self.df_val = df_val
        self.df_test = df_test.sort_values(["t", "y", "x"])

        self.density_varname = density_varname
        self.u_varname = u_varname
        self.v_varname = v_varname

        self.config = config
        self.quiv_args = dict(angles="xy", scale_units='xy', scale=scale)
        self.img_dim = config["img_dim"] - 1
        self.img_values = df_test[self.density_varname].values.reshape((self.img_dim,
                                                                        self.img_dim,
                                                                        -1))
        self.radius = config["radius"]
        self.test_grid = df_test.query("t==0")[["x", "y"]].values

        self.stream = self.data_stream()
        self.fig, self.ax = plt.subplots()

        self.colormap = sns.color_palette(kwargs.get("cmap", "Blues"), as_cmap=True)

        stream_dict, c = next(self.stream)
        self.stream = self.data_stream()

        self.norm = kwargs.get("norm", matplotlib.colors.Normalize(vmin=np.min(self.img_values),
                                                                   vmax=np.max(self.img_values))
                               )

        image = stream_dict["mass"].values.reshape((self.img_dim, self.img_dim))
        self.q = self.ax.quiver(stream_dict["x"][::3], stream_dict["y"][::3],
                                stream_dict["u"][::3], stream_dict["v"][::3], animated=True,
                                # alpha=self.calc_alpha(image).flatten()[::3],
                                **self.quiv_args)

        self.im: AxesImage = plt.imshow(image, origin='lower', extent=self.extent, norm=self.norm, cmap=self.colormap,
                                        alpha=1., interpolation=interpolation,
                                        )
        self.ani = animation.FuncAnimation(self.fig, self.update,
                                           frames=(i for i in range(self.img_values.shape[-1])),
                                           interval=interval,
                                           init_func=self.setup_plot, blit=True, repeat=False,
                                           )
        cbar = self.fig.colorbar(self.im, ax=self.ax)

    @staticmethod
    def calc_alpha(image):
        image = np.sqrt(image)
        return np.clip((image - image.min()) / (image.max() - image.min()), 0., 1.) ** 1.5

    def setup_plot(self):
        """Initial drawing of the plot."""
        self.ax.axis(self.extent)
        self.ax.set_aspect('equal')
        # draw circles showing the area of the simulated 'radar stations' with given radius
        # self.ax.scatter(*(self.datasets.train.X[..., :2].T / self.config_model.scaling_factor), color="black", s=2.,
        #                 marker='.')
        patches = []
        for x in self.df_train[["x", "y"]].drop_duplicates().values:
            patches.append(plt.Circle(x,
                                      radius=self.radius,
                                      fc=colorConverter.to_rgba('blue', alpha=0.01),
                                      ec='tab:gray'))
            self.ax.add_artist(patches[-1])
        if self.df_val is not None:
            for x in self.df_val[["x", "y"]].drop_duplicates().values:
                patches.append(plt.Circle(x,
                                          radius=self.radius,
                                          fc=colorConverter.to_rgba('orange', alpha=0.01),
                                          ec='tab:red'))
                self.ax.add_artist(patches[-1])
                # plt.text(x[0], x[1], s=radar_symbol, fontsize=8, fontproperties=prop,
                #          ha='center', va='center')
        # self.p = PatchCollection(patches, cmap = self.colormap, alpha=0.5)
        # self.p = PatchCollection(patches, alpha = 0.01)
        # self.ax.add_collection(self.p)

        # For FuncAnimation's sake, we need to return the artist we'll be using
        # Note that it expects a sequence of artists, thus the trailing comma.
        return self.im,

    def data_stream(self):
        for t, group in self.df_test.groupby("t"):
            # yield [group.mass.values.reshape(self.config["img_dim"], self.config["img_dim"]),
            #        group[["u", "v"]].values,
            #        group[["x", "y"]].values,
            #        t]
            yield dict(
                x=group.x,
                y=group.y,
                u=group[self.u_varname],
                v=group[self.v_varname],
                mass=group[self.density_varname]), t
        # for i, t in enumerate(self.timesteps):
        #     # plt.scatter(*sol, color=, alpha=0.6, s=1., marker='+')
        #     yield [self.img_values[..., i], self.quiver_values[i::self.config.num_time_steps_test, :], t]

    def update(self, i):
        """Update the plot."""
        stream_dict, t = next(self.stream)
        image = stream_dict["mass"].values.reshape((self.img_dim,
                                                    self.img_dim))
        # Set x and y data...
        self.im.set_array(image)
        self.im.set_alpha(1.)
        self.q.set_UVC(stream_dict["u"][::3], stream_dict["v"][::3])
        # self.q.set_alpha(self.calc_alpha(image).flatten()[::3])
        self.ax.set_title(f"Time: {t:.2f}")

        # self.p.set_array(self.norm(train_mass))
        # self.p.set_edgecolor("black")
        # self.p.set_alpha(.5)
        # We need to return the updated artist for FuncAnimation to draw..
        # Note that it expects a sequence of artists, thus the trailing comma.
        return self.im,

    def save(self, filename="animation.mp4"):
        self.ani.save(filename, writer=self.writer)
        plt.close(self.fig)


class AnimatedScatter(object):
    """An animated scatter plot using matplotlib.animations.FuncAnimation."""

    def __init__(self, solutions_y, timesteps, img_values: np.array = None,
                 config=None, image_extent: List[int] = None,
                 quiver_values: np.array = None,
                 interval=100, fps=3):
        self.solutions_y = solutions_y
        self.img_values = img_values
        self.quiver_values = quiver_values

        self.timesteps = timesteps
        self.stream = self.data_stream()
        self.config = config
        self.extent = image_extent

        # Setup the figure and axes...
        self.fig, self.ax = plt.subplots()
        # self.ax.axis('off')
        # self.ax.set_facecolor("lightskyblue")
        # Then setup FuncAnimation.
        self.colormap = cm.get_cmap("inferno")
        # self.norm = matplotlib.colors.Normalize(vmin=np.min(self.timesteps),
        #                                         vmax=np.max(self.timesteps))

        self.norm = matplotlib.colors.Normalize(vmin=np.min(self.img_values),
                                                vmax=np.max(self.img_values))

        Writer = animation.writers['ffmpeg']
        self.writer = Writer(fps=fps)

        xy, img, quiv, col = next(self.stream)
        # self.scat = self.ax.scatter(*xy, color=self.colormap(self.norm(col)), s=1., marker='+', alpha=0.3)
        self.scat = self.ax.scatter(*xy, s=1., marker='+', alpha=0.01)
        self.ax.set_xlim(-1.5, 1.5)
        self.ax.set_ylim(-1.5, 1.5)
        if self.img_values is not None:
            self.im: AxesImage = plt.imshow(img, origin='lower', extent=self.extent, alpha=0.5,
                                            vmin=np.min(self.timesteps),
                                            vmax=np.max(self.timesteps), cmap=self.colormap)
        # if self.quiver_values is not None:
        #     self.q = self.ax.quiver(*self.datasets.test.X[::len(self.timesteps), :2].T, *quiv.T)

        self.stream = self.data_stream()
        self.ani = animation.FuncAnimation(self.fig, self.update, interval=interval,
                                           init_func=self.setup_plot, blit=True,
                                           frames=(i for i in range(len(self.timesteps))), repeat=False)

    def save(self, filename="animation.mp4"):
        self.ani.save(filename, writer=self.writer)

    def setup_plot(self):
        """Initial drawing of the scatter plot."""
        self.ax.axis([-1.5, 1.5, -1.5, 1.5])
        self.ax.set_aspect('equal')
        # if self.datasets is not None:
        #     self.ax.scatter(*self.datasets.train.X.T, color="black", s=2., marker='.')
        #     if self.config is not None:
        #         for x in self.datasets.train.X:
        #             self.ax.add_artist(plt.Circle(x,
        #                                           radius=self.config.radius,
        #                                           fc=colorConverter.to_rgba('blue', alpha=0.1),
        #                                           ec='black'))

        # For FuncAnimation's sake, we need to return the artist we'll be using
        # Note that it expects a sequence of artists, thus the trailing comma.
        return self.scat,

    def data_stream(self):
        while True:
            for i, (xy, t) in enumerate(zip(self.solutions_y.T, self.timesteps)):
                # plt.scatter(*sol, color=, alpha=0.6, s=1., marker='+')
                img = self.img_values[..., i] if self.img_values is not None else None
                quiv = self.quiver_values[i::len(self.timesteps),
                       :] if self.quiver_values is not None else None
                yield [xy, img, quiv, t]

    def update(self, i):
        """Update the scatter plot."""
        xy, img, quiv, t = next(self.stream)
        # Set x and y data...
        self.scat.set_offsets(xy.T)
        self.scat.set_color(self.colormap(self.norm(t)))
        self.ax.set_title(f"Time: {t:.2f}")
        if self.img_values is not None:
            self.im.set_array(img)
        if self.quiver_values is not None:
            self.q.set_UVC(*quiv.T)

        # Set colors..
        # self.scat.set_array(data[:, 3])

        # We need to return the updated artist for FuncAnimation to draw..
        # Note that it expects a sequence of artists, thus the trailing comma.
        return self.scat,
