import math
import numpy as np
import torch
import matplotlib.pyplot as plt
from typing import Dict, Any
from utils import get_fdprs
from utils import get_grid_uniform

def decoding_weight_comparisons_NB(
    NB_model: torch.nn.Module,
    w_true_FD: torch.Tensor,
    label_params: Dict[str, Any]
):
    """
    Compare inferred decoding weights to ground truth by reconstructing
    both on a uniform grid and visualizing errors.

    Parameters
    ----------
    NB_model : torch.nn.Module
        A model dict containing at least:
          - 'mu_q': Tensor [Bdim, D], posterior mean of Fourier-domain weights.
    w_true_FD : torch.Tensor, shape [Bdim, D]
        Ground-truth Fourier weights for each of the D neurons.
    label_params : dict
        Configuration dict with keys:
          - 'no_of_outputs': 1 or 2, dimensionality of label space.
          - 'dimension_1_range': [min, max] for axis y1.
          - 'dimension_2_range': [min, max] for axis y2 (if 2D).
    """
    # 1) Build a uniform grid of label values and its Fourier basis
    grid_no_of_points = 1000
    grid = get_grid_uniform(grid_no_of_points, label_params)   # [K, dims]
    fdprs_grid = get_fdprs(grid, label_params)
    Bmat = fdprs_grid['Bmat']                                  # [K, Bdim]

    # 2) Reconstruct inferred weight functions on the grid
    mu_q = NB_model['mu_q']                       # [Bdim, D]
    w_inferred = (Bmat.float() @ mu_q.detach().float()).cpu().numpy()   # [K, D]
    # Center each neuron’s function by subtracting its mean
    w_inferred -= w_inferred.mean(axis=0, keepdims=True)

    # 3) Reconstruct true weight functions on the grid
    w_true = (Bmat.float() @ w_true_FD.float()).cpu().numpy()  # [K, D]
    w_true -= w_true.mean(axis=0, keepdims=True)

    D = w_true.shape[1]  # number of neurons/features
    eps = 1e-8           # small constant to avoid divide-by-zero

    # 4) Compute per-neuron normalized MAE
    err = w_inferred - w_true            # [K, D]
    mae = np.mean(np.abs(err), axis=0)   # [D]
    range_true = w_true.max(axis=0) - w_true.min(axis=0) + eps  # [D]
    nmae = mae / range_true              # [D]

    # Print overall error summary
    print(
        f"NMAE in inferred decoding weights over {D} neurons: "
        f"{100*nmae.mean():.2f}% ± {100*nmae.std():.2f}%"
    )

    # 5) Scatter vs. identity line plots for each neuron
    ncols = min(10, D)
    nrows = math.ceil(D / ncols)
    mn = min(w_true.min(), w_inferred.min())
    mx = max(w_true.max(), w_inferred.max())

    fig, axs = plt.subplots(
        nrows, ncols,
        figsize=(4 * ncols, 2.5 * nrows),
        constrained_layout=True
    )
    axs = axs.flatten()
    for d in range(D):
        ax = axs[d]
        ax.scatter(w_true[:, d], w_inferred[:, d], s=20, alpha=0.6)
        ax.plot([mn, mx], [mn, mx], 'k--', lw=1)
        ax.text(
            0.05, 0.90,
            f"NMAE = {100*nmae[d]:.1f}%",
            transform=ax.transAxes,
            va='top', ha='left',
            fontsize=10,
            bbox=dict(boxstyle="round,pad=0.2", alpha=0.2)
        )
        ax.set_title(f"Neuron {d+1}", fontsize=12)
        ax.set_xlabel("True (centered)")
        ax.set_ylabel("Inferred (centered)")
        ax.set_aspect('auto')

    # Hide unused subplots
    for ax in axs[D:]:
        ax.set_visible(False)
    plt.show()

    # 6) For 1D inputs, also plot the weight functions vs. label value
    if label_params['no_of_outputs'] == 1:
        grid_pts = grid.squeeze(-1)  # [K]

        fig2, axs2 = plt.subplots(
            nrows, ncols,
            figsize=(4 * ncols, 2.5 * nrows),
            constrained_layout=True
        )
        axs2 = axs2.flatten()
        for d in range(D):
            ax = axs2[d]
            # Plot true weights as dashed black, inferred as solid red
            ax.plot(grid_pts, w_true[:, d],  '--', label='True', color='k', lw=1)
            ax.plot(grid_pts, w_inferred[:, d], '-', label='Inferred', color='r', lw=1)
            ax.text(
                0.05, 0.90,
                f"NMAE = {100*nmae[d]:.1f}%",
                transform=ax.transAxes,
                va='top', ha='left',
                fontsize=10,
                bbox=dict(boxstyle="round,pad=0.2", alpha=0.2)
            )
            ax.set_title(f"Neuron {d+1}", fontsize=12)
            ax.set_xlabel("Label value")
            ax.set_ylabel("Decoding weight")
            ax.legend(fontsize=8)

        # Hide extra subplots
        for ax in axs2[D:]:
            ax.set_visible(False)
        plt.show()

