import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import torch
from torch import nn
import torch.nn.functional as F
from dmhp import utils


class DPP(nn.Module):
    """
    Super class, discrete point process
    """
    
    def __init__(self, dt: float, device=torch.device("cpu")) -> None:
        super().__init__()
        self.device = device
        self.dt = dt # time bin width
    
    def __str__(self) -> str:
        output_str = ""
        for name, params in self.named_parameters():
            output_str += f"{name}: {params.data}\n"
        return output_str[:-1]


class DiscreteMultivariateHawkesProcess(DPP):
    def __init__(self, n_neurons: int, dt: float, activation='sigmoid', device=torch.device("cpu")) -> None:
        super().__init__(dt, device)

        self.n_neurons = n_neurons

        if activation == 'linear':
            self.activation = lambda x: torch.relu(x) + 1e-5
        elif activation == 'exp':
            self.activation = torch.exp
        elif activation == 'sigmoid':
            self.activation = lambda x: 1 / dt * torch.sigmoid(x)
        elif activation == 'softplus':
            self.activation = nn.Softplus()


    def init_params(self, bg_intensity=None, weight=None):
        if bg_intensity is not None:
            self.bg_intensity = nn.Parameter(bg_intensity.detach().clone().to(self.device))
        else:
            self.bg_intensity = nn.Parameter(0.5/self.n_neurons*torch.ones(self.n_neurons, device=self.device))
        if weight is not None:
            self.weight = nn.Parameter(weight.detach().clone().to(self.device))
        else:
            self.weight = nn.Parameter(torch.zeros((self.n_neurons, self.n_neurons), device=self.device))
        
        # if torch.linalg.eigvals(self.weight).abs().max() > 1:
        #     print("Warning! Unstable weight matrix") # for stability


    def conditional_intensity(self, t=None, spikes=None, convolved_spikes=None) -> torch.FloatTensor:
        """Compute the conditional intensity for all neurons at time t or the whole spike train.

        Parameters
        ----------
        t : float
            Current time.
        spikes : torch.FloatTensor of shape (n_time_bins, n_neurons)
            Spike train.
        convolved_spikes : torch.FloatTensor of shape (n_time_bins, n_neurons)
            Convolved spike train.

        Returns
        -------
        cond_intensity : torch.FloatTensor of shape (n_neurons,) or (n_time_bins, n_neurons)
            Conditional intensity tensor at time t, or the whole spike train.
        """

        if t is not None:
            n = int(t / self.dt)
            if n >= self.window_size:
                cond_intensity = self.bg_intensity + torch.sum(spikes.T[None, :, n-self.window_size:n] * self.weight[:, :, None] * self.filter, axis=(1, 2))
            elif n == 0:
                cond_intensity = self.bg_intensity
            else:
                cond_intensity = self.bg_intensity + torch.sum(spikes.T[None, :, :n] * self.weight[:, :, None] * self.filter[:, :, -n:], axis=(1, 2))
        else:
            cond_intensity = self.bg_intensity + convolved_spikes @ self.weight.T
            
        return self.activation(cond_intensity)

    
    def sample(self, T: float, distribution='Poisson', rng=None):
        if distribution == 'Poisson':
            distribution = torch.poisson
        elif distribution == 'Bernoulli':
            distribution = torch.bernoulli
        n_time_bins = int(T / self.dt)
        spikes = torch.zeros(n_time_bins, self.n_neurons, device=self.device)
        firing_rates = torch.zeros(n_time_bins, self.n_neurons, device=self.device)
        with torch.no_grad():
            for n in range(n_time_bins):
                firing_rates[n, :] = self.conditional_intensity(n * self.dt, spikes=spikes) * self.dt
                spikes[n, :] = distribution(firing_rates[n, :], generator=rng)
        return spikes, firing_rates


    def neg_log_likelihood(self, spikes: torch.FloatTensor, convolved_spikes: torch.FloatTensor, distribution='Poisson') -> torch.tensor:
        """Compute the negative log-likelihood of the spike train.
        
        Parameters
        ----------
        spikes : torch.FloatTensor of shape (n_time_bins, n_neurons)
            Spike train.
        convolved_spikes : torch.FloatTensor of shape (n_neurons, n_neurons, n_time_bins)
            Convolved spike train.

        Returns
        -------
        nll : torch.FloatTensor of shape ()
            Negative log-likelihood of the spike train + constant.
        """

        firing_rates = self.conditional_intensity(convolved_spikes=convolved_spikes) * self.dt
        if distribution == 'Poisson':
            nll = torch.sum(-spikes * firing_rates.log() + firing_rates)
        elif distribution == 'Bernoulli':
            nll = -torch.sum(spikes * firing_rates.log() + (1-spikes) * (1-firing_rates).log())
        return nll
    

