# -*- coding: utf-8 -*-
"""
Plotting functions
"""
import os
import numpy as np
import pandas as pd
from sklearn import decomposition, preprocessing
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
from cycler import cycler

from conn2res import utils
from .readout import _check_xy_type, _check_x_dims, _check_y_dims

PROJ_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
FIG_DIR = os.path.join(PROJ_DIR, 'figs')
if not os.path.isdir(FIG_DIR):
    os.makedirs(FIG_DIR)


_figs_params_docs = dict(
    rc_params="""\
rc_params : dict, optional
    Dictionary of keyword arguments for `matplotlib.rcParams <https://matplotlib.org/stable/api/matplotlib_configuration_api.html#matplotlib.rcParams>`__.\
""",
    fig_params="""\
fig_params : dict, optional
    Dictionary of keyword arguments for `matplotlib.pyplot.figure <https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.figure.html>`__.
    Values to set figure properties.\
""",
    ax_params="""\
ax_params : dict, optional
    Dictionary of keyword arguments for `matplotlib.pyplot.axes <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.html#matplotlib.axes.Axes`__.
    Values to set axes's properties.\
""",
    lg_params="""\
lg_params : dict, optional
    Dictionary of keyword arguments for `matplotlib.axes.Axes.legend <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.legend.html`__.
    Values to set legend's properties.\
""",
    col_params="""\
col_params : dict, optional
    Dictionary of keyword arguments for `seaborn.color_palette <https://seaborn.pydata.org/generated/seaborn.color_palette.html`__.
    Values to set color settings.\
""",
    title="""\
title : str, optional
    Title to be shown at the superior part of the figure.
""",
    show="""\
show : bool, optional
    If True, it will display the matplotlib.pyplot.figure object
""",
    savefig="""\
savefig : bool, optional
    If True, it will save the matploblib.pyplot.figure object as a '.png' file by default.
    The format of the file can be changed using the 'savefig.format'
    keyword in the rc_params argument.
""",
    fname="""\
fname : str or path-like
    Path where the figure will be saved.
""",
    kwargs="""\
kwargs : key-value pairs
    Other keyword arguments pass directly to the underlying seaborn plotting function.
"""
)


def transform_data(
    data, feature_set=None, idx_features=None, n_features=None, scaler=None,
    model=None, seed=None, **kwargs
):
    """
    #TODO
    _summary_

    Parameters
    ----------
    data : _type_
        _description_
    feature_set : _type_
        _description_
    idx_features : _type_, optional
        _description_, by default None
    n_features : _type_, optional
        _description_, by default None
    scaler : _type_, optional
        _description_, by default None
    model : _type_, optional
        _description_, by default None
    seed : int, array_like[ints], SeedSequence, BitGenerator, Generator, optional
        seed to initialize the random number generator, by default None
        for details, see numpy.random.default_rng()

    Returns
    -------
    _type_
        _description_
    """
    if feature_set == 'pca':
        # transform data into principal components
        pca = decomposition.PCA(n_components=n_features)
        data = pca.fit_transform(preprocessing.scale(data), **kwargs)  # zscore to remove bias due to different scales

    elif feature_set == 'rnd':
        # update default number of features
        if n_features is None:
            n_features = 1

        # use random number generator for reproducibility
        rng = np.random.default_rng(seed=seed)

        # choose features randomly
        data = data[:, rng.choice(
            np.arange(data.shape[1]), size=n_features)]

    elif feature_set == 'decfun':
        # calculate decision function using model fitted on time series
        data = model.decision_function(data)

    elif feature_set == 'pred':
        # calculate predicted labels
        data = model.predict(data)[:, np.newaxis]

    elif feature_set in ['coeff', 'coeff_sum']:
        # update default number of features
        if n_features is None:
            n_features = 5

        # get coefficient from model
        if model.coef_.ndim > 1:
            idx_class = kwargs.get('idx_class', 0)
            coef = model.coef_[idx_class, :]
        else:
            coef = model.coef_

        # choose features that correspond to largest absolute coefficients
        idx_coef = np.argsort(np.absolute(coef))
        if sum(coef != 0) > n_features:
            # use top features
            idx_coef = idx_coef[-1*n_features:]
        else:
            # use non-zero features
            idx_coef = np.intersect1d(idx_coef, np.where(coef != 0)[0])

        # scale time series with coefficients
        data = data[:, idx_coef]
        if data.size > 0:
            data = data @ np.diag(coef[idx_coef])

    # select given features
    if idx_features is not None:
        data = data[:, idx_features]

    # sum features
    if feature_set == 'coeff_sum':
        data = np.sum(data, axis=1).reshape(-1, 1)

    # scale features
    if scaler is not None:
        # scalers provided by sklearn.preprocessing
        if hasattr(preprocessing, scaler):
            func = getattr(preprocessing, scaler)
            data = func(data, **kwargs)

        # scalers provided by numpy.linalg
        elif scaler == 'l1-norm':
            data /= np.linalg.norm(data, ord=1, axis=0)
        if scaler == 'l2-norm':
            data /= np.linalg.norm(data, ord=2, axis=0)
        elif scaler == 'max':
            data /= np.linalg.norm(data, ord=np.inf, axis=0)
        elif isinstance(scaler, int):
            data /= np.array([int])

    return data


