import warnings

import matplotlib.pyplot as plt
from typing import Optional
import pathlib
from datetime import datetime
import shutil
import numpy as np
import torch
import os
import torch.utils.tensorboard
from dataclasses import dataclass


@dataclass
class VisualizerConfig:
    dla_particle_alpha: float = 1.0
    dla_particle_color: str = 'tab:blue'
    dla_particle_marker: str = 'o'

    flow_particle_alpha: float = 0.3
    flow_particle_color: str = 'tab:orange'
    flow_particle_marker: str = 'o'

    arrow_alpha: float = 1.0
    arrow_color: str = 'k'
    step_text_location = (0.5, 0.85)


class DLAVisualizer:
    def __init__(self,
                 directory: pathlib.Path,
                 tensorboard: bool = False,
                 save_figures: bool = False,
                 show_figures: bool = False):
        self.directory = directory
        self.figures_directory = self.directory / 'figures'

        self.make_directories()

        self.config = VisualizerConfig()

        self.tensorboard = tensorboard
        self.save_figures = save_figures
        self.show_figures = show_figures
        self.writer = None if not tensorboard else torch.utils.tensorboard.SummaryWriter(log_dir=str(self.directory))

        self.enabled = self.tensorboard or self.save_figures or self.show_figures

    def make_directories(self):
        self.directory.mkdir(exist_ok=True, parents=True)
        self.figures_directory.mkdir(exist_ok=True, parents=True)

    def add_scalar(self, tag, value, step):
        if self.enabled and self.tensorboard:
            self.writer.add_scalar(tag=tag, scalar_value=value, global_step=step)

    @torch.no_grad()
    def add_figure(self, tag: str, x: torch.Tensor, step: int, dim0: int = 0, dim1: int = 1, **kwargs):
        if not self.enabled:
            return
        if not 0 <= dim0 < x.shape[1]:
            raise ValueError(f"Dimension 0 cannot be {dim0} with {x.shape[1]} dimensions")
        if not 0 <= dim1 < x.shape[1]:
            raise ValueError(f"Dimension 1 cannot be {dim1} with {x.shape[1]} dimensions")

        fig, ax = plt.subplots(**kwargs)
        ax.scatter(x[:, dim0], x[:, dim1])
        fig.tight_layout()
        if self.save_figures:
            fig.savefig(str(self.figures_directory / tag / f'{step:06d}.png'), bbox_inches='tight')
        if self.tensorboard:
            self.writer.add_figure(tag=tag, figure=fig, global_step=step, close=False)
        if self.show_figures:
            plt.show()
        plt.close(fig)

    def animate(self,
                dla_particles_dir: pathlib.Path,
                flow_samples_dir: pathlib.Path,
                output_dir: pathlib.Path,
                dim0: int = 0,
                dim1: int = 1,
                dpi: int = None,
                file_ext: str = 'gif',
                fig_kwargs: dict = None,
                legend_kwargs: dict = None,
                func_animation_kwargs: dict = None,
                ax_title: str = None,
                ax_limits=None,
                xscale=None,
                yscale=None,
                dynamic_limits=False):
        """
        Animate saved data.

        :param dim0: index of the first dimension to be animated.
        :param dim1: index of the second dimension to be animated.
        :param dpi: DPI used to render the animation. A higher DPI means higher quality, but more processing time.
        :param file_ext: file extension for the output animation. To use 'mp4', you must set the ffmpeg path in
            plt.rcParams.
        :param fig_kwargs: keyword arguments for plt.subplots.
        :param legend_kwargs: keyword arguments for ax.legend.
        :param func_animation_kwargs: keyword arguments for FuncAnimation.
        :param ax_title: plot title.
        """
        from matplotlib.animation import FuncAnimation

        if fig_kwargs is None:
            fig_kwargs = dict()
        if legend_kwargs is None:
            legend_kwargs = dict(loc='upper left')
        if func_animation_kwargs is None:
            func_animation_kwargs = dict(interval=10, blit=True, repeat=True)

        print(f'Animating data in {str(dla_particles_dir)}')
        animation_file_path = output_dir / f'animation-{dim0}-{dim1}.{file_ext}'
        numpy_files_dla = list(dla_particles_dir.glob('*.npy'))
        numpy_files_dla = list(sorted(numpy_files_dla, key=lambda i: int(os.path.splitext(os.path.basename(i))[0])))
        numpy_files_flow = list(flow_samples_dir.glob('*.npy'))
        numpy_files_flow = list(sorted(numpy_files_flow, key=lambda i: int(os.path.splitext(os.path.basename(i))[0])))

        # Make sure we don't load too many files into memory
        num_files = len(numpy_files_dla)
        dla_file_size = np.load(str(numpy_files_dla[0])).nbytes
        flow_file_size = np.load(str(numpy_files_flow[0])).nbytes
        total_nbytes = num_files * (dla_file_size + flow_file_size)
        print(f'Animation data size: {(total_nbytes / 1e6):.2f} MB')
        if total_nbytes > 50_000_000:
            warnings.warn('Animation data size exceeds 50 MB. Stopping.')
            return


        fig, ax = plt.subplots(**fig_kwargs)

        dla_particles, = ax.plot(
            [],
            [],
            color=self.config.dla_particle_color,
            linewidth=0,
            marker=self.config.dla_particle_marker,
            alpha=self.config.dla_particle_alpha,
            label='DLA particles'
        )
        flow_samples, = ax.plot(
            [],
            [],
            color=self.config.flow_particle_color,
            linewidth=0,
            marker=self.config.flow_particle_marker,
            alpha=self.config.flow_particle_alpha,
            label='Flow samples'
        )
        ax.legend(**legend_kwargs)
        ax.set_title(ax_title)

        line_plots = []
        for _ in range(np.load(str(numpy_files_dla[0])).shape[1]):
            lines_plot, = ax.plot([], [], color=self.config.arrow_color, alpha=self.config.arrow_alpha)
            line_plots.append(lines_plot)

        step_text = ax.text(
            *self.config.step_text_location,  # Plot space x and y
            "",
            bbox={
                'facecolor': 'w',
                'alpha': 0.5,
                'pad': 5
            },
            transform=ax.transAxes,
            ha="center"
        )

        min_x = np.infty
        min_y = np.infty
        max_x = -np.infty
        max_y = -np.infty

        def update(step):
            nonlocal min_x, max_x, min_y, max_y

            def get_limits(data, min_x, max_x, min_y, max_y, clip_min=-1e9, clip_max=1e9, dynamic=False):
                min_x_data = float(np.nanmin(data[..., dim0]))
                min_y_data = float(np.nanmin(data[..., dim1]))
                max_x_data = float(np.nanmax(data[..., dim0]))
                max_y_data = float(np.nanmax(data[..., dim1]))

                # print(min_x, min_y, max_x, max_y)
                # print(min_x_data, min_y_data, max_x_data, max_y_data)

                if dynamic:
                    min_x = min_x_data
                    max_x = max_x_data
                    min_y = min_y_data
                    max_y = max_y_data
                else:
                    min_x = min(min_x, min_x_data)
                    max_x = max(max_x, max_x_data)
                    min_y = min(min_y, min_y_data)
                    max_y = max(max_y, max_y_data)

                # print(min_x, min_y, max_x, max_y)

                min_x = max(clip_min, min_x)
                max_x = min(clip_max, max_x)
                min_y = max(clip_min, min_y)
                max_y = min(clip_max, max_y)

                # print(min_x, min_y, max_x, max_y)

                return min_x, max_x, min_y, max_y

            # Updating DLA particle data
            min_x_dla, max_x_dla, min_y_dla, max_y_dla = (0, 0, 0, 0)
            if step < len(numpy_files_dla):
                data = np.load(str(numpy_files_dla[step]))
                # print(f'DLA shape {data.shape}')
                min_x_dla, max_x_dla, min_y_dla, max_y_dla = get_limits(data, min_x, max_x, min_y, max_y,
                                                                        dynamic=(step == 0) or dynamic_limits)
                dla_particles.set_data(data[:, dim0], data[:, dim1])

                for i, line_plot in enumerate(line_plots):
                    line_plot.set_data(data[i, dim0], data[i, dim1])

            # Updating flow sample data
            min_x_flow, max_x_flow, min_y_flow, max_y_flow = (0, 0, 0, 0)
            if step < len(numpy_files_flow):
                data = np.load(str(numpy_files_flow[step]))
                # print(f'Flow shape {data.shape}')
                min_x_flow, max_x_flow, min_y_flow, max_y_flow = get_limits(data, min_x, max_x, min_y, max_y,
                                                                            dynamic=(step == 0) or dynamic_limits)
                flow_samples.set_data(data[:, dim0], data[:, dim1])

            step_text.set_text(f'Step {step:>5}')

            min_x, max_x, min_y, max_y = (
                min(min_x_flow, min_x_dla),
                max(max_x_flow, max_x_dla),
                min(min_y_flow, min_y_dla),
                max(max_y_flow, max_y_dla)
            )

            # print('Flow', min_x_flow, max_x_flow, min_y_flow, max_y_flow)
            # print('DLA', min_x_dla, max_x_dla, min_y_dla, max_y_dla)

            dx = (max_x - min_x) * 0.1
            dy = (max_y - min_y) * 0.1
            if ax_limits is not None:
                ax.axis(ax_limits)
            else:
                ax.axis([min_x - dx, max_x + dx, min_y - dy, max_y + dy])
            if xscale is not None:
                ax.set_xscale(xscale)
            if yscale is not None:
                ax.set_yscale(yscale)
            return dla_particles, step_text, *line_plots

        print('Writing data to file')
        ani = FuncAnimation(fig, update, frames=len(numpy_files_dla), **func_animation_kwargs)
        ani.save(str(animation_file_path), writer='ffmpeg', dpi=dpi)
        plt.close(fig)

        print(f'> Output file: {str(animation_file_path.absolute())}')


