# -*- coding: utf-8 -*-
# Copyright (C) 2020 Machine Learning Group of the University of Oldenburg.
# Licensed under the Academic Free License version 3.0

import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple
import os


def _adjust_axis_labels(
    ind_dim1: int,
    no_bins_dim1: int,
    no_bins_dim2: int,
    axarr: np.ndarray,
    ticklabelsize: int,
    axislabelsize: int,
    xlabel: str = None,
    ylabel: str = None,
):
    if ind_dim1 == (no_bins_dim1 - 1):
        axarr[ind_dim1].spines["right"].set_visible(False)
        axarr[ind_dim1].spines["top"].set_visible(False)
        axarr[ind_dim1].yaxis.set_ticks_position("left")
        axarr[ind_dim1].xaxis.set_ticks_position("bottom")
        for tick in axarr[ind_dim1].xaxis.get_major_ticks():
            tick.label.set_fontsize(ticklabelsize)
        major_y_ticks = axarr[ind_dim1].yaxis.get_major_ticks()
        for ind_tick, tick in enumerate(major_y_ticks):
            tick.label.set_fontsize(ticklabelsize)
        if xlabel is not None:
            plt.xlabel(xlabel, fontsize=axislabelsize)
        if ylabel is not None:
            plt.ylabel(ylabel, fontsize=axislabelsize)
    else:
        axarr[ind_dim1].axis("off")
    axarr[ind_dim1].set_xlim([0, no_bins_dim2])


def multichannel_waveform_plot(
    data: np.ndarray,
    to_highlight: np.ndarray = None,
    waveform_color: str = "k",
    highlighting_color: str = "b",
    xlabel: str = None,
    ylabel: str = None,
    figsize: Tuple = None,
    figname: str = None,
    output_file: str = None,
    close: bool = True,
    markersize: int = 1,
    ticklabelsize: int = 6,
    axislabelsize: int = 10,
):
    """Draw lineplots for entry along 1st dim of input data, i.e. every row of 2-dim. input.

    :param data: 2-dimensional array with input data
    :param to_highlight: boolean array indicating entries in `data` draw using `highlighting_color`
    :param waveform_color: color of
    :param highlighting_color:
    :param xlabel:
    :param ylabel:
    :param figsize:
    :param figname:
    :param output_file:
    :param close:
    :param markersize:
    :param ticklabelsize:
    :param axislabelsize:
    """

    # check inputs
    assert np.ndim(data) == 2, "data must be 2-dim."
    no_bins_dim1, no_bins_dim2 = data.shape

    if to_highlight is not None:
        assert data.shape == to_highlight.shape, "`data` and `to_highlight` must have same shape"
        assert to_highlight.dtype == np.bool, "`to_highlight` must be array of booleans"

    # each channel is plotted in a separate subplots
    f, axarr = plt.subplots(no_bins_dim1, 1, figsize=figsize)
    if figname is not None:
        f.canvas.manager.set_window_title(figname)
    axarr = axarr.ravel()
    spl = {}

    line_specs = {"linestyle": "none", "marker": ".", "markersize": markersize}

    for ind_dim1 in range(no_bins_dim1):

        # plot waveforms
        if to_highlight is None:
            spl[f"data{ind_dim1}"] = axarr[ind_dim1].plot(
                data[ind_dim1], color=waveform_color, **line_specs
            )
        else:
            x_to_highlight = np.where(to_highlight[ind_dim1])[0]
            y_to_highlight = data[ind_dim1][to_highlight[ind_dim1]]
            x_not_to_highlight = np.where(np.logical_not(to_highlight[ind_dim1]))[0]
            y_not_to_highlight = data[ind_dim1][np.logical_not(to_highlight[ind_dim1])]
            spl[f"data{ind_dim1}"] = axarr[ind_dim1].plot(
                x_not_to_highlight,
                y_not_to_highlight,
                color=waveform_color,
                **line_specs,
            )
            spl[f"highlighted{ind_dim1}"] = axarr[ind_dim1].plot(
                x_to_highlight, y_to_highlight, color=highlighting_color, **line_specs
            )

        # axis labels only in bottom subplot
        _adjust_axis_labels(
            ind_dim1,
            no_bins_dim1,
            no_bins_dim2,
            axarr,
            ticklabelsize,
            axislabelsize,
            xlabel,
            ylabel,
        )

    # save figure if file name was provided
    if output_file is not None:
        output_directory = os.path.split(output_file)[0]
        if not os.path.exists(output_directory):
            os.makedirs(output_directory)
        plt.savefig(output_file, bbox_inches="tight")
        print(f"Wrote {output_file}")
    if close:
        plt.close()


if __name__ == "__main__":
    target_data = np.random.random((30, 5000))
    to_highlight = target_data < 0.5
    kwargs = {
        "xlabel": "samples",
        "ylabel": "amplitude",
        "figsize": (20, 18),
        "ticklabelsize": 12,
        "axislabelsize": 10,
    }
    multichannel_waveform_plot(
        target_data,
        to_highlight,
        highlighting_color="g",
        output_file="./a.png",
        **kwargs,  # type: ignore
    )
    multichannel_waveform_plot(
        target_data,
        to_highlight,
        highlighting_color="w",
        output_file="./a_mod.png",
        **kwargs,  # type: ignore
    )
