import math
import numpy as np
import torch
import matplotlib.pyplot as plt
from typing import Dict, Any
from utils import get_fdprs
from utils import get_grid_uniform

def decoding_weight_comparisons(
    CMLR_model: torch.nn.Module,
    w_true_FD: torch.Tensor,
    label_params: Dict[str, Any]
):
    """
    Compare inferred decoding weights to ground truth by reconstructing
    both on a uniform grid and visualizing errors.

    Parameters
    ----------
    CMLR_model : torch.nn.Module
        Trained CMLR model, expected to have attribute:
          - mu_q: Tensor [Bdim, D], posterior mean of Fourier weights.
    w_true_FD : torch.Tensor, shape [Bdim, D]
        Ground-truth Fourier weights for each of the D neurons.
    label_params : dict
        Configuration dict with keys:
          - 'no_of_outputs': 1 or 2, dimensionality of label space.
          - 'dimension_1_range': [min, max] for axis y1.
          - 'dimension_2_range': [min, max] for axis y2 (if 2D).
    """
    # 1) Build a uniform grid of label values and its Fourier basis
    grid_no_of_points = 1000
    grid = get_grid_uniform(grid_no_of_points, label_params)   # [K, dims]
    fdprs_grid = get_fdprs(grid, label_params)
    Bmat = fdprs_grid['Bmat']                                  # [K, Bdim]

    # 2) Reconstruct inferred weight functions on the grid
    mu_q = CMLR_model.mu_q.detach()                        # [Bdim, D]
    w_inferred = (Bmat.float() @ mu_q.float()).cpu().numpy()   # [K, D]
    # Center each neuron’s function by subtracting its mean
    w_inferred -= w_inferred.mean(axis=0, keepdims=True)

    # 3) Reconstruct true weight functions on the grid
    w_true = (Bmat.float() @ w_true_FD.float()).cpu().numpy()  # [K, D]
    w_true -= w_true.mean(axis=0, keepdims=True)

    D = w_true.shape[1]  # number of neurons/features
    eps = 1e-8           # small constant to avoid divide-by-zero

    # 4) Compute per-neuron normalized MAE
    err = w_inferred - w_true            # [K, D]
    mae = np.mean(np.abs(err), axis=0)   # [D]
    range_true = w_true.max(axis=0) - w_true.min(axis=0) + eps  # [D]
    nmae = mae / range_true              # [D]

    # Print overall error summary
    print(
        f"NMAE in inferred decoding weights over {D} neurons: "
        f"{100*nmae.mean():.2f}% ± {100*nmae.std():.2f}%"
    )

    # 5) Scatter vs. identity line plots for each neuron
    ncols = min(10, D)
    nrows = math.ceil(D / ncols)
    mn = min(w_true.min(), w_inferred.min())
    mx = max(w_true.max(), w_inferred.max())

    fig, axs = plt.subplots(
        nrows, ncols,
        figsize=(4 * ncols, 2.5 * nrows),
        constrained_layout=True
    )
    axs = axs.flatten()
    for d in range(D):
        ax = axs[d]
        ax.scatter(w_true[:, d], w_inferred[:, d], s=20, alpha=0.6)
        ax.plot([mn, mx], [mn, mx], 'k--', lw=1)
        ax.text(
            0.05, 0.90,
            f"NMAE = {100*nmae[d]:.1f}%",
            transform=ax.transAxes,
            va='top', ha='left',
            fontsize=10,
            bbox=dict(boxstyle="round,pad=0.2", alpha=0.2)
        )
        ax.set_title(f"Neuron {d+1}", fontsize=12)
        ax.set_xlabel("True (centered)")
        ax.set_ylabel("Inferred (centered)")
        ax.set_aspect('auto')

    # Hide unused subplots
    for ax in axs[D:]:
        ax.set_visible(False)
    plt.show()

    # 6) For 1D inputs, also plot the weight functions vs. label value
    if label_params['no_of_outputs'] == 1:
        grid_pts = grid.squeeze(-1)  # [K]

        fig2, axs2 = plt.subplots(
            nrows, ncols,
            figsize=(4 * ncols, 2.5 * nrows),
            constrained_layout=True
        )
        axs2 = axs2.flatten()
        for d in range(D):
            ax = axs2[d]
            # Plot true weights as dashed black, inferred as solid red
            ax.plot(grid_pts, w_true[:, d],  '--', label='True', color='k', lw=1)
            ax.plot(grid_pts, w_inferred[:, d], '-', label='Inferred', color='r', lw=1)
            ax.text(
                0.05, 0.90,
                f"NMAE = {100*nmae[d]:.1f}%",
                transform=ax.transAxes,
                va='top', ha='left',
                fontsize=10,
                bbox=dict(boxstyle="round,pad=0.2", alpha=0.2)
            )
            ax.set_title(f"Neuron {d+1}", fontsize=12)
            ax.set_xlabel("Label value")
            ax.set_ylabel("Decoding weight")
            ax.legend(fontsize=8)

        # Hide extra subplots
        for ax in axs2[D:]:
            ax.set_visible(False)
        plt.show()

