import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from dmhp import utils, model


class HMMGLM(model.DPP):
    """Hidden Markov model + Generalized linear models.

    Parameters
    ----------
    n_states : int
        Number of states.
    n_neurons : int
        Number of neurons.
    dt : float
        Width of the time bin in ms.
    basis : torch.FloatTensor of shape (window_size,)
        Basis function of the GLM, measuring the history influence.
    flipped_basis : torch.FloatTensor of shape (window_size,)
        Flipped basis function that is easy to do convolution with spike trains.
    window_size : int
        Lentgh of the discretized basis vector.
    activation : str or callable, default='sigmoid'
        Non-linear activation function in ['linear' | 'exp' | 'sigmoid' | 'softplus'] or a callable function.
    intensity_upperbound : torch.FloatTensor of shape (n_neurons,), default=None.
        The multiplier of the sigmoid output for each neuron. The None default means all 1 / dt.
    bg_intensity : nn.Parameter of shape (n_states, n_neurons) or (n_neurons,), default=None
        The background intensity of each neuron in each state. The None default means all zeros and same shared for all states.
    weight : nn.Parameter of shape (n_states, n_neurons, n_neurons), default=None
        The weight matrix in $\\mathbb{R}^{\\text{n_states}\\times \\text{n_neurons} \\times \\text{n_neurons}}$.
    transition_matrix : nn.Parameter of shape (n_states, n_states), default=None
        The Markov transition matrix. The None default means diagonal are 0.98 and all the remainings are equal.
    
    """

    def __init__(self, n_states: int, n_neurons: int, dt: float, basis: torch.FloatTensor, activation='sigmoid', bg_intensity: torch.FloatTensor = None, weight: torch.FloatTensor = None, intensity_upperbound: torch.FloatTensor = None, transition_matrix: torch.FloatTensor = None, transition_rate: float = None, device: torch.device = torch.device("cpu")) -> None:
        super().__init__(dt, device)

        self.n_states = n_states
        self.n_neurons = n_neurons
        self.basis = basis
        self.flipped_basis = torch.flip(self.basis, (0,))
        self.window_size = len(self.basis)

        if activation == 'linear':
            self.activation = lambda x: torch.relu(x) + 1e-5
        elif activation == 'exp':
            self.activation = torch.exp
        elif activation == 'sigmoid':
            if intensity_upperbound is None:
                self.intensity_upperbound = nn.Parameter(1 / dt * torch.ones(self.n_neurons).to(self.device), requires_grad=False)
            else:
                self.intensity_upperbound = nn.Parameter(intensity_upperbound.detach().clone().to(self.device), requires_grad=False)
            self.activation = lambda x: self.intensity_upperbound * torch.sigmoid(x)
        elif activation == 'softplus':
            self.activation = nn.Softplus()
        else:
            self.activation = activation

        if bg_intensity is None:
            self._bg_intensity = nn.Parameter(torch.zeros((self.n_neurons), device=self.device)) # shared bg_intensity for all states
        else:
            self._bg_intensity = nn.Parameter(bg_intensity.clone().detach().to(self.device))
        if weight is None:
            self._weight = nn.Parameter(torch.zeros((self.n_states, self.n_neurons, self.n_neurons), device=self.device))
        else:
            self._weight = nn.Parameter(weight.clone().detach().to(self.device))
        
        if self.n_states > 1:
            if transition_matrix is None:
                if transition_rate is None:
                    p_change = 0.02
                else:
                    p_change = transition_rate * self.dt
                self.transition_matrix = nn.Parameter(torch.eye(n_states) * (1-p_change) + p_change / (n_states - 1) * (1-torch.eye(n_states)), requires_grad=False).to(self.device)
            else:
                self.transition_matrix = nn.Parameter(transition_matrix.clone().detach(), requires_grad=False).to(self.device)
    
    @property
    def bg_intensity(self):
        if len(self._bg_intensity.shape) == 1:
            return self._bg_intensity.expand(self.n_states, -1)
        else:
            return self._bg_intensity
    
    @property
    def weight(self):
        return self._weight
    
    @property
    def adjacency(self):
        return utils.weight_to_adjacency(self.weight)
    
    @property
    def adjacency_index(self):
        return utils.weight_to_adjacency_index(self.weight)
    
    def permute_states(self, true_to_learned: torch.LongTensor):
        with torch.no_grad():
            self._weight.data[:] = self._weight.data[true_to_learned]
            if len(self._bg_intensity.shape) > 1:
                self._bg_intensity.data[:] = self._bg_intensity.data[true_to_learned]

    def firing_rates(self, convolved_spikes: torch.FloatTensor, states: torch.FloatTensor = None) -> torch.FloatTensor:
        """Compute the firing rates for all neurons at time t or the whole spike train.

        Parameters
        ----------
        convolved_spikes : torch.FloatTensor of shape (n_time_bins, n_neurons)
            Convolved spike train.
        states : None or torch.LongTensor of shape (n_time_bins, )
            States of the corresponding spike train. If provided, will only compute the firing rates using the states provided; otherwise, will compute the firing rates in all possible states.
        
        Returns
        -------
        fr : torch.FloatTensor of shape (n_time_bins, n_neurons) or (n_states, n_time_bins, n_neurons)
            If states is provided, will only be the firing rates using the states provided (n_time_bins, n_neurons); otherwise, will be the firing rates in all possible states (n_states, n_time_bins, n_neurons).
        """

        if states is None:
            return self.dt * self.activation(self.bg_intensity[:, None, :] + convolved_spikes @ self.weight.permute((0, 2, 1)))
        else:
            return self.dt * self.activation(self.bg_intensity[states, :] + (convolved_spikes[:, None, :] @ self.weight[states, :, :].permute((0, 2, 1)))[:, 0, :])
    
    def simulate_sample_sequence(self, n_seq: int, T: float, rng: torch.Generator = None) -> tuple:
        """Simulate `n_seq` sequences up to time T.

        Parameters
        ----------
        n_seq : int
            Number of sample sequences to besimulated.
        T: float
            Final time T in seconds.
        
        Returns
        -------
        spikes_list : torch.FloatTensor of shape (n_seq, n_time_bins, n_neurons)
            Spikes list.
        firing_rates_list : torch.FloatTensor of shape (n_seq, n_time_bins, n_neurons)
            Firing rates list.
        states_list : torch.LongTensor of shape (n_seq, n_time_bins)
            States list.
        """

        n_time_bins = int(T / self.dt)
        spikes_list = torch.zeros((n_seq, n_time_bins, self.n_neurons), device=self.device)
        firing_rates_list = torch.zeros_like(spikes_list, device=self.device)
        states_list = torch.zeros((n_seq, n_time_bins), dtype=torch.int64, device=self.device)
        with torch.no_grad():
            states_list[:, 0] = torch.multinomial(1/self.n_states*torch.ones(self.n_states), num_samples=n_seq, replacement=True)
            for n in range(1, n_time_bins):
                states_list[:, n] = torch.multinomial(self.transition_matrix[states_list[:, n-1]], num_samples=1)[:, 0]
            
            firing_rates_list[:, 0, :] = self.dt * self.activation(self.bg_intensity[states_list[:, 0]])
            spikes_list[:, 0, :] = torch.poisson(firing_rates_list[:, 0, :], generator=rng)
            for n in range(1, self.window_size):
                state = states_list[:, n]
                firing_rates_list[:, n, :] = self.dt * self.activation(self.bg_intensity[state] + ((self.flipped_basis[-n:] @ spikes_list[:, :n, :])[:, None, :] @ self.weight[state].permute((0, 2, 1)))[:, 0, :])
                spikes_list[:, n, :] = torch.poisson(firing_rates_list[:, n, :], generator=rng)
            for n in range(self.window_size, n_time_bins):
                state = states_list[:, n]
                firing_rates_list[:, n, :] = self.dt * self.activation(self.bg_intensity[state] + ((self.flipped_basis @ spikes_list[:, n-self.window_size:n, :])[:, None, :] @ self.weight[state].permute((0, 2, 1)))[:, 0, :])
                spikes_list[:, n, :] = torch.poisson(firing_rates_list[:, n, :], generator=rng)
        return spikes_list, firing_rates_list, states_list
    
    # def forward_backward_logsumexp(self, spikes, convolved_spikes, distribution='Poisson'):
    #     init_p = torch.ones(self.n_states) / self.n_states
    #     n_time_bins = len(spikes)
    #     log_transition_matrix = (self.transition_matrix).log()
    #     firing_rates_in_different_states = self.dt * self.activation(self.bg_intensity[:, None, :] + convolved_spikes @ self.weight.permute((0, 2, 1)))
    #     log_emission = utils.log_likelihood(spikes, firing_rates_in_different_states, distribution=distribution).sum(dim=(-1,)).permute((1, 0)) # n_time_bins x n_states
        
    #     log_alpha = torch.zeros((n_time_bins, self.n_states))
    #     log_alpha[0] = torch.log(init_p) + log_emission[0]
        
    #     for t in range(1, n_time_bins):
    #         log_alpha[t] = log_emission[t] + torch.logsumexp(log_transition_matrix.T + log_alpha[t - 1], dim=1)
        
    #     log_beta = torch.zeros((n_time_bins, self.n_states))
    #     log_beta[-1] = 0
    #     for t in range(n_time_bins - 2, -1, -1):
    #         log_beta[t] = torch.logsumexp(log_transition_matrix + log_beta[t + 1] + log_emission[t+1], dim=1)
        
    #     log_complete_data_likelihood = torch.logsumexp(log_alpha[-1], dim=0)
    #     log_gamma = log_alpha + log_beta - log_complete_data_likelihood # posterior probability of hidden
    #     gamma = F.softmax(log_gamma, dim=1) # because not sum to one
    #     return gamma
    
    def forward_backward(self, spikes: torch.FloatTensor, convolved_spikes: torch.FloatTensor, distribution: str = 'Poisson') -> tuple:
        """Forward backward algorithm, corrsponding to the E-step of the EM algorithm.

        Parameters
        ----------
        spikes : torch.FloatTensor of shape (n_time_bins, n_neurons)
            Spike train.
        convolved_spikes : torch.FloatTensor of shape (n_time_bins, n_neurons)
            Convolved spike train.
        distribution : str, optional
            Observation distribution in ["Poisson" | "Bernoulli"], by default "Poisson".
        
        Returns
        -------
        gamma : torch.FloatTensor of shape (n_time_bins, n_states)
            Posterior state probability of time t.
        xi : torch.FloatTensor of shape (n_time_bins, n_states, n_states)
            Posterior state probability of time t and time (t-1).
        """

        with torch.no_grad():
            init_p = torch.ones(self.n_states) / self.n_states
            n_time_bins = len(spikes)
            firing_rates_in_different_states = self.firing_rates(convolved_spikes) # n_states x n_time_bins x n_neurons
            log_emission = utils.log_likelihood(spikes, firing_rates_in_different_states, distribution=distribution).sum(dim=(-1,)).permute((1, 0)) # n_time_bins x n_states
            
            alpha = torch.zeros((n_time_bins, self.n_states))
            c = torch.zeros((n_time_bins,))
            alpha[0] = init_p * log_emission[0].exp()
            c[0] = alpha[0].sum()
            alpha[0] = alpha[0] / c[0]
            
            for t in range(1, n_time_bins):
                alpha[t] = log_emission[t].exp().clamp(min=1e-16) * (self.transition_matrix.T @ alpha[t-1])
                c[t] = alpha[t].sum()
                alpha[t] = alpha[t] / c[t]
            
            beta = torch.zeros((n_time_bins, self.n_states))
            beta[-1] = 1
            for t in range(n_time_bins - 2, -1, -1):
                beta[t] = self.transition_matrix @ (beta[t + 1] * log_emission[t+1].exp().clamp(min=1e-16)) / c[t+1]
            
            gamma = alpha * beta # posterior probability of hidden
            if gamma.isnan().sum() > 0:
                raise ValueError()
            xi = c[1:, None, None] * alpha[:-1, :, None] * log_emission[1:, None, :].exp() * self.transition_matrix[None, :, :] * beta[1:, None, :] # n_time_bins x n_states x n_states
            return gamma, xi
    
    def m_step(self, spikes: torch.FloatTensor, convolved_spikes: torch.FloatTensor, gamma: torch.FloatTensor, xi: torch.FloatTensor, distribution: str = 'Poisson', update_transition_matrix: torch.FloatTensor = True) -> torch.FloatTensor:
        """Forward backward algorithm, corrsponding to the E-step of the EM algorithm.

        Parameters
        ----------
        spikes : torch.FloatTensor of shape (n_time_bins, n_neurons)
            Spike train.
        convolved_spikes : torch.FloatTensor of shape (n_time_bins, n_neurons)
            Convolved spike train.
        gamma : torch.FloatTensor of shape (n_time_bins, n_states)
            Posterior state probability of time t.
        xi : torch.FloatTensor of shape (n_time_bins, n_states)
            Posterior state probability of time t and time (t-1).
        distribution : str, optional
            Observation distribution in ["Poisson" | "Bernoulli"], by default "Poisson".
        update_transition_matrix : bool, optional
            Whether to update the transition matrix, by default True.
        
        Returns
        -------
        out : torch.FloatTensor of shape (1,)
            Q(\\theta, \\theta^{\\text{old}}), i.e., the target function of the EM algorithm.
        """

        init_p = torch.ones(self.n_states) / self.n_states
        firing_rates_in_different_states = self.firing_rates(convolved_spikes) # n_states x n_time_bins x n_neurons
        log_emission = utils.log_likelihood(spikes, firing_rates_in_different_states, distribution=distribution).sum(dim=(-1,)).permute((1, 0)) # n_time_bins x n_states
        term_1 = torch.sum(gamma[0] + init_p.log())
        term_2 = torch.sum(torch.sum(xi, dim=0) * self.transition_matrix.log())
        term_3 = torch.sum(gamma * log_emission)
        if update_transition_matrix is True:
            # self.transition_matrix.data = xi.sum(dim=0) / xi.sum(dim=(0, 2))[:, None]
            self.transition_matrix.data = F.softmax((xi.sum(dim=0) / xi.sum(dim=(0, 2))[:, None]).log().clamp_(min=-9), dim=-1)
        return term_1 + term_2 + term_3
    
    def viterbi(self, spikes: torch.FloatTensor, convolved_spikes: torch.FloatTensor, distribution: str = 'Poisson') -> torch.FloatTensor:
        """Viterbi algorithm to inference the most probable latent state sequence.

        Parameters
        ----------
        spikes : torch.FloatTensor of shape (n_time_bins, n_neurons)
            Spike train.
        convolved_spikes : torch.FloatTensor of shape (n_time_bins, n_neurons)
            Convolved spike train.
        distribution : str, optional
            Observation distribution in ["Poisson" | "Bernoulli"], by default "Poisson".
        
        Returns
        -------
        out : torch.FloatTensor of shape (1,)
            Q(\\theta, \\theta^{\\text{old}}), i.e., the target function of the EM algorithm.
        """

        n_time_bins = len(spikes)
        omega = torch.zeros((n_time_bins, self.n_states))
        psi = torch.zeros((n_time_bins, self.n_states), dtype=torch.int64)

        with torch.no_grad():
            firing_rates_in_different_states = self.firing_rates(convolved_spikes) # n_states x n_time_bins x n_neurons
            log_emission = utils.log_likelihood(spikes, firing_rates_in_different_states, distribution=distribution).sum(dim=(-1,)).permute((1, 0)) # n_time_bins x n_states

            init_p = torch.ones(self.n_states) / self.n_states
            omega[0] = init_p.log() + log_emission[0]

            for t in range(1, n_time_bins):
                temp_matrix = self.transition_matrix.log() + omega[t - 1][:, None]
                values, psi[t] = torch.max(temp_matrix, dim=0)
                omega[t] = log_emission[t] + values
            
            states_pred = torch.zeros(n_time_bins, dtype=torch.int64)
            states_pred[-1] = omega[-1].argmax()
            for t in range(n_time_bins - 1, 0, -1):
                states_pred[t-1] = psi[t, states_pred[t]]
            return states_pred
        

