from abc import ABC
from typing import Union, List, Tuple

from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure


class Drawer(ABC):
    def start_drawing(self, alg, *args, **kwargs):
        raise NotImplementedError()

    def update_data(self, alg, *args, **kwargs):
        raise NotImplementedError()

    def draw_data(self, alg, *args, **kwargs) -> List[Tuple[Union[Figure], str]]:
        raise NotImplementedError()

    def end_drawing(self, alg, *args, **kwargs) -> List[Tuple[Union[Figure], str]]:
        raise NotImplementedError()

    def close(self):
        pass


class PLTDrawer(Drawer, ABC):
    def __init__(self, fig: Figure = None, ax: Axes = None):
        self.fig = fig or plt.figure(figsize=[13, 10])
        self.ax = ax or self.fig.add_axes([0, 0, 1, 1])

    def update_plot(self, fig: Figure = None, ax: Axes = None):
        self.fig = fig or self.fig
        self.ax = ax or self.ax

    def close(self):
        plt.close(self.fig)


class StaticDrawer(Drawer, ABC):
    def draw(self, alg, *args, **kwargs):
        raise NotImplementedError()

    def draw_data(self, alg, *args, **kwargs) -> List[Tuple[Union[Figure], str]]:
        return self.draw(alg, *args, **kwargs)

    def update_data(self, alg, *args, **kwargs):
        return self.draw(alg, *args, **kwargs)

    def start_drawing(self, alg, *args, **kwargs):
        return self.draw(alg, *args, **kwargs)

    def end_drawing(self, alg, *args, **kwargs) -> List[Tuple[Union[Figure], str]]:
        return self.draw(alg, *args, **kwargs)


class StaticPLTDrawer(PLTDrawer, StaticDrawer, ABC):
    pass


class MultipleDrawerEpoch(Drawer):
    def __init__(self, plt_drawers: List[PLTDrawer], fig: Figure = None):
        self.drawers = plt_drawers
        self.fig = fig or plt.figure(figsize=[12.8, 9.6])

    def start_drawing(self, alg, *args, **kwargs):
        ax = self.fig.add_axes([0, 0, 1, 1])
        for drawer in self.drawers:
            drawer.update_plot(self.fig, ax)
        for drawer in self.drawers:
            drawer.start_drawing(alg, *args, **kwargs)
        return [(self.fig, "")]

    def update_data(self, alg, *args, **kwargs):
        for drawer in self.drawers:
            drawer.update_data(alg, *args, **kwargs)
        return [(self.fig, "")]

    def draw_data(self, alg, *args, **kwargs) -> List[Tuple[Union[Figure], str]]:
        for drawer in self.drawers:
            drawer.draw_data(alg, *args, **kwargs)
        return [(self.fig, "")]

    def end_drawing(self, alg, *args, **kwargs) -> List[Tuple[Union[Figure], str]]:
        for drawer in self.drawers:
            drawer.end_drawing(alg, *args, **kwargs)
        return [(self.fig, "")]

    def close(self):
        plt.close(self.fig)