def plot_iodata(
    x, y, n_trials=7, palette=None,
    rc_params={}, fig_params={}, ax_params={}, lg_params={},
    title=None, show=True, savefig=False, fname='io_data', **kwargs
):
    """
    Plot input (x) and output (y) data.

    Parameters
    ----------
    x : _type_
        _description_
    y : _type_
        _description_
    n_trials : _type_, optional
        _description_, by default 7
    palette : _type_, optional
        _description_, by default None
    {rc_params}
    {fig_params}
    {ax_params}
    {lg_params}
    {title}
    {show}
    {savefig}
    {fname}
    {kwargs}
    """
    # get end points for trials to plot trial separators
    if isinstance(x, list):
        n_trials = np.min([len(x), 10])
        x = x[:n_trials]
        y = y[:n_trials]

        tf, end_points = 0, []
        for _, trial in enumerate(x):
            tf += len(trial)
            end_points.append(tf)
    else:
        end_points = None

    # check X and y are arrays
    x, y = _check_xy_type(x, y)

    # check X and y dimensions
    x = _check_x_dims(x)
    y = _check_y_dims(y)

    # set plotting theme
    rc_defaults = {'figure.titlesize': 12, 'axes.labelsize': 11,
                   'xtick.labelsize': 11, 'ytick.labelsize': 11,
                   'legend.fontsize': 8, 'legend.loc': 'best',
                   'lines.linewidth': 1, 'savefig.format': 'png'}
    rc_defaults.update(rc_params)
    sns.set_theme(style='ticks', rc=rc_defaults)

    # open figure and axes
    fig_defaults = {'figsize': (12, 2)}  # 12, 4.5
    fig_defaults.update(fig_params)
    fig = plt.figure(**fig_defaults)
    ax = fig.subplots(1, 1)

    # plot inputs (x) and outputs (y)
    sns.lineplot(
        data=x, palette=palette, dashes=False, legend=False, ax=ax,
        **kwargs
    )
    sns.lineplot(
        data=y, palette=palette, dashes=False, legend=False, ax=ax,
        linewidth=1.5, **kwargs
    )

    # set legend
    x_labels = ['x'] if x.ndim == 1 else [f'x{n+1}' for n in range(x.shape[1])]
    y_labels = ['y'] if y.ndim == 1 else [f'y{n+1}' for n in range(y.shape[1])]
    lg_defaults = {'labels': x_labels + y_labels}
    lg_defaults.update(**lg_params)
    ax.legend(handles=ax.lines, **lg_defaults)

    # set axes properties
    ax_defaults = {'xlabel': 'time steps', 'ylabel': 'signal amplitude',
                   'xlim': [0, 200]}
    ax_defaults.update(**ax_params)
    ax.set(**ax_defaults)

    # plot trial line separators
    if end_points is not None:
        min_y = np.min(y).astype(int)
        max_y = np.max(y).astype(int)
        for tf in end_points:
            ax.plot(
                tf * np.ones((2)), np.array([min_y, max_y]), c='black',
                linestyle='--'
            )

    # set title
    if title is not None:
        fig.suptitle(title)

    sns.despine(offset=10, trim=True,
                top=True, bottom=False,
                right=True, left=False)

    if show:
        plt.show(block=True)

    if savefig:
        fig.savefig(fname + '.' + mpl.rcParams['savefig.format'],
                    transparent=True, bbox_inches='tight', dpi=300)

    plt.close()

    # reset rc defaults
    mpl.rcdefaults()