class DLAFileWriter:
    def __init__(self, directory: pathlib.Path, enabled: bool = False):
        self.directory = directory

        self.static_data_dir = directory / 'static'
        self.flow_data_dir = directory / 'flow_samples'
        self.dla_data_dir = directory / 'dla_particles'

        self.directory.mkdir(exist_ok=True, parents=True)
        self.static_data_dir.mkdir(exist_ok=True, parents=True)
        self.flow_data_dir.mkdir(exist_ok=True, parents=True)
        self.dla_data_dir.mkdir(exist_ok=True, parents=True)

        self.enabled = enabled

    def write_scalar(self, tag: str, value: float, step: int):
        if self.enabled:
            pass  # TODO implement

    def write_static_data(self, tag: str, value: torch.Tensor):
        if self.enabled:
            np.save(str(self.static_data_dir / tag), value.detach().cpu().numpy())

    def write_flow_samples(self, flow_samples: torch.Tensor, step: int):
        if self.enabled:
            np.save(file=str(self.flow_data_dir / f'{step:06d}'), arr=flow_samples.detach().cpu().numpy())

    def write_dla_particles(self, particles: torch.Tensor, step: int):
        if self.enabled:
            np.save(file=str(self.dla_data_dir / f'{step:06d}'), arr=particles.detach().cpu().numpy())