class GaussianHMMGLM(HMMGLM):
    """HMMGLM with a shared bg_intensity and a weight prior w_prior with a fixed hyperparameter sigma^2 of the Gaussian for each element.

    Parameters
    ----------
    w_prior : nn.Parameter of shape (n_neurons, n_neurons), default=None
        Prior weight matrix. Each element w_ij is the mean of the Normal distribution and the w_ij in each state is sampled from this mean.
    sigma : torch.FloatTensor of shape (1,) or (n_neurons, n_neurons), defualt=None
        The standard deviation of the Normal distribution for each element of the weight matrix, the mean of the Normal distribution is in w_prior. The None default means `1/3 * dt`.
    """
    
    def __init__(self, n_states: int, n_neurons: int, dt: float, basis: torch.FloatTensor, activation='sigmoid', bg_intensity: torch.FloatTensor = None, weight: torch.FloatTensor = None, intensity_upperbound: torch.FloatTensor = None, transition_matrix: torch.FloatTensor = None, transition_rate: float = None, w_prior: torch.FloatTensor = None, sigma: float = None, device: torch.device = torch.device("cpu")) -> None:
        super().__init__(n_states, n_neurons, dt, basis, activation, bg_intensity, weight, intensity_upperbound, transition_matrix, transition_rate, device)

        if w_prior is None:
            self.w_prior = nn.Parameter(torch.zeros((self.n_neurons, self.n_neurons), device=self.device), requires_grad=False)
        else:
            self.w_prior = nn.Parameter(w_prior.clone().detach().to(self.device), requires_grad=False)
        
        if sigma is None:
            self.sigma = nn.Parameter(torch.tensor(self.dt/3).to(self.device), requires_grad=False)
        else:
            self.sigma = nn.Parameter(sigma.clone().detach().to(self.device), requires_grad=False)
        
    def sample_weight(self, rng: torch.Generator = None):
        with torch.no_grad():
            for state in range(self.n_states):
                self._weight.data[state] = torch.normal(self.w_prior, self.sigma, generator=rng)
    
    def update_w_prior(self):
        with torch.no_grad():
            self.w_prior.data = self.weight.mean(dim=0)
    
    def update_sigma_prior(self):
        with torch.no_grad():
            empirical_std = self.weight.std(dim=0)
            if len(self.sigma.shape) == 0:
                self.sigma.data = empirical_std.mean()
            else:
                self.sigma.data = empirical_std
    
    def prior_log_likelihood(self):
        """Gaussian prior log-likelihood, or the negative regularization term with constant.
        """

        return torch.sum(-1/2 * (self.weight - self.w_prior)**2 / self.sigma**2) - self.sigma.log() - 1/2*torch.log(2 * torch.tensor(torch.pi))


