import math
import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, Any
import random
from utils import get_grid_uniform, get_fdprs

def plot_inferrred_hyperparameters_real_data(
    CMLR_model: torch.nn.Module,
    label_params: Dict[str, Any]
):
    """
    Plot histograms of inferred GP hyperparameters (length-scale and variance)
    recovered from a trained CMLR model on real data.

    Parameters
    ----------
    CMLR_model : torch.nn.Module
        Trained CMLR model instance, containing:
          - unconstrained_lengthscale (Tensor [D, no_of_outputs])
          - log_rho                  (Tensor [D])
    label_params : dict
        Configuration dictionary with keys:
          - 'no_of_outputs': int (1 or 2), number of label dimensions
          - 'lengthscale_bounds': tuple (min_ls, max_ls), bounds for length-scale

    Returns
    -------
    None
        Displays matplotlib figures of histogram plots.
    """
    # 1) Extract length-scale bounds from label parameters
    ls_min, ls_max = label_params['lengthscale_bounds']

    # 2) Recover inferred length-scales by transforming unconstrained parameters
    #    via sigmoid into the [ls_min, ls_max] interval
    inferred_lengthscale = (
        ls_min
        + (ls_max - ls_min)
        * torch.sigmoid(CMLR_model.unconstrained_lengthscale)
    ).detach().cpu().numpy()  # shape: [D, no_of_outputs]

    # 3) Recover inferred variances by exponentiating log_rho
    inferred_rho = torch.exp(CMLR_model.log_rho).detach().cpu().numpy()  # shape: [D]

    # 4) Determine number of subplots: one per length-scale dimension + one for variance
    n_inputs = label_params['no_of_outputs']
    n_plots = n_inputs + 1

    # 5) Set up figure and axes for horizontal layout
    fig, axes = plt.subplots(
        1, n_plots,
        figsize=(6 * n_plots, 3),
        constrained_layout=True
    )
    axes = np.atleast_1d(axes).flatten()

    # 6) Plot histograms for each length-scale dimension
    for i in range(n_inputs):
        ax = axes[i]
        ax.hist(
            inferred_lengthscale[:, i],
            bins=50,
            color='skyblue',
            edgecolor='black'
        )
        # Set dimension-specific titles
        if n_inputs == 1:
            ax.set_title('Length-scale (ℓ)')
        else:
            ax.set_title(f'Length-scale ℓ_{i+1}')
        ax.set_xlabel('Length-scale')
        ax.set_ylabel('Frequency')

    # 7) Plot histogram for variances
    ax_rho = axes[-1]
    ax_rho.hist(
        inferred_rho,
        bins=50,
        color='lightgreen',
        edgecolor='black'
    )
    ax_rho.set_title('Variance (ρ)')
    ax_rho.set_xlabel('Variance')
    ax_rho.set_ylabel('Frequency')

    # 8) Display the combined figure
    plt.show()