class SingleStageDebugger:
    def __init__(self,
                 directory: Optional[pathlib.Path] = None,
                 directory_suffix: Optional[str] = None,
                 delete_existing: bool = False,
                 tensorboard: bool = False,
                 save_raw_data: bool = False,
                 show_figures: bool = False,
                 save_figures: bool = False,
                 animate: bool = False,
                 static_data: Optional[dict] = None):
        if directory is None:
            directory = 'runs' / pathlib.Path(str(datetime.now().strftime("%Y-%m-%d %H-%M-%S")))
        if directory_suffix is not None:
            directory = directory / directory_suffix
        if directory.exists() and delete_existing:
            shutil.rmtree(str(directory))
        if static_data is None:
            static_data = dict()

        self.directory = directory.absolute()  # Absolute path helps make logs clear
        self.animation_dir = self.directory / 'animations'

        self.directory.mkdir(exist_ok=True, parents=True)
        self.animation_dir.mkdir(exist_ok=True, parents=True)

        print(f'Debugger working directory: {str(self.directory)}')

        self._step: int = 0  # DLA step counter
        self.animatable = animate and save_raw_data

        self.file_writer = DLAFileWriter(self.directory / 'raw_data', enabled=save_raw_data)
        self.visualizer = DLAVisualizer(
            self.directory / 'visualization',
            tensorboard=tensorboard,
            save_figures=save_figures,
            show_figures=show_figures
        )

        for key, value in static_data.items():
            self.file_writer.write_static_data(key, value)

    def step(self):
        """
        Increment the DLA step counter.
        """
        self._step += 1

    def add_scalar(self, tag: str, scalar: float):
        self.file_writer.write_scalar(tag=tag, value=scalar, step=self._step)
        self.visualizer.add_scalar(tag=tag, value=scalar, step=self._step)

    def add_flow_samples(self,
                         flow_samples: Optional[torch.Tensor] = None,
                         dim0: int = 0,
                         dim1: int = 1,
                         **figure_kwargs):
        if flow_samples is not None:
            self.file_writer.write_flow_samples(flow_samples=flow_samples, step=self._step)
            self.visualizer.add_figure(
                tag='flow_samples',
                x=flow_samples,
                step=self._step,
                dim0=dim0,
                dim1=dim1,
                **figure_kwargs
            )

    def add_particles(self,
                      dla_particles: Optional[torch.Tensor] = None,
                      dim0: int = 0,
                      dim1: int = 1,
                      **figure_kwargs):
        if dla_particles is not None:
            self.file_writer.write_dla_particles(particles=dla_particles, step=self._step)
            self.visualizer.add_figure(
                tag='dla_particles',
                x=dla_particles,
                step=self._step,
                dim0=dim0,
                dim1=dim1,
                **figure_kwargs
            )

    def animate(self, **kwargs):
        if self.animatable:
            self.visualizer.animate(
                dla_particles_dir=self.file_writer.dla_data_dir,
                flow_samples_dir=self.file_writer.flow_data_dir,
                output_dir=self.animation_dir,
                **kwargs
            )