def hyperparameter_comparisons(
    CMLR_model: torch.nn.Module,
    true_hyperparameters: Dict[str, Any]
) -> None:
    """
    Compare inferred vs. true GP hyperparameters via summary statistics 
    and visualizations (stem plots and scatter‐vs‐identity).

    Supports both 1D and 2D length-scale parameters based on 
    variational_gp.no_of_outputs.

    Parameters
    ----------
    CMLR_model : torch.nn.Module
        Trained GP model with attributes:
          - unconstrained_lengthscale : Tensor [D, no_of_outputs]
          - log_rho                   : Tensor [D]
          - ls_bounds                 : (min_ls, max_ls)
    true_hyperparameters : dict
        Ground-truth hyperparameters:
          - 'len': array-like of shape [D] (1D) or [D, 2] (2D)
          - 'rho': array-like of shape [D]
    """
    # --- A) Recover inferred hyperparameters from variational parameters ---
    ls_min, ls_max = CMLR_model.ls_bounds
    # Map unconstrained → (ls_min, ls_max) via sigmoid
    inferred_ls = (
        ls_min + (ls_max - ls_min) * torch.sigmoid(
            CMLR_model.unconstrained_lengthscale
        )
    ).detach().cpu().numpy()  # shape [D, no_of_outputs]
    # Recover variances ρ via exponential
    inferred_rho = torch.exp(CMLR_model.log_rho).detach().cpu().numpy()  # [D]

    # --- B) Load true hyperparameters as numpy arrays ---
    true_len = np.array(true_hyperparameters['len'])  # [D] or [D,2]
    true_rho = np.array(true_hyperparameters['rho'])  # [D]
    D = inferred_rho.shape[0]

    # --- C) Compute and print NMAE for variance ρ ---
    eps = 1e-8
    abs_err_rho = np.abs(inferred_rho - true_rho)               # [D]
    range_rho   = true_rho.max() - true_rho.min() + eps         # scalar range
    nmae_rho    = abs_err_rho / range_rho                       # [D]
    print(
        f"Variance (ρ) NMAE over {D} neurons: "
        f"{100*nmae_rho.mean():.2f}% ± {100*nmae_rho.std():.2f}%"
    )

    # --- D) Compute and print NMAE for length-scale ℓ ---
    abs_err_ls = np.abs(inferred_ls - true_len)  # [D] or [D, no_of_outputs]
    if true_len.ndim == 1:
        # 1D case
        range_ls = np.ptp(true_len) + eps      # scalar
        nmae_ls  = abs_err_ls / range_ls       # [D]
        print(
            f"Length-scale (ℓ) NMAE over {D} neurons: "
            f"{100*nmae_ls.mean():.2f}% ± {100*nmae_ls.std():.2f}%"
        )
        axes = 1
    else:
        # 2D (or multi‐dim) case
        n_dims = true_len.shape[1]
        range_ls = np.ptp(true_len, axis=0) + eps  # [n_dims]
        # normalize error per dimension
        nmae_ls = abs_err_ls / range_ls[None, :]   # [D, n_dims]
        for i in range(n_dims):
            print(
                f"Length-scale (ℓ_{i+1}) NMAE over {D} neurons: "
                f"{100*nmae_ls[:, i].mean():.2f}% ± {100*nmae_ls[:, i].std():.2f}%"
            )
        axes = n_dims

    # --- E) Prepare plotting layout ---
    n_plots = axes + 1  # one stem/scatter per ℓ-dim plus one for ρ
    neuron_idx = np.arange(D)

    # --- F) Stem plots: true vs. inferred values per neuron ---
    fig1, axs1 = plt.subplots(
        1, n_plots,
        figsize=(6*n_plots, 3),
        constrained_layout=True
    )
    axs1 = np.atleast_1d(axs1).flatten()

    # ℓ stem(s)
    for i in range(axes):
        ax = axs1[i]
        true_vals = true_len if axes == 1 else true_len[:, i]
        inf_vals  = inferred_ls[:, i]
        ax.stem(neuron_idx, true_vals, linefmt='b-', markerfmt='bo', basefmt=' ')
        ax.stem(neuron_idx, inf_vals,  linefmt='r--', markerfmt='ro', basefmt=' ')
        title = 'ℓ' if axes == 1 else f'ℓ_{i+1}'
        ax.set_title(title)
        ax.set_xlabel('Neuron')
        ax.set_ylabel('Length-scale')
        ax.legend(['True','Inferred'], fontsize=8)

    # ρ stem
    ax_rho = axs1[-1]
    ax_rho.stem(neuron_idx, true_rho,    linefmt='b-', markerfmt='bo', basefmt=' ')
    ax_rho.stem(neuron_idx, inferred_rho, linefmt='r--', markerfmt='ro', basefmt=' ')
    ax_rho.set_title('ρ')
    ax_rho.set_xlabel('Neuron')
    ax_rho.set_ylabel('Variance')
    ax_rho.legend(['True','Inferred'], fontsize=8)
    plt.show()

    # --- G) Scatter‐vs‐identity plots to assess bias/spread ---
    fig2, axs2 = plt.subplots(
        1, n_plots,
        figsize=(4*n_plots, 3),
        constrained_layout=True
    )
    axs2 = np.atleast_1d(axs2).flatten()

    # ℓ scatter(s)
    for i in range(axes):
        ax = axs2[i]
        true_vals = true_len if axes == 1 else true_len[:, i]
        inf_vals  = inferred_ls[:, i]
        ax.scatter(true_vals, inf_vals, alpha=0.7, s=20)
        mn, mx = true_vals.min(), inf_vals.max()
        ax.plot([mn, mx], [mn, mx], 'k--', lw=1)
        title = 'ℓ' if axes == 1 else f'ℓ_{i+1}'
        ax.set_title(title)
        ax.set_xlabel('True')
        ax.set_ylabel('Inferred')

    # ρ scatter
    ax = axs2[-1]
    ax.scatter(true_rho, inferred_rho, alpha=0.7, s=20)
    mn, mx = true_rho.min(), true_rho.max()
    ax.plot([mn, mx], [mn, mx], 'k--', lw=1)
    ax.set_title('ρ')
    ax.set_xlabel('True')
    ax.set_ylabel('Inferred')
    plt.show()

