import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mpl_toolkits.axes_grid1 import make_axes_locatable


def exp_basis(decay: float, window_size: int, time_span: float):
    """Exponential decay basis.

    \\phi(t) = \\beta exp(-\\beta t)

    Parameters
    ----------
    decay : float
        Decay parameter.
    window_size : int
        Number of time bins descretized.
    time_span : float
        Max influence time span.

    Returns
    -------
    basis : ndarray of shape (window_size,)
        Descretized basis.
    """

    basis = torch.zeros(window_size)
    dt = time_span / window_size
    t = torch.linspace(dt, time_span, window_size)
    basis = torch.exp(-decay * t)
    basis /= dt * basis.sum(dim=0)  # normalization
    return basis


def convolve_spikes_with_basis(
    spikes_list: torch.Tensor,
    basis: torch.Tensor,
) -> torch.Tensor:
    """Convolve soft spike train soft_spikes_list[:, :, j] with a single basis.

    Parameters
    ----------
    spikes_list : torch.Tensor of shape (n_samples, n_time_bins, n_neurons)
        Spike train. The values can be continuous that are from soft spike train.
    basis : torch.Tensor of shape (window_size,)
        Descretized basis.

    Returns
    -------
    convolved_spikes_list : torch.Tensor of shape (n_samples, n_time_bins, n_neurons)
        Convolved spike train.
    """

    window_size = len(basis)
    n_samples, n_time_bins, n_neurons = spikes_list.shape

    convolved_spikes_list = torch.zeros_like(spikes_list)
    padded_spikes_list = torch.cat(
        (
            torch.zeros((n_samples, window_size, n_neurons), device=spikes_list.device),
            spikes_list,
        ),
        dim=-2,
    )
    for i in range(window_size):
        convolved_spikes_list = (
            convolved_spikes_list
            + basis[-(i + 1)] * padded_spikes_list[:, i : n_time_bins + i, :]
        )
    return convolved_spikes_list


def logcosh(x: torch.Tensor) -> torch.Tensor:
    """ln((e^x + e^{-x}) / 2) via logsumexp.

    Parameters
    ----------
    x : torch.Tensor of shape (*,)
        Input tensor

    Returns
    -------
    result : torch.Tensor of shape (*,)
        Output tensor
    """

    result = torch.zeros([2] + list(x.shape), device=x.device)
    result[0] = x
    result[1] = -x

    result = torch.logsumexp(result, dim=0) - np.log(2)
    return result


def normal_log_prob(z: torch.Tensor) -> torch.Tensor:
    """Calculate the log probability of a normal prior distribution.

    Parameters
    ----------
    z : torch.Tensor of shape (*, n_components)
        Latent variable.

    Returns
    -------
    log_prob : torch.Tensor of shape (*, n_comopnents)
        The log probability of the normal prior distribution.
    """
    return -F.gaussian_nll_loss(
        torch.zeros_like(z), z, torch.ones_like(z), full=True, reduction="none"
    )


def logcosh_log_prob(z: torch.Tensor) -> torch.Tensor:
    """Calculate the log probability of a logcosh prior distribution.

    Parameters
    ----------
    z : torch.Tensor of shape (*, n_components)
        Latent variable.

    Returns
    -------
    log_prob : torch.Tensor of shape (*, n_components)
        The log probability of the logcosh prior distribution.
    """
    return np.log(np.pi) - 2 * logcosh(np.pi * z / 2 / np.sqrt(3)) - np.log(4 * np.sqrt(3))