def plot_reservoir_states(
    x, reservoir_states, n_trials=7, palette=None,
    rc_params={}, fig_params={}, ax_params=[{}] * 2, lg_params={},
    title=None, show=True, savefig=False, fname='res_states', **kwargs
):
    """
    Plot simulated reservoir states.

    Parameters
    ----------
    x : _type_
        _description_
    reservoir_states : _type_
        _description_
    n_trials : int, optional
        _description_, by default 7
    palette : _type_, optional
        _description_, by default None
    {rc_params}
    {fig_params}
    ax_params : list of dict
        list of dictionaries with keyword arguments for `matplotlib.pyplot.axes <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.html#matplotlib.axes.Axes`__.
        Values to set axes's properties, by default [{}] * 2.
    {lg_params}
    {title}
    {show}
    {savefig}
    {fname}
    {kwargs}
    """
    # get end points for trials to plot trial separators
    if isinstance(reservoir_states, list):
        n_trials = np.min([len(x), 10])
        x = x[-n_trials:]
        reservoir_states = reservoir_states[-n_trials:]

        tf, end_points = 0, []
        for _, trial in enumerate(x):
            tf += len(trial)
            end_points.append(tf)
    else:
        end_points = None

    # check X is array
    x, _ = _check_xy_type(x, None)

    # check reservoir_states is array
    if isinstance(reservoir_states, (list, tuple)):
        reservoir_states = utils.concat(reservoir_states)

    # check X dimensions
    x = _check_x_dims(x)

    # set plotting theme
    rc_defaults = {'figure.titlesize': 12, 'axes.labelsize': 11,
                   'xtick.labelsize': 11, 'ytick.labelsize': 11,
                   'legend.fontsize': 8, 'legend.loc': 'best',
                   'lines.linewidth': 1, 'savefig.format': 'png'}
    rc_defaults.update(rc_params)
    sns.set_theme(style='ticks', rc=rc_defaults)

    # open figure and axes
    fig_defaults = {'figsize': (12, 4), 'layout': 'tight'}
    fig_defaults.update(fig_params)
    fig = plt.figure(**fig_defaults)
    axs = fig.subplots(2, 1, sharex=True)
    axs = axs.ravel()

    fig.subplots_adjust(wspace=0.1)

    # plot inputs (x) and reservoir states
    sns.lineplot(
        data=x, palette=palette, dashes=False, legend=False, ax=axs[0],
        **kwargs
    )

    palette = sns.color_palette("tab10", reservoir_states.shape[1])
    # reservoir_states = transform_data(
    #     transform_data(reservoir_states, scaler='scale', with_std=False),
    #     scaler='minmax_scale', feature_range=(-1, 1))
    sns.lineplot(
        data=reservoir_states, palette=palette, dashes=False, legend=False,
        linewidth=0.5, ax=axs[1], **kwargs
    )

    # set legend
    x_labels = ['x'] if x.ndim == 1 else [f'x{n+1}' for n in range(x.shape[1])]
    lg_defaults = {'labels': x_labels}
    lg_defaults.update(**lg_params)
    axs[0].legend(handles=axs[0].lines, **lg_defaults)

    yabs_max = abs(max(axs[1].get_ylim(), key=abs))
    axs[1].set_ylim(ymin=-yabs_max, ymax=yabs_max)

    # set axes properties
    xlabel = ['', 'time steps']
    ylabel = ['x signal \namplitude', 'reservoir \nstates']
    for i, ax in enumerate(axs):
        ax_defaults = {'xlim': [0, 200], 'xlabel': xlabel[i],
                       'ylabel': ylabel[i]}
        ax_defaults.update(**ax_params[i])
        ax.set(**ax_defaults)

    # plot trial line separators
    if end_points is not None:
        min_x = np.min(x).astype(int)
        max_x = np.max(x).astype(int)
        min_res_states = np.min(reservoir_states).astype(int)
        max_res_states = np.max(reservoir_states).astype(int)
        for tf in end_points:
            axs[0].plot(
                tf * np.ones((2)), np.array([min_x, max_x]), c='black',
                linestyle='--',
            )
            axs[1].plot(
                tf * np.ones((2)), np.array([min_res_states, max_res_states]),
                c='black', linestyle='--',
            )

    # set title
    if title is not None:
        fig.suptitle(title)

    sns.despine(offset=10, trim=True,
                top=True, bottom=False,
                right=True, left=False)

    if show:
        plt.show(block=True)

    if savefig:
        fig.savefig(fname + '.' + mpl.rcParams['savefig.format'],
                    transparent=True, bbox_inches='tight', dpi=300)

    plt.close()

    # reset rc defaults
    mpl.rcdefaults()