class MultiStageDebugger:
    def __init__(self, stage: int = 0, directory: Optional[pathlib.Path] = None,
                 delete_existing: bool = False, tensorboard: bool = False, save_raw_data: bool = False,
                 show_figures: bool = False, save_figures: bool = False, animate: bool = False,
                 static_data: Optional[dict] = None):
        """
        DLA debugger with multiple beta stages.
        """
        self.current_stage = stage

        if directory is None:
            directory = 'runs' / pathlib.Path(str(datetime.now().strftime("%Y-%m-%d %H-%M-%S")))
        self._directory = directory
        self._delete_existing = delete_existing
        self._tensorboard = tensorboard
        self._save_raw_data = save_raw_data
        self._show_figures = show_figures
        self._save_figures = save_figures
        self._animate = animate
        self._static_data = static_data

        self.debugger = self.create_stage_debugger()

    def create_stage_debugger(self):
        return SingleStageDebugger(
            directory=self._directory / f'stage_{self.current_stage}',
            delete_existing=self._delete_existing,
            tensorboard=self._tensorboard,
            save_raw_data=self._save_raw_data,
            show_figures=self._show_figures,
            save_figures=self._save_figures,
            animate=self._animate,
            static_data=self._static_data
        )

    def stage_step(self):
        self.current_stage += 1
        self.debugger = self.create_stage_debugger()

    def reset(self):
        self.current_stage = 0
        self.debugger = self.create_stage_debugger()

    def animate(self, **kwargs):
        self.debugger.animate(**kwargs)

    def step(self):
        self.debugger.step()

    def add_scalar(self, *args, **kwargs):
        self.debugger.add_scalar(*args, **kwargs)

    def add_particles(self, *args, **kwargs):
        self.debugger.add_particles(*args, **kwargs)

    def add_flow_samples(self, *args, **kwargs):
        self.debugger.add_flow_samples(*args, **kwargs)
