import math
import torch
import torch.nn as nn
import numpy as np
from torch.distributions import Categorical
from utils import krbf_fourier_cdiag_multi
from utils import constrained_to_unconstrained

class CMLR(nn.Module):
    """
    Continuous Multinomial Logistic Regression (CMLR) with a Fourier-domain GP prior.

    This class defines a variational GP prior over decoding weights and
    a multinomial logistic likelihood. Supports 1D or 2D input spaces
    as specified by `label_params['no_of_outputs']`.
    """

    def __init__(
            self,
            xsamples: torch.Tensor,
            fdprs_labels: dict,
            fdprs_uniform_grid: dict,
            label_params: dict
        ):
        """
        Initialize CMLR.

        Args:
            xsamples (Tensor[L, D]):
                Observed feature vectors (neural responses), where L is number
                of trials and D is number of features/neurons.
            fdprs_labels (dict):
                Fourier parameters for non-uniform labels:
                - 'Bmat': Tensor [L, Bdim]
                - 'wwsq' : Tensor [Bdim] or [Bdim,1]
            fdprs_uniform_grid (dict):
                Fourier parameters for uniform grid:
                - 'Bmat': Tensor [K, Bdim]
            label_params (dict):
                - 'no_of_outputs': int, 1 or 2 dimensions.
                - 'dimension_1_range': [min, max] for axis y1.
                - 'dimension_2_range': [min, max] for axis y2 (if 2D).
                - Optional 'lengthscale_bounds': (min, max) override.
        """
        # Call parent constructor
        super(CMLR, self).__init__()

        # --- 1) Determine input dimensionality and label range dict ---
        no_of_outputs = label_params['no_of_outputs']
        # extract axis limits from label_params
        if label_params['no_of_outputs'] == 1:
            label_range = {
                'y1': label_params['dimension_1_range']
            }
        else:
            label_range = {
                'y1': label_params['dimension_1_range'],
                'y2': label_params['dimension_2_range']
            }
            
        # --- 2) Register Fourier basis matrices and data as buffers ---
        Bmat_non_uniform = fdprs_labels['Bmat']  # [L, Bdim]
        wwsq              = fdprs_labels['wwsq'] # [Bdim] or [Bdim,1]
        Bmat_uniform      = fdprs_uniform_grid['Bmat'] # [K, Bdim]
        self.register_buffer("Bmat_non_uniform", Bmat_non_uniform.clone().float())
        self.register_buffer("Bmat_uniform",     Bmat_uniform.clone().float())
        self.register_buffer("wwsq",             wwsq.clone().float())
        self.register_buffer("xsamp",            xsamples.clone().float())

        # --- 3) Extract core dimensions ---
        self.D = xsamples.shape[1]           # number of features/neurons
        self.L = Bmat_non_uniform.shape[0]  # number of label samples (L)
        Bdim = Bmat_non_uniform.shape[1]    # number of Fourier coefficients

        # --- 4) Variational parameters for Fourier-domain weights ---
        # q(w) ~ N(mu_q, sigma_q^2)
        self.mu_q = nn.Parameter(torch.zeros(Bdim, self.D))              # shape: [Bdim, D]
        self.mu_q.data.normal_(0, 0.01)  # small random init
        self.log_sigma_q = nn.Parameter(torch.ones(Bdim, self.D) * np.log(.001)) 

        # --- 5) Determine lengthscale bounds from label_range ---
        spans = []
        for (low, high) in label_range.values():
            spans.append(high - low)
        # Now pick the maximum span:
        max_span = max(spans)
        # Set default lengthscale bounds based on the maximum span
        lengthscale_bounds = (0.01, 2*max_span)
        # If provided, override with user-specified bounds
        lengthscale_bounds = label_params['lengthscale_bounds'] if 'lengthscale_bounds' in label_params else lengthscale_bounds
        self.ls_bounds = lengthscale_bounds
        ls_min, ls_max = lengthscale_bounds

        # --- 6) Initialize unconstrained hyperparameters ---
        # Map an initial lengthscale of 10 into unconstrained space
        init_val_len = constrained_to_unconstrained(torch.tensor(max_span/4), ls_min, ls_max)
        init_val_rho = np.log(1)
        # One lengthscale per neuron × per input dimension
        self.unconstrained_lengthscale = nn.Parameter(torch.ones(self.D,no_of_outputs) * init_val_len)  # shape: [D, no_of_outputs]
        # One log-variance per neuron
        self.log_rho = nn.Parameter(torch.ones(self.D) * init_val_rho)          # shape: [D]

    def forward(
        self,
        x_batch: torch.Tensor = None,
        indices: torch.Tensor = None,
        num_samples: int = None,
        eps_samples: torch.Tensor = None
        ) -> torch.Tensor:
        """
        Compute the Evidence Lower Bound (ELBO) for the CMLR model.

        Args:
            x_batch (Tensor[B, D], optional):
                Mini-batch of feature vectors. Defaults to full dataset.
            indices (Tensor[B], optional):
                Corresponding label indices for this batch. Defaults to all.
            num_samples (int, optional):
                Monte Carlo samples for ELBO. Defaults to 1 or eps_samples length.
            eps_samples (Tensor[S, Bdim, D], optional):
                Pre-drawn noise for reproducibility.

        Returns:
            elbo (Tensor):
                Scalar ELBO estimate = E_q[log p] - KL[q||p].
        """
        
        # --- A) Select data batch and build combined Fourier basis ---
        if x_batch is None:
            x_batch = self.xsamp
        if indices is None:
            indices = torch.arange(self.L)
        # Rows of non-uniform basis for this batch
        Bmat_batch = self.Bmat_non_uniform[indices, :]  # Select the rows of Bmat corresponding to the indices : [B, Bdim]
        # Concatenate with uniform grid basis [K, Bdim]
        Bmat_batch_concat = torch.cat([Bmat_batch, self.Bmat_uniform], dim=0)  # shape: [B+K, Bdim]
        B = x_batch.shape[0]              # Mini-batch size
        N_full = self.xsamp.shape[0]      # Total number of data points
        N_effective = (N_full // B) * B    # effective count

        # --- B) Constrain hyperparameters via sigmoid transform ---
        ls_min, ls_max = self.ls_bounds
        # Use sigmoid to map the unconstrained parameter to (0, 1), then scale to the desired interval.
        lengthscale = ls_min + (ls_max - ls_min) * torch.sigmoid(self.unconstrained_lengthscale)  # shape:  [D, no_of_outputs]
        rho = torch.exp(self.log_rho)                 # shape: [D]
        sigma_q = torch.exp(self.log_sigma_q)           # shape: [Bdim, D]
        
        # Spectral prior variances for each Fourier coefficient & neuron
        k_m = krbf_fourier_cdiag_multi(lengthscale, rho, self.wwsq)  # shape: [Bdim, D]

        # --- C) Prepare Monte Carlo noise for sampling weights ---
        if num_samples is None:
            num_samples = eps_samples.shape[0] if eps_samples is not None else 1
        if eps_samples is None:
            eps_samples = torch.randn(num_samples, *self.mu_q.shape, device=self.mu_q.device)
        else:
            expected_shape = (num_samples,) + self.mu_q.shape
            if eps_samples.shape != expected_shape:
                raise ValueError(f"eps_samples should have shape {expected_shape}, but got {eps_samples.shape}")

        # Sample Fourier-domain weights: w_samp [num_samples, Bdim, D]
        w_samp = self.mu_q.unsqueeze(0) + sigma_q.unsqueeze(0) * eps_samples
        # --- D) Compute expected log-likelihood via Monte Carlo ---
        # Map to label domain: W_time [num_samples, B+K, D]
        w_time = torch.matmul(Bmat_batch_concat.unsqueeze(0), w_samp)
        # Compute logits: [num_samples, B, B+K]
        logits = torch.matmul(x_batch.unsqueeze(0), w_time.transpose(1, 2))
        # Split into label (trace) vs grid (log-sum-exp) parts
        logits_non_uniform = logits[:, :, :B]  # shape: [num_samples, num_trials, self.L]
        logits_uniform     = logits[:, :, B:]  # shape: [num_samples, num_trials, K]
        # Trace term: sum of diagonal logits_lab
        trace_vals = torch.diagonal(logits_non_uniform, dim1=-2, dim2=-1).sum(-1)  # shape: [num_samples]
        # Log-sum-exp over grid for each trial, then sum
        logsumexp_vals = torch.logsumexp(logits_uniform, dim=2).sum(dim=1)  # shape: [num_samples]
        # Monte Carlo estimate of expected log-likelihood
        log_likelihood_samples = trace_vals - logsumexp_vals  # shape: [num_samples]
        mini_batch_ll = log_likelihood_samples.mean()
        # Scale the mini-batch log-likelihood to approximate the full-data likelihood.
        log_likelihood_est = mini_batch_ll * (N_effective / B)

        # --- E) KL divergence between q(w) and p(w) ---
        KL_divergence = 0.5 * torch.sum(
            torch.log(k_m / (sigma_q**2)) - 1.0 + (sigma_q**2 + self.mu_q**2) / k_m
        )

        # --- F) ELBO = E[log-likelihood] - KL divergence ---
        elbo = log_likelihood_est - KL_divergence
        return elbo
