import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from utils import get_fdprs
from NaiveBayes import Naive_Bayes

def train_Naive_Bayes(model, optimizer, num_epochs=None, tol=None, patience=None):
    """
    Train the Naive Bayes GP model by maximizing the log posterior,
    using gradient-based optimization with early stopping.

    Parameters
    ----------
    model : torch.nn.Module
        A Naive Bayes GP model whose forward() returns the current log-posterior.
    optimizer : torch.optim.Optimizer
        Optimizer instance (e.g., Adam) used to update model parameters.
    num_epochs : int, optional
        Maximum number of training epochs. Defaults to 10,000 if not provided.
    tol : float, optional
        Minimum improvement in ELBO required to reset early-stopping counter.
        Defaults to 1e-4.
    patience : int, optional
        Number of consecutive epochs with ELBO improvement < tol before stopping.
        Defaults to 100.

    Returns
    -------
    loss_history : list of float
        The log posterior (loss) at each epoch, for monitoring convergence.
    """

    # Optionally add a learning‐rate scheduler; e.g. halve LR every 10k epochs:
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.5)

    loss_history = []                       # to record log posterior each epoch
    best_posterior = -float('inf')         # best log posterior seen so far
    patience_counter = 0                   # counter for early stopping

    # Set default parameters if none provided
    if num_epochs is None:
        num_epochs = 10000
    if tol is None:
        tol = 1e-4
    if patience is None:
        patience = 100

    # Main training loop
    for epoch in range(num_epochs):
        optimizer.zero_grad()               # reset gradients

        # Forward pass: model() should return the log posterior 
        posterior = model()
        # Define loss as negative posterior (we perform gradient descent on -posterior)
        loss = -posterior
        loss.backward()                     # compute gradients
        optimizer.step()                    # update parameters

        # Early-stopping logic:
        # If posterior improved by at least tol, reset patience; otherwise increment it.
        if posterior > best_posterior + tol:
            best_posterior = posterior
            patience_counter = 0
        else:
            patience_counter += 1

        # If no significant improvement for `patience` epochs, stop early.
        if patience_counter >= patience:
            # print(
            #     f"Early stopping at epoch {epoch} with ELBO={posterior:.4f}: "
            #     f"no improvement > {tol} for {patience} consecutive epochs."
            # )
            break

        # Record the loss for later analysis or plotting
        loss_history.append(loss.item())

    return loss_history


def fit_NB_model(xsamples, labels, NB_fixed_params):
    """
    Fit a separate Naive Bayes–style model to each feature dimension using
    a Fourier‐domain GP prior.

    Parameters
    ----------
    xsamples : torch.Tensor, shape [L, D]
        Observed feature matrix (L samples, D features).
    labels : torch.Tensor, shape [L, 1]
        Continuous labels corresponding to each sample.
    NB_fixed_params : dict
        Training hyperparameters and label configuration:
          - 'learing_rate'           : float, Adam learning rate.
          - 'label_params'           : dict, passed to get_fdprs().
          - 'maximum_num_epochs'     : int, max SVI epochs (per feature).
          - 'tolerance'              : float, ELBO improvement threshold.
          - 'patience'               : int, early‐stop patience.
    
    Returns
    -------
    NB_model : dict
        Contains:
          - 'mu_q'                    : Tensor [Bdim, D], learned Fourier means.
          - 'inferred_hyperparameters': dict with keys
                'len'   : Tensor [D, no_of_outputs], inferred length-scales
                'rho'   : Tensor [D, 1], inferred process variances
                'noise' : Tensor [D, 1], inferred noise variances
    """
    # --- 1) Unpack training parameters ---
    learning_rate = NB_fixed_params['learing_rate']
    label_params  = NB_fixed_params['label_params']
    max_num_epochs = NB_fixed_params['maximum_num_epochs']
    tolerance      = NB_fixed_params['tolerance']
    patience       = NB_fixed_params['patience']
    no_of_outputs   = label_params['no_of_outputs']

    # --- 2) Compute Fourier‐domain basis for non-uniform labels ---
    fdprs_labels = get_fdprs(labels, label_params)
    Bdim = fdprs_labels['fBdims'][1]  # number of Fourier basis functions
    D = xsamples.shape[1]             # number of feature dimensions
    print("Number of Fourier coefficients:", Bdim)

    # --- 3) Preallocate storage for inferred parameters ---
    # mu_q: posterior mean of Fourier weights, one column per feature
    mu_q = torch.zeros((Bdim, D))
    # lengthscale, rho, and noise for each feature
    lengthscale_inferred_all = torch.zeros((D, no_of_outputs))
    rho_inferred_all         = torch.zeros((D, 1))
    noise_inferred_all       = torch.zeros((D, 1))

    # --- 4) Fit a separate model for each feature dimension ---
    for d in range(D):

    # only print every 5th feature (i.e. when d is 0, 5, 10, …)
        if d % 5 == 0:
            print(f"Training Naive Bayes model for feature {d+1}/{D}...")        # Extract one column of xsamples and ensure shape [L, 1]
        xsamp_d = xsamples[:, d].unsqueeze(1)

        # Initialize a fresh Naive_Bayes model for this feature
        NB_model_d = Naive_Bayes(xsamp_d, fdprs_labels, label_params)
        optimizer  = optim.Adam(NB_model_d.parameters(), lr=learning_rate)

        # Override training params if desired
        # (Here fixed for demonstration; you can use max_num_epochs, tolerance, patience)
        train_Naive_Bayes(
            NB_model_d,
            optimizer,
            num_epochs=max_num_epochs,
            tol=tolerance,
            patience=patience
        )

        # --- 5) Extract learned parameters for this dimension ---
        # Fourier-domain mean
        mu_q[:, d] = NB_model_d.mu_q.squeeze()
        # length-scale(s) → positive via exp(log_lengthscale)
        lengthscale_inferred_all[d, :] = torch.exp(NB_model_d.log_lengthscale)
        # process variance
        rho_inferred_all[d]           = torch.exp(NB_model_d.log_rho)
        # observation noise variance
        noise_inferred_all[d]         = torch.exp(NB_model_d.log_sigma_y)

    # --- 6) Package all inferred results into a dict ---
    inferred_hyperparameters = {
        'len'  : lengthscale_inferred_all,
        'rho'  : rho_inferred_all,
        'noise': noise_inferred_all
    }
    NB_model = {
        'mu_q'                    : mu_q,
        'inferred_hyperparameters': inferred_hyperparameters
    }
    return NB_model