class GLMChangingDynamics(DPP):
    def __init__(self, dt: float, basis: torch.FloatTensor, activation='sigmoid', device=torch.device("cpu")) -> None:
        super().__init__(dt, device)

        self.basis = basis
        self.flipped_basis = torch.flip(self.basis, (0,))
        self.window_size = len(self.basis)

        if activation == 'linear':
            self.activation = lambda x: torch.relu(x) + 1e-5 # lambda x: x
        elif activation == 'exp':
            self.activation = torch.exp
        elif activation == 'sigmoid':
            self.activation = lambda x: 1 / dt * torch.sigmoid(x)
        elif activation == 'softplus':
            self.activation = nn.Softplus()

    def add_data(self, spikes: torch.FloatTensor, bg_intensity=None, weight=None, n_pieces=100) -> None:
        self.spikes = spikes
        self.n_time_bins, self.n_neurons = spikes.shape
        self.n_pieces = n_pieces
        self.piece_length = self.n_time_bins // self.n_pieces
        
        self.convolved_spikes = np.zeros_like(spikes)
        for neuron in range(self.n_neurons):
            self.convolved_spikes[1:, neuron] = np.convolve(spikes[:, neuron], self.basis)[:-self.window_size]
        self.convolved_spikes = torch.from_numpy(self.convolved_spikes)

        if bg_intensity is not None:
            self.bg_intensity = nn.Parameter(bg_intensity.detach().clone().to(self.device))
        else:
            self.bg_intensity = nn.Parameter(torch.zeros(self.n_neurons, device=self.device))
        if weight is not None:
            self.weight = nn.Parameter(weight.detach().clone().to(self.device))
        else:
            self.weight = nn.Parameter(torch.zeros((self.n_pieces, self.n_neurons, self.n_neurons), device=self.device))

    def weight_at_time_bin(self, time_bin: int):
        return self.weight[time_bin // self.piece_length]

    def conditional_intensity(self) -> torch.FloatTensor:
        return self.activation(self.bg_intensity + (self.convolved_spikes[:, None, :] @ self.weight.repeat_interleave(self.piece_length, dim=0)[:self.n_time_bins].permute((0, 2, 1))).squeeze())
    
    def firing_rates(self) -> torch.FloatTensor:
        return self.conditional_intensity() * self.dt

    def neg_log_likelihood(self, distribution='Poisson') -> torch.FloatTensor:
        firing_rates = self.firing_rates()
        if distribution == 'Poisson':
            nll = torch.sum(-self.spikes * firing_rates.log() + firing_rates + torch.lgamma(self.spikes+1))
        elif distribution == 'Bernoulli':
            nll = -torch.sum(self.spikes * firing_rates.log() + (1-self.spikes) * (1-firing_rates).log())
        return nll
    
    def loss_fn(self, regularizer=[1]):
        loss = self.neg_log_likelihood()
        for i in range(len(regularizer)):
            loss += regularizer[i] * torch.mean((torch.diff(self.weight, n=i+1, dim=0))**2)
        return loss
    
    def train(self, optimizer, regularizer=[1]):
        loss = self.loss_fn(regularizer)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        return loss.item()
    
    def evaluate(self, regularizer=[1]):
        with torch.no_grad():
            loss = self.loss_fn(regularizer)
        return loss.item()
        

class OnehotGLMChangingDynamics(GLMChangingDynamics):
    def __init__(self, dt: float, basis: torch.FloatTensor, activation='sigmoid', device=torch.device("cpu")) -> None:
        super().__init__(dt, basis, activation, device)
    
    def add_data(self, spikes: torch.FloatTensor, bg_intensity=None, logit_onehot_connection=None, logit_weight=None, n_pieces=100) -> None:
        self.spikes = spikes
        self.n_time_bins, self.n_neurons = spikes.shape
        self.n_pieces = n_pieces
        self.piece_length = self.n_time_bins // self.n_pieces
        
        self.convolved_spikes = np.zeros_like(spikes)
        for neuron in range(self.n_neurons):
            self.convolved_spikes[1:, neuron] = np.convolve(spikes[:, neuron], self.basis)[:-self.window_size]
        self.convolved_spikes = torch.from_numpy(self.convolved_spikes)

        if bg_intensity is not None:
            self.bg_intensity = nn.Parameter(bg_intensity.detach().clone().to(self.device))
        else:
            self.bg_intensity = nn.Parameter(torch.zeros(self.n_neurons, device=self.device))
        if logit_onehot_connection is not None:
            self.logit_onehot_connection = nn.Parameter(logit_onehot_connection.detach().clone().to(self.device))
        else:
            self.logit_onehot_connection = nn.Parameter(torch.zeros((self.n_pieces, self.n_neurons, self.n_neurons, 3), device=self.device))
        if logit_weight is not None:
            self.logit_weight = nn.Parameter(logit_weight.detach().clone().to(self.device))
        else:
            self.logit_weight = nn.Parameter(torch.zeros((self.n_neurons, self.n_neurons, 2), device=self.device))

    def connection_at_time_bin(self, time_bin: int):
        return torch.argmax(self.logit_onehot_connection[time_bin // self.piece_length], dim=-1) - 1

    def conditional_intensity(self) -> torch.FloatTensor:
        w = 0.08 * torch.sigmoid(self.logit_weight) + 0.02
        onehot_connection = F.gumbel_softmax(self.logit_onehot_connection, tau=0.5, dim=-1)
        exc_matrix = 1 * onehot_connection[:, :, :, 2] * w[:, :, 1]
        inh_matrix = (-1) * onehot_connection[:, :, :, 0] * w[:, :, 0]
        weight = exc_matrix + inh_matrix

        return self.activation(self.bg_intensity + (self.convolved_spikes[:, None, :] @ weight.repeat_interleave(self.piece_length, dim=0)[:self.n_time_bins].permute((0, 2, 1))).squeeze())

    def loss_fn(self, regularizer=[1]):
        loss = self.neg_log_likelihood()
        onehot_connection = F.softmax(self.logit_onehot_connection, dim=-1)
        for i in range(len(regularizer)):
            loss += regularizer[i] * torch.mean((torch.diff(onehot_connection, n=i+1, dim=0))**2)
        return loss