class OnehotHMMGLM(HMMGLM):
    """HMMOnehotGLM with a shared bg_intensity and an adjacency prior adj_prior (a Categorical distribution) for each element.

    Parameters
    ----------
    adj_prior : nn.Parameter of shape (n_neurons, n_neurons), default=None
        Prior weight matrix. Each element adj_ij is the parameter of the Categorical distribution and the adj_ij in each state is sampled from this Cat distribution.
    log_adjacency : nn.Parameter of shape (n_states, n_neurons, n_neurons, 3), default=None
        Soft onehot version of the adjacency matrix in each state.
    tau : float
        Temperature parameter of the Gumbel-softmax distribution.
    """

    def __init__(self, n_states: int, n_neurons: int, dt: float, basis: torch.FloatTensor, activation='sigmoid', bg_intensity: torch.FloatTensor = None, weight: torch.FloatTensor = None, intensity_upperbound: torch.FloatTensor = None, transition_matrix: torch.FloatTensor = None, transition_rate: float = None, adj_prior: torch.FloatTensor = None, log_adjacency: torch.FloatTensor = None, logit_strength: torch.FloatTensor = None, tau: float = 0.2, weight_tau: float = 0.2, strength_nonlinearity: str = 'sigmoid', device: torch.device = torch.device("cpu")) -> None:
        super().__init__(n_states, n_neurons, dt, basis, activation, bg_intensity, weight, intensity_upperbound, transition_matrix, transition_rate, device)

        if adj_prior is None:
            self.adj_prior = nn.Parameter(1/3 * torch.ones((self.n_neurons, self.n_neurons, 3), device=self.device), requires_grad=True)
        else:
            self.adj_prior = nn.Parameter(adj_prior.clone().detach().to(self.device), requires_grad=True)
        
        if log_adjacency is None:
            self.log_adjacency = nn.Parameter(torch.zeros((self.n_states, self.n_neurons, self.n_neurons, 3)).to(self.device), requires_grad=True)
        else:
            self.log_adjacency = nn.Parameter(log_adjacency.clone().detach().to(self.device), requires_grad=True)

        if logit_strength is None:
            self.logit_strength = nn.Parameter(torch.zeros((self.n_states, self.n_neurons, self.n_neurons, 2)).to(self.device), requires_grad=True)
        else:
            self.logit_strength = nn.Parameter(logit_strength.clone().detach().to(self.device), requires_grad=True)
        
        self.gumbel_softmax_weight = False

        if strength_nonlinearity == 'sigmoid':
            self.strength_nonlinearity = 'sigmoid'
            self.strength_lowerbound = self.dt / 10
            self.strength_upperbound = self.dt
        elif strength_nonlinearity == 'softplus':
            self.strength_nonlinearity = 'softplus'

        self.tau = tau
        self.weight_tau = weight_tau
    
    @property
    def adjacency(self):
        if self.gumbel_softmax_weight is False:
            return F.softmax(self.log_adjacency, dim=-1)
        else:
            return F.gumbel_softmax(self.log_adjacency, tau=self.weight_tau, dim=-1)
        
    @property
    def adjacency_index(self):
        with torch.no_grad():
            return self.log_adjacency.argmax(dim=-1) - 1
        
    @property
    def strength(self):
        if self.strength_nonlinearity == 'sigmoid':
            return self.strength_lowerbound + (self.strength_upperbound - self.strength_lowerbound) * torch.sigmoid(self.logit_strength)
        elif self.strength_nonlinearity == 'softplus':
            return self.dt * (self.logit_strength.exp() + 1 + 1e-5).log()
    
    @property
    def weight(self):
        adjacency = self.adjacency
        strength = self.strength
        
        if len(strength.shape) == 4:
            exc_matrix = 1 * adjacency[:, :, :, 2] * strength[:, :, :, 1]
            inh_matrix = (-1) * adjacency[:, :, :, 0] * strength[:, :, :, 0]
        else:
            exc_matrix = 1 * adjacency[:, :, :, 2] * strength[:, :, :]
            inh_matrix = (-1) * adjacency[:, :, :, 0] * strength[:, :, :]
        return exc_matrix + inh_matrix

    def permute_states(self, true_to_learned: torch.LongTensor):
        with torch.no_grad():
            self.log_adjacency.data[:] = self.log_adjacency.data[true_to_learned]
            self.logit_strength.data[:] = self.logit_strength.data[true_to_learned]
            if len(self._bg_intensity.shape) > 1:
                self._bg_intensity.data[:] = self._bg_intensity.data[true_to_learned]

    def sample_adjacency(self, tau: float = None, hard: bool = False):
        if tau is None:
            tau = self.tau
        with torch.no_grad():
            for state in range(self.n_states):
                self.log_adjacency.data[state] = F.gumbel_softmax(self.adj_prior.log(), tau=tau, dim=-1, hard=hard).log()

    def prior_log_likelihood(self, constant: bool = False):
        """Gumbel-softmax prior log-likelihood, or the negative regularization term with constant. tau = 1.
        """

        adjacency = F.softmax(self.log_adjacency, dim=-1)

        if constant is True:
            constant = self.n_states * self.n_neurons**2 * (torch.lgamma(torch.tensor(3.)) + 2 * self.tau.log()) - (self.tau + 1) * adjacency.log().sum()
        else:
            constant = 0

        return -3 * (self.adj_prior / adjacency**self.tau).sum(dim=-1).log().sum() + self.n_states * self.adj_prior.log().sum() + constant

    def prior_entropy(self):
        adjacency = F.softmax(self.log_adjacency, dim=-1)
        return -(adjacency * adjacency.log()).sum()
    
    def reset_adjacency_from_weight(self):
        with torch.no_grad():
            # adjacency_index = utils.predict_type(self.weight, cut_pos=1/5)
            # adjacency = 0.05 * torch.ones((self.n_states, self.n_neurons, self.n_neurons, 3))
            # adjacency[F.one_hot(adjacency_index + 1, num_classes=3).to(torch.bool)] = 0.9
            # self.log_adjacency.data = adjacency.log()

            adjacency = torch.zeros((self.n_states, self.n_neurons, self.n_neurons, 3))
            weight = self.weight
            idx = weight > 0
            adjacency[:, :, :, 2][idx] = weight[idx] / weight.max()
            adjacency[:, :, :, 1][idx] = 1 - adjacency[:, :, :, 2][idx]
            adjacency[:, :, :, 0][~idx] = weight[~idx] / weight.min()
            adjacency[:, :, :, 1][~idx] = 1 - adjacency[:, :, :, 0][~idx]
            self.log_adjacency.data = (adjacency + 0.01).log()


