import abc
import functools
from abc import ABC
from logging import Logger
from typing import Tuple

import wandb
from matplotlib.figure import Figure
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter

from handlers.base_handler import AlgorithmCallbackHandler
from handlers.drawers.base_drawer import Drawer


class DrawerHandler(AlgorithmCallbackHandler, ABC):
    def __init__(self, drawer: Drawer, name: str, *args, **kwargs):
        super(DrawerHandler, self).__init__(*args, **kwargs)
        self.drawer = drawer
        self.name = name
        self.counter = 0

    def handle_drawings(self, drawings):
        for figure, name in drawings:
            self.handle_drawing(figure, name)
        self.counter += 1

    @abc.abstractmethod
    def handle_drawing(self, drawing, drawing_name: str):
        raise NotImplementedError()

    def on_algorithm_start(self, alg, *args, **kwargs):
        drawing = self.drawer.start_drawing(alg, *args, **kwargs)
        self.handle_drawings(drawing)

    def on_algorithm_update(self, alg, *args, **kwargs):
        drawing = self.drawer.start_drawing(alg, *args, **kwargs)
        self.handle_drawings(drawing)

    def on_epoch_end(self, alg, *args, **kwargs):
        drawings = self.drawer.update_data(alg, *args, **kwargs)
        self.handle_drawings(drawings)

    def on_algorithm_end(self, alg, *args, **kwargs):
        drawings = self.drawer.end_drawing(alg, *args, **kwargs)
        self.handle_drawings(drawings)
        self.drawer.close()


class WANDBHandler(DrawerHandler):
    def __init__(self, *args, **kwargs):
        super(WANDBHandler, self).__init__(*args, **kwargs)
        self.epoch = 0

    @functools.singledispatchmethod
    def handle_drawing(self, drawing, drawing_name: str):
        raise NotImplementedError()

    @handle_drawing.register(Figure)
    def handle_figure(self, drawing: Figure, drawing_name: str):
        wandb.log({drawing_name: wandb.Image(drawing)})

    @handle_drawing.register(float)
    def handle_figure(self, drawing_data: float, drawing_name: str):
        wandb.log({drawing_name: drawing_data})

    @handle_drawing.register(tuple)
    def handle_scalar(self, drawing_data: Tuple[float, int], drawing_name: str):
        wandb.log({drawing_name: drawing_data[0]}, step=drawing_data[1])


class TensorboardDrawerHandler(DrawerHandler):
    def __init__(self, *args, writer=None, **kwargs):
        super(TensorboardDrawerHandler, self).__init__(*args, **kwargs)
        self.writer = writer or SummaryWriter()

    def on_algorithm_end(self, alg, *args, **kwargs):
        super().on_algorithm_end(alg, *args, **kwargs)
        self.writer.flush()
        self.writer.close()

    @functools.singledispatchmethod
    def handle_drawing(self, drawing_data, drawing_name: str):
        raise NotImplementedError()

    @handle_drawing.register(Figure)
    def handle_figure(self, drawing_data: Figure, drawing_name: str):
        self.writer.add_figure(
            f"{self.name} {drawing_name}", drawing_data, global_step=self.counter
        )

    @handle_drawing.register(float)
    def handle_figure(self, drawing_data: float, drawing_name: str):
        self.writer.add_scalar(
            f"{self.name} {drawing_name}", drawing_data, global_step=self.counter
        )

    @handle_drawing.register(tuple)
    def handle_scalar(self, drawing_data: Tuple[float, int], drawing_name: str):
        self.writer.add_scalar(
            f"{self.name} {drawing_name}", drawing_data[0], global_step=drawing_data[1]
        )


class LoggerDrawerHandler(DrawerHandler):
    def __init__(self, *args, logger: Logger, **kwargs):
        super().__init__(*args, **kwargs)
        self.logger = logger

    @functools.singledispatchmethod
    def handle_drawing(self, drawing_data, drawing_name: str):
        raise NotImplementedError()

    @handle_drawing.register(float)
    def handle_figure(self, drawing_data: float, drawing_name: str):
        self.logger.info(f"{self.name} - {drawing_name}: {drawing_data}")

    @handle_drawing.register(Tensor)
    def handle_figure(self, drawing_data: Tensor, drawing_name: str):
        self.logger.info(f"{self.name} - {drawing_name}: {drawing_data}")