def plot_diagnostics(
    x, y, reservoir_states, trained_model, idx_features=None,
    n_features=None, scaler=None, palette=None,
    rc_params={}, fig_params={}, ax_params=[{}] * 3, lg_params=[{}] * 3,
    title=None, show=True, savefig=False, fname='diagnostics_curve', **kwargs
):
    """
    Plot decision function of readout module.
    Worksvonly if 'trained_model' is a classifier.

    Parameters
    ----------
    x : _type_
        _description_
    y : _type_
        _description_
    reservoir_states : _type_
        _description_
    trained_model : _type_
        _description_
    idx_features : _type_, optional
        _description_, by default None
    n_features : _type_, optional
        _description_, by default None
    scaler : _type_, optional
        _description_, by default None
    palette : _type_, optional
        _description_, by default None
    {rc_params}
    {fig_params}
    ax_params : list of dict
        list of dictionaries with keyword arguments for `matplotlib.pyplot.axes <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.html#matplotlib.axes.Axes`__.
        Values to set axes's properties, by default [{}] * 3.
    lg_params : list of dict
        list of dictionaries with keyword arguments for `matplotlib.axes.Axes.legend <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.legend.html`__. 
        Values to set legend's properties, by default [{}] * 3.
    {title}
    {show}
    {savefig}
    {fname}
    kwargs : key-value pairs
        Other keyword arguments are passed to the `~conn2res.plotting.transform_data` function.
    """
    # check X and y are arrays
    x, y = _check_xy_type(x, y)

    # check reservoir_states is an array
    if isinstance(reservoir_states, (list, tuple)):
        reservoir_states = utils.concat(reservoir_states)

    # check X and y dimensions
    x = _check_x_dims(x)
    y = _check_y_dims(y)

    # transform data
    x_trans = transform_data(
        x, feature_set='data', idx_features=idx_features,
        n_features=n_features, scaler=scaler, **kwargs
    )

    dec_func = transform_data(
        reservoir_states, feature_set='decfun', idx_features=idx_features,
        n_features=n_features, scaler=scaler, model=trained_model, **kwargs
    )

    y_trans = transform_data(
        y, feature_set='data', idx_features=idx_features,
        n_features=n_features, scaler=scaler, **kwargs
    )

    y_pred = transform_data(
        reservoir_states, feature_set='pred', idx_features=idx_features,
        n_features=n_features, scaler=scaler, model=trained_model, **kwargs
    )

    # set plotting theme
    rc_defaults = {'figure.titlesize': 12, 'axes.labelsize': 11,
                   'xtick.labelsize': 11, 'ytick.labelsize': 11,
                   'legend.fontsize': 8, 'legend.loc': 'upper right',
                   'lines.linewidth': 1, 'savefig.format': 'png'}
    rc_defaults.update(rc_params)
    sns.set_theme(style='ticks', rc=rc_defaults)

    # open figure and axes
    fig_defaults = {'figsize': (12, 6), 'layout': 'tight'}
    fig_defaults.update(fig_params)
    fig = plt.figure(**fig_defaults)
    axs = fig.subplots(3, 1, sharex=True)
    axs = axs.ravel()

    fig.subplots_adjust(wspace=0.1)

    # plot
    data = [x_trans, dec_func, y_trans]
    for i, ax in enumerate(axs):
        sns.lineplot(
            data=data[i][:160], palette=palette, dashes=False,
            legend=False, ax=ax)

    if y_pred.ndim:
        n_colors = 1
    else:
        n_colors = y_pred.shape[1]
    palette = sns.color_palette("tab10", n_colors+1)[1:]
    sns.lineplot(
        data=y_pred[:160], palette=palette, dashes=False, legend=False,
        ax=axs[2], linewidth=1.5)

    # set legend
    labels = [['x'] if x.ndim == 1 else [f'x{n+1}' for n in range(x.shape[1])],
              ['decision function'] if dec_func.ndim == 1 else [f'decision function {n+1}' for n in range(dec_func.shape[1])],
              ['target', 'predicted target']]
    for i, ax in enumerate(axs):
        lg_defaults = {'labels': labels[i]}
        lg_defaults.update(**lg_params[i])
        ax.legend(handles=ax.lines, **lg_defaults)

    # set axes properties
    xlabel = ['', '', 'time steps']
    ylabel = ['x signal \namplitude', 'decision \nfunction', 'y signal \namplitude']
    for i, ax in enumerate(axs):
        ax_defaults = {'xlim': [0, 160], 'xlabel': xlabel[i], 'ylabel': ylabel[i]}
        ax_defaults.update(**ax_params[i])
        ax.set(**ax_defaults)

    # set title
    if title is not None:
        fig.suptitle(title)

    sns.despine(offset=10, trim=False,
                top=True, bottom=False,
                right=True, left=False)

    if show:
        plt.show(block=True)

    if savefig:
        fig.savefig(fname=fname + '.' + mpl.rcParams['savefig.format'],
                    transparent=True, bbox_inches='tight', dpi=300)

    plt.close()

    # reset rc defaults
    mpl.rcdefaults()