class ScottHMMGLM(HMMGLM):
    """ScootHMMGLM with a shared bg_intensity and an adjacency prior adj_prior (a Categorical distribution) for each element.

    Parameters
    ----------
    con_prior : nn.Parameter of shape (n_neurons, n_neurons), default=None
        Prior weight matrix. Each element adj_ij is the parameter of the Categorical distribution and the adj_ij in each state is sampled from this Cat distribution.
    log_connection : nn.Parameter of shape (n_states, n_neurons, n_neurons, 2), default=None
        Soft onehot version of the connection matrix in each state.
    tau : float
        Temperature parameter of the Gumbel-softmax distribution.
    """

    def __init__(self, n_states: int, n_neurons: int, dt: float, basis: torch.FloatTensor, activation='sigmoid', bg_intensity: torch.FloatTensor = None, weight: torch.FloatTensor = None, intensity_upperbound: torch.FloatTensor = None, transition_matrix: torch.FloatTensor = None, transition_rate: float = None, con_prior: torch.FloatTensor = None, log_connection: torch.FloatTensor = None, tau: float = 0.2, weight_tau: float = 0.2, device: torch.device = torch.device("cpu")) -> None:
        super().__init__(n_states, n_neurons, dt, basis, activation, bg_intensity, weight, intensity_upperbound, transition_matrix, transition_rate, device)

        if con_prior is None:
            self.con_prior = nn.Parameter(1/2 * torch.ones((self.n_neurons, self.n_neurons, 2), device=self.device), requires_grad=True)
        else:
            self.con_prior = nn.Parameter(con_prior.clone().detach().to(self.device), requires_grad=True)
        
        if log_connection is None:
            self.log_connection = nn.Parameter(torch.zeros((self.n_states, self.n_neurons, self.n_neurons, 2)).to(self.device), requires_grad=True)
        else:
            self.log_connection = nn.Parameter(log_connection.clone().detach().to(self.device), requires_grad=True)
        
        self.gumbel_softmax_weight = False

        self.tau = tau
        self.weight_tau = weight_tau
    
    @property
    def connection(self):
        if self.gumbel_softmax_weight is False:
            return F.softmax(self.log_connection, dim=-1)
        else:
            return F.gumbel_softmax(self.log_connection, tau=self.weight_tau, dim=-1)
    
    @property
    def weight(self):
        return self._weight * self.connection[:, :, :, 1]

    def permute_states(self, true_to_learned: torch.LongTensor):
        with torch.no_grad():
            self.log_connection.data[:] = self.log_connection.data[true_to_learned]
            self._weight.data[:] = self._weight.data[true_to_learned]
            if len(self._bg_intensity.shape) > 1:
                self._bg_intensity.data[:] = self._bg_intensity.data[true_to_learned]

    def sample_connection(self, tau: float = None, hard: bool = False):
        if tau is None:
            tau = self.tau
        with torch.no_grad():
            for state in range(self.n_states):
                self.log_connection.data[state] = F.gumbel_softmax(self.con_prior.log(), tau=tau, dim=-1, hard=hard).log()

    def prior_log_likelihood(self, constant: bool = False):
        """Gumbel-softmax prior log-likelihood, or the negative regularization term with constant. tau = 1.
        """

        connection = F.softmax(self.log_connection, dim=-1)

        if constant is True:
            constant = self.n_states * self.n_neurons**2 * (torch.lgamma(torch.tensor(2.)) + 1 * self.tau.log()) - (self.tau + 1) * connection.log().sum()
        else:
            constant = 0

        return -2 * (self.con_prior / connection**self.tau).sum(dim=-1).log().sum() + self.n_states * self.con_prior.log().sum() + constant

    def prior_entropy(self):
        connection = F.softmax(self.log_connection, dim=-1)
        return -(connection * connection.log()).sum()