import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from typing import Dict, Any, Tuple

from utils import get_grid_uniform, get_fdprs,get_circular_error

def decode_labels_NB(NB_model, J_decode, xsamples_test, label_params):
    """
    Memory-efficient decoding of continuous labels from Naive Bayes GP model.
      - avoids constructing full N×K×D tensors
      - uses matrix norms and cross-terms for speed

    Parameters
    ----------
    NB_model : dict
        Contains:
          - 'mu_q': Tensor [Bdim, D], posterior mean of Fourier-domain weights.
          - 'inferred_hyperparameters': dict with key 'noise' for observation noise.
    J_decode : int
        Number of uniform grid points for decoding.
    xsamples_test : torch.Tensor, shape [N, D]
        Test feature matrix (N samples, D features).
    label_params : dict
        Configuration dict with keys:
          - 'no_of_outputs': 1 or 2, input dimensionality.
          - 'dimension_1_range' / 'dimension_2_range': grid limits.
          - 'is_label_range_circular': whether labels wrap around (angles).
    Returns
    -------
    decoded_label_mode : torch.Tensor, shape [N] or [N, dims]
        MAP (mode) decoding from posterior.
    decoded_label_mean : torch.Tensor, shape [N] or [N, dims]
        Posterior mean decoding (circular mean if specified).
    """

    # 1) Build uniform decoding grid and its Fourier basis
    grid_decode = get_grid_uniform(J_decode, label_params)      # [K, dims]
    fdprs_decode = get_fdprs(grid_decode, label_params)         # Fourier parameters for grid
    Bmat = fdprs_decode['Bmat'].float()                         # [K, Bdim]

    # 2) Compute decoding weights: map Fourier-domain mean to label space
    mu_q = NB_model['mu_q'].detach().float()                    # [Bdim, D]
    decoding_weights = Bmat @ mu_q                              # [K, D]

    # 3) Compute negative least-squares error (–LSE) = log-likelihood terms
    #    Using formula: ‖x–μ‖²/(2σ²) = ‖x‖²/(2σ²) + ‖μ‖²/(2σ²) – (x·μ)/σ²

    # 3a) Precompute factor 1/(2σ²) per feature
    inv2σ2 = 1.0 / (2 * NB_model['inferred_hyperparameters']['noise'].squeeze().pow(2))  # [D]

    # 3b) ‖x‖² term for each sample: sum_d x_d² * inv2σ2_d
    xs_norm = (xsamples_test**2 * inv2σ2.unsqueeze(0)).sum(dim=1, keepdim=True)         # [N,1]

    # 3c) ‖μ‖² term for each grid point: sum_d μ_d² * inv2σ2_d
    mu_norm = (decoding_weights**2 * inv2σ2.unsqueeze(0)).sum(dim=1)                    # [K]

    # 3d) Cross-term: 2 * sum_d x_d * μ_d * inv2σ2_d
    xs_scaled = xsamples_test * inv2σ2.unsqueeze(0)                                     # [N, D]
    cross2   = xs_scaled @ decoding_weights.T                                          # [N, K]

    # 3e) Combine to get LSE per (sample, grid) and convert to log-likelihood
    LSE     = xs_norm + mu_norm.unsqueeze(0) - 2 * cross2                               # [N, K]
    log_lik = -LSE                                                                      # [N, K]

    # 4) Normalize to posterior probabilities over K grid points
    posterior = torch.softmax(log_lik, dim=1)  # [N, K]

    # 5) Mode decoding: pick grid index with highest posterior
    decode_class = posterior.argmax(dim=1)     # [N]
    decoded_label_mode = grid_decode[decode_class]
    if label_params['no_of_outputs'] == 1:
        # squeeze singleton dimension for 1D labels
        decoded_label_mode = decoded_label_mode.squeeze()

    # 6) Mean decoding: posterior-weighted average (handles circular labels separately)
    if label_params.get('is_label_range_circular', False):
        # 6a) Circular mean via atan2 of weighted sine/cosine
        grid_flat = grid_decode.squeeze(1)     # [K]
        weighted_sin = torch.sum(posterior * torch.sin(grid_flat), dim=1)
        weighted_cos = torch.sum(posterior * torch.cos(grid_flat), dim=1)
        decoded_label_mean = torch.atan2(weighted_sin, weighted_cos)
        # Map angles to [0, 2π)
        decoded_label_mean = (decoded_label_mean + 2 * np.pi) % (2 * np.pi)
    else:
        # 6b) Linear mean decoding: matrix multiplication
        decoded_label_mean = posterior @ grid_decode  # [N, dims]

    return decoded_label_mode, decoded_label_mean

def get_avg_decoding_error_vs_resolution_NB(NB_model, J_decode_range, xsamples_test, labels_test, label_params):
    """
    Compute the average decoding error for different decoding resolutions.

    Parameters
    ----------
    NB_model : NB_model
        The trained Naive Bayes model.
    J_decode_range : np.ndarray
        Array of decoding resolutions.
    xsamples_test : torch.Tensor
        Test samples.
    labels_test : torch.Tensor
        True labels for the test samples.
    label_params : dict
        Label parameters for the model.

    Returns
    -------
    avg_errors_mode_CMLR : list
        List of average errors for mode decoding.
    avg_errors_mean_CMLR : list
        List of average errors for mean decoding.
    """

    # List to store the mean absolute error for each K_decode.
    avg_errors_mode_NB = [] # for mode decoding
    avg_errors_mean_NB = [] # for mean decoding

    # Loop over the decoding resolutions.
    for J_decode in J_decode_range:
        # Decode the labels using the CMLR model.
        decoded_label_mode_NB,decoded_label_mean_NB = decode_labels_NB(NB_model, J_decode, xsamples_test, label_params)
        if label_params['no_of_outputs'] == 1:
            if label_params['is_label_range_circular']:
                # Compute the absolute circular error between the true angles and the decoded angles for mode decoding.
                abs_error_mode_NB = get_circular_error(labels_test, decoded_label_mode_NB)
            else:
                # Compute the absolute error between the true angles and the decoded angles for mode decoding.
                abs_error_mode_NB = torch.abs(labels_test - decoded_label_mode_NB)
        elif label_params['no_of_outputs'] == 2:
            abs_error_mode_NB = torch.norm(decoded_label_mode_NB - labels_test, dim=1)  
        # Compute the mean error across all trials for the current J_decode.
        avg_error_mode_NB = torch.mean(abs_error_mode_NB).item()
        # Append the average error to the list.
        avg_errors_mode_NB.append(avg_error_mode_NB)
        if label_params['no_of_outputs'] == 1:
            if label_params['is_label_range_circular']:
                # Compute the absolute circular error between the true angles and the decoded angles for mode decoding.
                abs_error_mean_NB = get_circular_error(labels_test, decoded_label_mean_NB)
            else:
                # Compute the absolute error between the true angles and the decoded angles for mode decoding.
                abs_error_mean_NB = torch.abs(labels_test - decoded_label_mean_NB)
        elif label_params['no_of_outputs'] == 2:
            abs_error_mean_NB = torch.norm(decoded_label_mean_NB - labels_test, dim=1)  
        # Compute the mean error across all trials for the current J_decode.
        avg_error_mean_NB = torch.mean(abs_error_mean_NB).item()
        # Append the average error to the list.
        avg_errors_mean_NB.append(avg_error_mean_NB)
    # Convert the lists to NumPy arrays and convert from radians to degrees.    
    avg_errors_mode_NB = np.array(avg_errors_mode_NB)
    avg_errors_mean_NB = np.array(avg_errors_mean_NB)

    return avg_errors_mode_NB, avg_errors_mean_NB