def hyperparameter_comparisons_NB(NB_model, true_hyperparameters):
    """
    Compare true vs. inferred GP hyperparameters (length-scales, variance, noise)
    for 1D or 2D input spaces.

    Parameters
    ----------
    NB_model : dict
        Should contain key 'inferred_hyperparameters' mapping to a dict with:
          - 'len'   : Tensor [D] (1D) or [D,2] (2D) inferred length-scales
          - 'rho'   : Tensor [D] inferred process variances
          - 'noise' : Tensor [D] inferred noise variances
    true_hyperparameters : dict
        Ground-truth values with keys:
          - 'len'   : list or array [D] or [D,2]
          - 'rho'   : list or array [D]
          - 'noise' : list or array [D]
    """
    # Pull out inferred values and convert to NumPy
    inf = NB_model['inferred_hyperparameters']
    inferred_len   = inf['len'].detach().cpu().numpy()    # [D] or [D,2]
    inferred_rho   = inf['rho'].detach().cpu().numpy()    # [D]
    inferred_noise = inf['noise'].detach().cpu().numpy()  # [D]

    # Convert true values to NumPy arrays
    true_len   = np.array(true_hyperparameters['len'])    # [D] or [D,2]
    true_rho   = np.array(true_hyperparameters['rho'])    # [D]
    true_noise = np.array(true_hyperparameters['noise'])  # [D]

    # Determine number of length-scale dimensions (1 or 2)
    if true_len.ndim == 1:
        n_ls = 1
        # reshape to [D,1] for uniform indexing
        true_len = true_len.reshape(-1, 1)
        inferred_len = inferred_len.reshape(-1, 1)
    else:
        n_ls = true_len.shape[1]

    # Total subplots = length-scales + rho + noise
    n_plots = n_ls + 2

    # Neuron indices for x-axis
    neurons = np.arange(len(inferred_rho))

    # Create subplots in one row
    fig, axs = plt.subplots(1, n_plots, figsize=(3 * n_plots, 2), constrained_layout=True)
    axs = np.atleast_1d(axs).flatten()

    # 1) Stem plots for each length-scale dimension
    for i in range(n_ls):
        ax = axs[i]
        # true vs. inferred for this ℓ-dimension
        ax.stem(neurons, true_len[:, i],    linefmt='b-',  markerfmt='bo', basefmt=' ')
        ax.stem(neurons, inferred_len[:, i], linefmt='r--', markerfmt='ro', basefmt=' ')
        # Title adjusts if 1D or 2D
        if n_ls == 1:
            ax.set_title('Length-scale (ℓ)')
        else:
            label = 'x' if i == 0 else 'y'
            ax.set_title(f'Length-scale {label} (ℓ_{label})')
        ax.set_xlabel('Neuron')
        ax.set_ylabel('Length-scale')
        ax.legend(['True', 'Inferred'], fontsize=8)

    # 2) Stem plot for process variance ρ
    ax_rho = axs[n_ls]
    ax_rho.stem(neurons, true_rho,    linefmt='b-',  markerfmt='bo', basefmt=' ')
    ax_rho.stem(neurons, inferred_rho, linefmt='r--', markerfmt='ro', basefmt=' ')
    ax_rho.set_title('Variance (ρ)')
    ax_rho.set_xlabel('Neuron')
    ax_rho.set_ylabel('Variance')
    ax_rho.legend(['True', 'Inferred'], fontsize=8)

    # 3) Stem plot for noise variance σ²
    ax_noise = axs[n_ls + 1]
    ax_noise.stem(neurons, true_noise,    linefmt='b-',  markerfmt='bo', basefmt=' ')
    ax_noise.stem(neurons, inferred_noise, linefmt='r--', markerfmt='ro', basefmt=' ')
    ax_noise.set_title('Noise Variance (σ²)')
    ax_noise.set_xlabel('Neuron')
    ax_noise.set_ylabel('Noise Variance')
    ax_noise.legend(['True', 'Inferred'], fontsize=8)

    # Create 1×n_plots subplots
    fig, axs = plt.subplots(1, n_plots, figsize=(3 * n_plots, 2), constrained_layout=True)
    axs = np.atleast_1d(axs).flatten()

    # 1) Scatter plot for each length-scale dimension
    for i in range(n_ls):
        ax = axs[i]
        true_i = true_len[:, i]            # true ℓ values for dimension i
        inf_i  = inferred_len[:, i]        # inferred ℓ values for dimension i

        # scatter true vs. inferred
        ax.scatter(true_i, inf_i, alpha=0.7, edgecolor='k')
        # identity line for perfect recovery
        mn = min(true_i.min(), inf_i.min())
        mx = max(true_i.max(), inf_i.max())
        ax.plot([mn, mx], [mn, mx], 'k--', lw=1)

        # title varies for 1D or 2D
        if n_ls == 1:
            ax.set_title('Length-scale (ℓ)')
        else:
            label = 'x' if i == 0 else 'y'
            ax.set_title(f'Length-scale {label} (ℓ₍{label}₎)')

        ax.set_xlabel('True ℓ')
        ax.set_ylabel('Inferred ℓ')

    # 2) Scatter plot for process variance ρ
    ax_rho = axs[n_ls]
    ax_rho.scatter(true_rho, inferred_rho, alpha=0.7, edgecolor='k')
    mn = min(true_rho.min(), inferred_rho.min())
    mx = max(true_rho.max(), inferred_rho.max())
    ax_rho.plot([mn, mx], [mn, mx], 'k--', lw=1)
    ax_rho.set_title('Variance (ρ)')
    ax_rho.set_xlabel('True ρ')
    ax_rho.set_ylabel('Inferred ρ')

    # 3) Scatter plot for noise variance σ²
    ax_noise = axs[n_ls + 1]
    ax_noise.scatter(true_noise, inferred_noise, alpha=0.7, edgecolor='k')
    mn = min(true_noise.min(), inferred_noise.min())
    mx = max(true_noise.max(), inferred_noise.max())
    ax_noise.plot([mn, mx], [mn, mx], 'k--', lw=1)
    ax_noise.set_title('Noise Variance (σ²)')
    ax_noise.set_xlabel('True σ²')
    ax_noise.set_ylabel('Inferred σ²')

    plt.show()