import torch
import torch.nn as nn
import numpy as np
from utils import krbf_fourier_cdiag

class Naive_Bayes(nn.Module):
    """
    A simple Naive Bayes–style variational GP for a single feature dimension.
    Approximates the posterior over a latent GP via Fourier-domain weights.

    Args
    ----
    xsamp : torch.Tensor, shape [L, 1]
        Observed feature values (one-dimensional responses) for L samples.
    fdprs : dict
        Precomputed Fourier-domain parameters containing:
          - 'Bmat': torch.Tensor [K, Bdim], Fourier basis evaluated at label grid.
          - 'wwsq': torch.Tensor [Bdim] or [Bdim, 1], squared normalized frequencies.
    label_params : dict
        Configuration dict with key:
          - 'no_of_outputs': 1 or 2, dimensionality of the label space.
    """
    def __init__(self, xsamp, fdprs, label_params):
        super(Naive_Bayes, self).__init__()

        # 1) Determine input dimensionality
        no_of_outputs = label_params['no_of_outputs']

        # 2) Unpack Fourier basis and frequencies
        Bmat = fdprs['Bmat']      # [K, Bdim]: maps Fourier weights → function values
        wwsq = fdprs['wwsq']      # squared frequencies
        # If 1D labels, drop any singleton trailing dimension
        if no_of_outputs == 1:
            wwsq = wwsq.squeeze(-1)

        # Store as buffers/attributes
        self.Bmat = Bmat
        self.wwsq = wwsq
        self.K    = Bmat.shape[0]  # number of grid points (stimuli)
        Bdim      = Bmat.shape[1]  # number of Fourier basis functions

        # 3) Variational parameters: posterior mean of Fourier weights
        #    Shape [Bdim, 1] because we model a single output dimension
        self.mu_q = nn.Parameter(torch.zeros(Bdim, 1))

        # 4) Store input responses (data) for likelihood term
        #    Clone to detach from any external graph
        self.x = xsamp.clone().detach()

        # 5) Hyperparameters of the GP prior in log-space for positivity
        #    log_rho controls the GP marginal variance
        self.log_rho = nn.Parameter(torch.tensor(np.log(1.0)))

        # 6) Length-scale ℓ for RBF kernel, one per label dimension, in log-space
        #    Initialized to a small constant (.1)
        self.log_lengthscale = nn.Parameter(torch.log(torch.ones(no_of_outputs) * 0.1))

        # 7) Observation noise variance σ_y^2, in log-space for positivity
        self.log_sigma_y = nn.Parameter(torch.tensor(np.log(0.1)))

    def forward(self):
        """
        Compute the Evidence Lower Bound (ELBO) approximation to the log-posterior.

        Returns
        -------
        log_posterior : torch.Tensor, scalar
            Sum of expected log-likelihood and KL divergence (GP prior vs. variational).
        """
        # 1) Transform hyperparameters back to the positive domain
        rho         = torch.exp(self.log_rho)            # GP variance
        lengthscale = torch.exp(self.log_lengthscale)    # RBF length-scale(s)
        sigma_y     = torch.exp(self.log_sigma_y)        # observation noise std

        # 2) Compute GP prior spectrum diagonal in Fourier domain
        #    cdiag = diag of prior covariance in Fourier basis
        k_m = krbf_fourier_cdiag(lengthscale, rho, self.wwsq)

        # 3) Expected log-likelihood under q(w):
        #    B @ mu_q  gives the mean function values at inputs
        Bmu_q = torch.matmul(self.Bmat, self.mu_q)       # [K, 1]
        #    Gaussian likelihood: sum over L data points
        #    up to additive const: −0.5 L log σ_y^2 − 0.5 ‖x − Bμ‖² / σ_y²
        log_likelihood = (
            -0.5 * self.K * torch.log(sigma_y**2)
            -0.5 * (torch.norm(self.x - Bmu_q)**2) / (sigma_y**2)
        )

        # 4) KL divergence between q(w)=N(μ_q, 0) and prior(propto cdiag):
        #    −½ ∑ [ log cdiag + μ_q² / cdiag ]
        log_prior = -0.5 * torch.sum(
            torch.log(k_m) + (self.mu_q**2) / k_m
        )

        # 5) ELBO / log-posterior = expected log-likelihood + log-prior
        log_posterior = log_likelihood + log_prior
        return log_posterior