def plot_performance(
    df, x='alpha', y='score', normalize=False, hue=None,
    rc_params={}, fig_params={}, ax_params={}, lg_params={}, col_params={},
    title=None, show=True, savefig=False, fname='performance_curve', **kwargs
):
    """
    Plot performance curve.

    Parameters
    ----------
    df : _type_
        _description_
    x : str, optional
        _description_, by default 'alpha'
    y : str, optional
        _description_, by default 'score'
    normalize : bool, optional
        _description_, by default False
    hue : optional
        _description_, by default None
    {rc_params}
    {fig_params}
    {ax_params}
    {lg_params}
    {col_params}
    {title}
    {show}
    {savefig}
    {fname}
    {kwargs}
    """
    if normalize:
        df[y] = df[y] / max(df[y])

    # set plotting theme
    rc_defaults = {'figure.titlesize': 12, 'axes.labelsize': 11,
                   'xtick.labelsize': 11, 'ytick.labelsize': 11,
                   'legend.fontsize': 8, 'legend.loc': 'upper right',
                   'savefig.format': 'png'}
    rc_defaults.update(rc_params)
    sns.set_theme(style='ticks', rc=rc_defaults)

    # open figure and axes
    fig_defaults = {'figsize': (6, 2)}
    fig_defaults.update(fig_params)
    fig = plt.figure(**fig_defaults)
    ax = fig.subplots(1, 1)

    # set color palette
    col_defaults = {'palette': 'husl'}
    if hue is not None:
        col_defaults.update(n_colors=len(pd.unique(df[hue])))
    col_defaults.update(**col_params)
    palette = sns.color_palette(**col_defaults)

    # plot
    sns.lineplot(
        data=df, x=x, y=y, hue=hue, palette=palette, dashes=False,
        legend=False, markers=True, ax=ax, **kwargs)

    # set legend
    if hue is not None:
        try:
            lg_defaults = {'labels': kwargs['hue_order']}
        except KeyError:
            lg_defaults = {'labels': list(pd.unique(df[hue]))}
        lg_defaults.update(**lg_params)
        ax.legend(handles=ax.lines, **lg_defaults)

    # set axes properties
    axes_defaults = {'xlabel': x, 'ylabel': ' '.join(y.split('_'))}
    axes_defaults.update(**ax_params)
    ax.set(**axes_defaults)

    # set title
    if title is not None:
        fig.suptitle(title)

    sns.despine(offset=10, trim=True,
                top=True, bottom=False,
                right=True, left=False)

    if show:
        plt.show(block=True)

    if savefig:
        fig.savefig(fname=fname + '.' + mpl.rcParams['savefig.format'],
                    transparent=True, bbox_inches='tight', dpi=300)

    plt.close()

    # reset rc defaults
    mpl.rcdefaults()


