import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import scipy
import torch
import torch.nn as nn
import torch.nn.functional as F
from mpl_toolkits.axes_grid1 import make_axes_locatable
from numpy.typing import NDArray


def exp_basis(decay: float, window_size: int, time_span: float):
    """Exponential decay basis.

    \\phi(t) = \\beta exp(-\\beta t)

    Parameters
    ----------
    decay : float
        Decay parameter.
    window_size : int
        Number of time bins descretized.
    time_span : float
        Max influence time span.

    Returns
    -------
    basis : ndarray of shape (window_size,)
        Descretized basis.
    """

    basis = torch.zeros(window_size)
    dt = time_span / window_size
    t = torch.linspace(dt, time_span, window_size)
    basis = torch.exp(-decay * t)
    basis /= dt * basis.sum(dim=0)  # normalization
    return basis


def convolve_spikes_with_basis(
    spikes_list: torch.Tensor,
    basis: torch.Tensor,
) -> torch.Tensor:
    """Convolve soft spike train soft_spikes_list[:, :, j] with a single basis.

    Parameters
    ----------
    spikes_list : torch.Tensor of shape (n_samples, n_time_bins, n_neurons)
        Spike train. The values can be continuous that are from soft spike train.
    basis : torch.Tensor of shape (window_size,)
        Descretized basis.

    Returns
    -------
    convolved_spikes_list : torch.Tensor of shape (n_samples, n_time_bins, n_neurons)
        Convolved spike train.
    """

    window_size = len(basis)
    n_samples, n_time_bins, n_neurons = spikes_list.shape

    convolved_spikes_list = torch.zeros_like(spikes_list)
    padded_spikes_list = torch.cat(
        (
            torch.zeros((n_samples, window_size, n_neurons), device=spikes_list.device),
            spikes_list,
        ),
        dim=-2,
    )
    for i in range(window_size):
        convolved_spikes_list = (
            convolved_spikes_list
            + basis[-(i + 1)] * padded_spikes_list[:, i : n_time_bins + i, :]
        )
    return convolved_spikes_list


def logcosh(x: torch.Tensor) -> torch.Tensor:
    """ln((e^x + e^{-x}) / 2) via logsumexp.

    Parameters
    ----------
    x : torch.Tensor of shape (*,)
        Input tensor

    Returns
    -------
    result : torch.Tensor of shape (*,)
        Output tensor
    """

    result = torch.zeros([2] + list(x.shape), device=x.device)
    result[0] = x
    result[1] = -x

    result = torch.logsumexp(result, dim=0) - np.log(2)
    return result


def normal_log_prob(z: torch.Tensor) -> torch.Tensor:
    """Calculate the log probability of a normal prior distribution.

    Parameters
    ----------
    z : torch.Tensor of shape (*, n_components)
        Latent variable.

    Returns
    -------
    log_prob : torch.Tensor of shape (*, n_comopnents)
        The log probability of the normal prior distribution.
    """
    return -F.gaussian_nll_loss(
        torch.zeros_like(z), z, torch.ones_like(z), full=True, reduction="none"
    )


def logcosh_log_prob(z: torch.Tensor) -> torch.Tensor:
    """Calculate the log probability of a logcosh prior distribution.

    Parameters
    ----------
    z : torch.Tensor of shape (*, n_components)
        Latent variable.

    Returns
    -------
    log_prob : torch.Tensor of shape (*, n_components)
        The log probability of the logcosh prior distribution.
    """
    return (
        np.log(np.pi) - 2 * logcosh(np.pi * z / 2 / np.sqrt(3)) - np.log(4 * np.sqrt(3))
    )


def nestedlp_log_prob(z: torch.Tensor, p0: float, p: NDArray) -> torch.Tensor:
    """Calculate the log probability of a nested L^p prior distribution.

    Parameters
    ----------
    z : torch.Tensor of shape (*, n_components)
        Latent variable.
    p0 : float
        The power of the root node.
    p : NDArray of shape (n_groups,)
        The power of the nested nodes.

    Returns
    -------
    log_prob : torch.Tensor of shape (*, n_components)
        The log probability of the logcosh prior distribution.
    """

    n_components = z.shape[-1]
    n_groups = len(p)
    group_rank = n_components // n_groups
    v0 = (
        z.reshape(z.shape[:-1], n_groups, group_rank)
        .abs()
        .pow(p[:, None])
        .sum(dim=-1)  # (*, n_groups)
        .pow(p0 / p)
        .sum(dim=-1)
        .pow(1 / p0)
    )  # (*,)
    log_v0 = v0.log()

    log_psi_0 = (
        np.log(p0)
        + (n_components - 1) * log_v0
        - torch.lgamma(torch.tensor(n_components / p0))
    ) - v0.pow(p0)

    log_Sf1 = np.log(n_components) + np.log(2)

    return (
        np.log(np.pi) - 2 * logcosh(np.pi * z / 2 / np.sqrt(3)) - np.log(4 * np.sqrt(3))
    )


