# -*- coding: utf-8 -*-
# Copyright (C) 2020 Machine Learning Group of the University of Oldenburg.
# Licensed under the Academic Free License version 3.0

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from typing import Tuple
import os


def _adjust_axis_labels(
    ind_dim1: int,
    no_bins_dim1: int,
    no_bins_dim2: int,
    axarr: np.ndarray,
    ticklabelsize: int,
    axislabelsize: int,
    xlabel: str = None,
    ylabel: str = None,
):
    if ind_dim1 == (no_bins_dim1 - 1):
        axarr[ind_dim1].spines["right"].set_visible(False)
        axarr[ind_dim1].spines["top"].set_visible(False)
        axarr[ind_dim1].yaxis.set_ticks_position("left")
        axarr[ind_dim1].xaxis.set_ticks_position("bottom")
        for tick in axarr[ind_dim1].xaxis.get_major_ticks():
            tick.label.set_fontsize(ticklabelsize)
        major_y_ticks = axarr[ind_dim1].yaxis.get_major_ticks()
        for ind_tick, tick in enumerate(major_y_ticks):
            tick.label.set_fontsize(ticklabelsize)
        if xlabel is not None:
            plt.xlabel(xlabel, fontsize=axislabelsize)
        if ylabel is not None:
            plt.ylabel(ylabel, fontsize=axislabelsize)
    else:
        axarr[ind_dim1].axis("off")
    axarr[ind_dim1].set_xlim([0, no_bins_dim2])


class multichannel_waveform_animation(animation.TimedAnimation):
    def __init__(
        self,
        data: np.ndarray,
        to_highlight: np.ndarray,
        waveform_color: str = "k",
        highlighting_color: str = "b",
        xlabel: str = None,
        ylabel: str = None,
        figsize: Tuple = None,
        figname: str = None,
        output_file: str = None,
        close: bool = True,
        markersize: int = 1,
        ticklabelsize: int = 6,
        axislabelsize: int = 10,
        ms_per_frame: int = 200,
        repeat: bool = True,
    ):
        """Draw lineplots for entry along 1st dim of input data, i.e. every row of 2-dim. input.

        :param data: 3-dimensional array with input data
        :param to_highlight: boolean array indicating entries in `data` draw using
                             `highlighting_color`
        :param waveform_color: color of
        :param highlighting_color:
        :param xlabel:
        :param ylabel:
        :param figsize:
        :param figname:
        :param output_file:
        :param close:
        :param markersize:
        :param ticklabelsize:
        :param axislabelsize:
        :param ms_per_frame:
        :param repeat:
        """

        # check inputs
        assert np.ndim(data) == 3, "data must be 3-dim."
        assert np.ndim(to_highlight) == 2, "to_highlight must be 2-dim."
        no_steps, no_bins_dim1, no_bins_dim2 = data.shape

        assert (
            data.shape[1:] == to_highlight.shape
        ), "`data` and `to_highlight` must have same shape"
        assert to_highlight.dtype == np.bool, "`to_highlight` must be array of booleans"

        # each channel is plotted in a separate subplots
        f, axarr = plt.subplots(no_bins_dim1, 1, figsize=figsize)
        if figname is not None:
            f.canvas.manager.set_window_title(figname)
        axarr = axarr.ravel()
        self.spl = {}

        self.curr_step = 0
        _data = data[self.curr_step]

        line_specs = {"linestyle": "none", "marker": ".", "markersize": markersize}

        for ind_dim1 in range(no_bins_dim1):

            # plot waveforms
            x_to_highlight = np.where(to_highlight[ind_dim1])[0]
            y_to_highlight = _data[ind_dim1][to_highlight[ind_dim1]]
            x_not_to_highlight = np.where(np.logical_not(to_highlight[ind_dim1]))[0]
            y_not_to_highlight = _data[ind_dim1][np.logical_not(to_highlight[ind_dim1])]
            (self.spl[f"data{ind_dim1}"],) = axarr[ind_dim1].plot(
                x_not_to_highlight,
                y_not_to_highlight,
                color=waveform_color,
                **line_specs,
            )
            (self.spl[f"highlighted{ind_dim1}"],) = axarr[ind_dim1].plot(
                x_to_highlight, y_to_highlight, color=highlighting_color, **line_specs
            )

            # axis labels only in bottom subplot
            _adjust_axis_labels(
                ind_dim1,
                no_bins_dim1,
                no_bins_dim2,
                axarr,
                ticklabelsize,
                axislabelsize,
                xlabel,
                ylabel,
            )

        self.data = data
        self.to_highlight = to_highlight
        self.no_steps = no_steps
        self.f = f

        self.f.tight_layout()

        animation.TimedAnimation.__init__(
            self, self.f, interval=ms_per_frame, blit=True, repeat=repeat
        )

        # save figure if file name was provided
        if output_file is not None:
            output_directory = os.path.split(output_file)[0]
            if not os.path.exists(output_directory):
                os.makedirs(output_directory)
            animation.TimedAnimation.save(self, output_file, writer="ffmpeg")
            print(f"Wrote {output_file}")
        if close:
            plt.close()

    def new_frame_seq(self):
        return iter(range(self.no_steps - 1))

    def _draw_frame(self, framedata):
        self.curr_step += 1
        if self.curr_step == self.no_steps:
            self.curr_step = 1
        _data = self.data[self.curr_step]
        for ind_dim1 in range(_data.shape[0]):
            y_to_highlight = _data[ind_dim1][self.to_highlight[ind_dim1]]
            self.spl[f"highlighted{ind_dim1}"].set_ydata(y_to_highlight)


if __name__ == "__main__":
    target_data = np.random.random((5, 30, 5000))
    to_highlight = target_data[0] < 0.5
    kwargs = {
        "xlabel": "samples",
        "ylabel": "amplitude",
        "figsize": (20, 18),
        "ticklabelsize": 12,
        "axislabelsize": 10,
    }
    a = multichannel_waveform_animation(
        target_data,
        to_highlight,
        highlighting_color="g",
        output_file="./a.mp4",
        **kwargs,  # type: ignore
    )