def plot_phase_space(
    x, y, sample=None, palette=None,
    rc_params={}, fig_params={}, ax_params={},
    title=None, show=False, savefig=False, fname='phase_space'
):
    """
    Plot phase space diagram

    Parameters
    ----------
    x : _type_
        _description_
    y : _type_
        _description_
    sample : _type_, optional
        _description_, by default None
    palette : _type_, optional
        _description_, by default None
    {rc_params}
    {fig_params}
    {ax_params}
    {title}
    {show}
    {savefig}
    {fname}
    """
    # time steps
    if sample is None:
        t = np.arange(x.shape[0])
    else:
        t = np.arange(*sample)

    # set plotting theme
    rc_defaults = {'figure.titlesize': 12, 'axes.labelsize': 11,
                   'xtick.labelsize': 11, 'ytick.labelsize': 11,
                   'lines.linewidth': 1, 'savefig.format': 'png'}
    if palette is not None:
        # set cycler for color to change as a function of time step
        rc_defaults['axes.prop_cycle'] = cycler(color=sns.color_palette(palette, t.size-1))
    rc_defaults.update(rc_params)
    sns.set_theme(style='ticks', rc=rc_defaults)

    # open figure and axes
    fig_defaults = {'figsize': (4, 4)}
    fig_defaults.update(fig_params)
    fig = plt.figure(**fig_defaults)
    ax = fig.subplots(1, 1)

    # plot data (these plots are easier with matplotlib)
    if palette is None:
        ax.plot(x[t], y[t])
    else:
        for i in range(t.size-1):
            ax.plot(x[t[i:i+2]], y[t[i:i+2]])

    # set axis properties
    axes_defaults = {'xlim': [0.2, 1.4], 'ylim': [0.2, 1.4]}
    axes_defaults.update(**ax_params)
    ax.set(**axes_defaults)

    # set title
    if title is not None:
        fig.suptitle(title)

    sns.despine(offset=10, trim=False,
                top=True, bottom=False,
                right=True, left=False)

    if show:
        plt.show(block=True)

    if savefig:
        fig.savefig(fname=fname + '.' + mpl.rcParams['savefig.format'],
                    transparent=True, bbox_inches='tight', dpi=300)
       
        plt.close()
   
    # reset rc defaults
    mpl.rcdefaults()

def plot_spike_raster(tspike, x1, x2, title = "Spike Raster"):
    """
    Plot a spike raster plot.

    Parameters
    ----------
    tspike : (N_spikes, 2) numpy.ndarray
        spike times (in s)
        N_spikes: number of spikes
        tspike[:, 0]: spike neuronal indices
        tspike[:, 1]: spike times (in s)
    x1 : float
        start time (in s)
    x2 : float
        end time (in s)
    title : str, optional
        title of the plot, by default "Spike Raster"
    """

    nneurons = int(np.max(np.unique(tspike[:, 0])) + 1)

    plt.figure(figsize=(max((x2 - x1)/0.1*10, 10), max(.02*nneurons, 1)))
    plt.title(title)
    plt.xlabel("Time (ms)")
    plt.ylabel("Neuron")

    spike_times = []
    for neuron in range(nneurons):
        idx = np.where(tspike[:, 0] == neuron)[0]
        spike_times.append(tspike[:, 1][idx]*1000)
    
    for neuron in range(nneurons):
        spike_train = spike_times[neuron]
        plt.scatter(spike_train, [neuron] * len(spike_train), marker='|', color='black')

    plt.xlim(x1*1000, x2*1000)
    plt.ylim(-0.5, nneurons - 0.5)
    plt.gca().invert_yaxis()
    plt.grid(True, linestyle='--', alpha = 0.7)
    plt.show()

def plot_membrane_voltages(membrane_voltages, x1, x2, neuron_idx = None, 
                           dt = 0.05, title="Membrane Voltages"):
    """
    Plot the membrane voltages of the neurons.
    
    Parameters
    ----------
    membrane_voltages : (nt, N) numpy.ndarray
        membrane voltage tracings (mV)
        nt: number of time steps
        N: number of nodes
    x1 : float
        start timestep
    x2 : float
        end timestep
    neuron_idx : numpy.ndarray, optional
        indices of neurons to plot
        Default: None
    dt : float, optional
        sampling rate (in s)
        Default: 0.05
    title : str, optional
        title of the plot, by default "Membrane Voltages"
    """
    
    if neuron_idx is None:
        neuron_idx = np.arange(membrane_voltages.shape[1])
    nneurons = len(neuron_idx)

    time_arr = np.arange(x1, x2)*dt

    fig, axes = plt.subplots(nneurons, 1, figsize=(max(10, (x2 - x1)/1000*10), nneurons), sharex=True, sharey=False)
    if nneurons == 1:
        axes = [axes]
    for idx, ax in enumerate(axes):
        ax.plot(time_arr, membrane_voltages[x1:x2, neuron_idx[idx]], c = 'k')
        ax.set_ylabel(f"Neuron {neuron_idx[idx]}")
        ax.grid(True)

    fig.suptitle(title, y = 1.03)
    fig.supylabel('Voltage (mV)', x = -0.03)
    fig.supxlabel('Time (ms)', y = -0.03)
    fig.tight_layout()
    plt.show()