def plot_sample_weights_real_data(
    CMLR_model: torch.nn.Module,
    label_params: Dict[str, Any]
):
    """
    Reconstruct and visualize inferred decoding weights for a subset of neurons
    on a uniform input grid, handling both 1D and 2D label spaces.

    Parameters
    ----------
    CMLR_model : torch.nn.Module
        Trained CMLR model containing:
          - mu_q: Tensor [Bdim, D], posterior mean of Fourier-domain weights
    label_params : dict
        Configuration dictionary with keys:
          - 'no_of_outputs': int, either 1 or 2 dimensions
          - 'dimension_1_range': [min, max] for axis y1
          - 'dimension_2_range': [min, max] for axis y2 (if 2D)
    """
    # 1) Build a uniform grid of input labels and compute its Fourier basis
    grid_no_of_points = 10000
    grid = get_grid_uniform(grid_no_of_points, label_params)  # [K, dims]
    fdprs_grid = get_fdprs(grid, label_params)
    Bmat = fdprs_grid['Bmat']  # [K, Bdim]

    # 2) Reconstruct the posterior mean functions on the grid
    mu_q = CMLR_model.mu_q.detach()  # [Bdim, D]
    w_inferred = (Bmat.float() @ mu_q.float()).cpu().numpy()  # [K, D]
    # Center each neuron's function by subtracting its mean
    w_inferred -= w_inferred.mean(axis=0, keepdims=True)

    # 3) Select a random subset of features to plot
    D = w_inferred.shape[1]
    no_of_features_selected_plot = 50
    selected_features = random.sample(range(D), no_of_features_selected_plot)
    w_inferred_selected = w_inferred[:, selected_features]  # [K, n_selected]

    # 4) Determine subplot grid dimensions
    ncols = min(10, no_of_features_selected_plot)
    nrows = math.ceil(no_of_features_selected_plot / ncols)

    # 5) Plot for 1D inputs
    if label_params['no_of_outputs'] == 1:
        grid_pts = grid.squeeze(-1)  # [K]

        # Compute common axis limits across all selected features
        x_min, x_max = float(grid_pts.min()), float(grid_pts.max())
        y_min = float(w_inferred_selected.min())
        y_max = float(w_inferred_selected.max())

        fig, axs = plt.subplots(
            nrows, ncols,
            figsize=(2 * ncols, 1.5 * nrows),
            constrained_layout=True
        )
        axs = axs.flatten()

        # Iterate over each selected feature and subplot
        for i, feat_idx in enumerate(selected_features):
            ax = axs[i]
            # Plot the inferred decoding weight vs. label value
            ax.plot(
                grid_pts,
                w_inferred_selected[:, i],
                '-',
                color='r', lw=1,
                label='Inferred'
            )

            # Enforce identical x and y limits for all subplots
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)

            # Only show x-labels on the bottom row
            row, col = divmod(i, ncols)
            if row == nrows - 1:
                ax.set_xlabel("Label value")
            else:
                ax.set_xticks([])

            # Only show y-labels on the first column
            if col == 0:
                ax.set_ylabel("Inferred weight")
            else:
                ax.set_yticks([])

            ax.set_title(f"Feature {feat_idx}", fontsize=12)
            ax.legend(fontsize=6)

        # Hide any unused subplots
        for ax in axs[no_of_features_selected_plot:]:
            ax.set_visible(False)

        plt.show()

    # 6) Plot for 2D inputs
    elif label_params['no_of_outputs'] == 2:
        # Compute shared color limits across features
        vmin = float(w_inferred_selected.min())
        vmax = float(w_inferred_selected.max())

        # Retrieve axis ranges
        axis_range = {
            'y1': label_params['dimension_1_range'],
            'y2': label_params['dimension_2_range']
        }
        xmin, xmax = axis_range['y1']
        ymin, ymax = axis_range['y2']
        span_x = xmax - xmin
        span_y = ymax - ymin
        # Determine grid spacing so area ≈ number of points
        s = math.sqrt((span_x * span_y) / grid_no_of_points)

        # Build mesh coordinates
        xv = torch.arange(xmin, xmax - s/2, step=s)
        yv = torch.arange(ymin, ymax - s/2, step=s)
        X, Y = torch.meshgrid(xv, yv, indexing='xy')
        ny, nx = X.shape

        fig, axes = plt.subplots(
            nrows, ncols,
            figsize=(4 * ncols, 2.5 * nrows),
            sharex=True, sharey=True,
            constrained_layout=True
        )
        axes = np.atleast_1d(axes).flatten()

        # Plot each feature as an image
        for idx, feat_idx in enumerate(selected_features):
            ax = axes[idx]
            im = ax.imshow(
                w_inferred_selected[:, idx].reshape(ny, nx),
                origin='lower',
                extent=[xmin, xmax, ymin, ymax],
                aspect='equal',
                vmin=vmin, vmax=vmax,
                cmap='bwr'
            )
            ax.set_title(f"Feature {feat_idx}", fontsize=12)

            # Only label edges
            row, col = divmod(idx, ncols)
            if row == nrows - 1:
                ax.set_xlabel('x')
            else:
                ax.set_xticks([])
            if col == 0:
                ax.set_ylabel('y')
            else:
                ax.set_yticks([])

        # Add a single colorbar for all subplots
        fig.colorbar(
            im,
            ax=axes.tolist(),
            orientation='vertical',
            fraction=0.046,
            pad=0.04,
            label=r'$\mu_{TD}$'
        )

        plt.show()