class LpNestedPrior:
    def __init__(self, p=[0.5, [[1.0], [2.0, [[1.0], [1.0]]]]], scale=1.0):
        self.p = p
        self.scale = scale
        self.n = self.count_leaves(self.p)
        self.Sf1_log = self.Sf_log(self.p)

    def dimz(self):
        return self.n

    def pradial_log_prob(self, v0, p0, n, s=1.0):
        u = n / p0
        a = np.log(p0) + torch.log(v0) * (n - 1)
        b = scipy.special.gammaln(u) + np.log(s) * (u)
        c = -(v0**p0) / s
        val = a - b + c
        return val

    def count_leaves(self, p):
        if len(p) <= 1:
            return 1
        else:
            val = 0
            for p_Ik in p[1]:
                val = val + self.count_leaves(p_Ik)
            return val

    def betaFunction(self, a, b):
        return (
            scipy.special.gamma(a) * scipy.special.gamma(b) / scipy.special.gamma(a + b)
        )

    def betaFunctionLog(self, a, b):
        return (
            scipy.special.loggamma(a)
            + scipy.special.loggamma(b)
            - scipy.special.loggamma(a + b)
        )

    def Sf_recursive(self, p):
        if len(p) <= 1:
            # leaf node
            return 1
        else:
            p_I = p[0]
            n_I = self.count_leaves(p)
            p_children = p[1]
            l_I = len(p_children)
            val = 1.0 / p_I ** (l_I - 1)
            n_Ik_sum = 0
            for k in range(len(p_children) - 1):
                n_Ik = self.count_leaves(p_children[k])
                n_Ik_sum = n_Ik_sum + n_Ik
                n_Ik_plus = self.count_leaves(p_children[k + 1])
                val = val * self.betaFunction(n_Ik_sum / p_I, n_Ik_plus / p_I)
            for p_child in p_children:
                val = val * self.Sf_recursive(p_child)
            return val

    # compute the surface of a unit ball
    def Sf(self, p):
        n = self.count_leaves(p)
        return 2**n * self.Sf_recursive(p)

    def Sf_log_recursive(self, p):
        if len(p) <= 1:
            # leaf node
            return 0
        else:
            p_I = p[0]
            n_I = self.count_leaves(p)
            p_children = p[1]
            l_I = len(p_children)
            val = -np.log(p_I) * (l_I - 1)
            n_Ik_sum = 0
            for k in range(len(p_children) - 1):
                n_Ik = self.count_leaves(p_children[k])
                n_Ik_sum = n_Ik_sum + n_Ik
                n_Ik_plus = self.count_leaves(p_children[k + 1])
                val = val + self.betaFunctionLog(n_Ik_sum / p_I, n_Ik_plus / p_I)
            for p_child in p_children:
                val = val + self.Sf_log_recursive(p_child)
            return val

    # compute the logarithm of the surface of a unit ball
    def Sf_log(self, p):
        n = self.count_leaves(p)
        return n * np.log(2.0) + self.Sf_log_recursive(p)

    def f_recursive(self, z, p, k):
        if len(p) <= 1:
            # leaf node
            return torch.abs(z[:, k]), k + 1, 1.0
        else:
            p_I = p[0]
            p_children = p[1]
            val = 0
            for p_child in p_children:
                op, k, p_Iplus = self.f_recursive(z, p_child, k)
                val = val + op ** (p_I / p_Iplus)
            return val, k, p_I

    def f(self, z, p):
        val, k, p_I = self.f_recursive(z, p, k=0)
        return val ** (1.0 / p_I)

    def log_prob(self, z):
        v0 = self.f(z, self.p)
        p0 = self.p[0]

        res1 = self.pradial_log_prob(v0, p0, self.n, self.scale)

        log_divisor = torch.log(v0) * (self.n - 1) + self.Sf1_log
        res = res1 - log_divisor
        return res
