import numpy as np
import itertools
from skimage.exposure import match_histograms
from sklearn.decomposition import PCA
from config import get_hparams
from hvae_backbone import utils, checkpoint
from hvae_backbone import init_globals
import torch
from torch.distributions import Laplace
from .task_posterior import laplace_fusion_mean, laplace_moment, gaussian_mixture_fusion
from scipy import stats
import multiprocessing as mp
from functools import partial

def whitenImages(images, filter_noise=True, pca=None):
    if pca is None:
        pca = PCA()
        pca.fit(images)
    latents = pca.transform(images)
    normed_s = 1.0 / np.sqrt(pca.explained_variance_)
    if filter_noise:
        n_retained_comps = int(pca.explained_variance_.shape[0] * np.pi / 4.0)
        normed_s[n_retained_comps:] = 0
    return np.matmul(latents, np.matmul(np.diag(normed_s), pca.components_)), pca

def _generate_single_grating(params):
    """
    Helper function to generate a single grating. Used for parallel processing.
    
    Args:
        params: Tuple of (angle, phase, spatial_frequency, contrast, spherical_mask, noise_std, seed_offset)
    
    Returns:
        tuple: (grating_image, angle, phase, spatial_frequency, contrast)
    """
    angle, phase, spatial_frequency, contrast, spherical_mask, noise_std, seed_offset = params
    
    # Set random seed for this process
    if seed_offset is not None:
        np.random.seed(0 + seed_offset)
    
    x = np.linspace(-1, 1, 40)
    y = np.linspace(-1, 1, 40)
    X, Y = np.meshgrid(x, y)
    
    grating = np.sin(
        2 * np.pi * spatial_frequency * (X * np.cos(np.radians(angle)) + Y * np.sin(np.radians(angle))) + phase
    )
    grating = grating * contrast

    # Apply spherical mask if requested
    if spherical_mask:
        # Create coordinate grid for the mask
        x_mask = np.linspace(-1, 1, 40)
        y_mask = np.linspace(-1, 1, 40)
        X_mask, Y_mask = np.meshgrid(x_mask, y_mask)
        # Create circular mask with diameter 40 (radius 1.0 in normalized coordinates = 20 pixels)
        mask = (X_mask**2 + Y_mask**2) <= 1.0**2
        mask = mask.astype(np.float32)
        grating = grating * mask

    if noise_std is not None:
        # Add noise to the image
        grating = grating + np.random.normal(0, noise_std, grating.shape)
    
    return grating, angle, phase, spatial_frequency, contrast

def generate_grating(spatial_frequency=3, contrast=1, show_plots=True, save_pdf=True, filename=None):
    """
    Generate 5 example grating plots with specified spatial frequency and contrast.
    Angles and phases are chosen randomly.
    
    Args:
        spatial_frequency (float): Spatial frequency of the gratings (default: 3)
        contrast (float): Contrast of the gratings (default: 1)  
        show_plots (bool): Whether to display the plots (default: True)
        save_pdf (bool): Whether to save the figure as a PDF in canvas/plots (default: True)
        filename (str or None): Optional filename for the PDF. If None, a timestamped name is used.
    
    Returns:
        tuple: (gratings, angles, phases) where:
            - gratings: numpy array of shape (5, 40, 40) containing the 5 grating images
            - angles: list of 5 randomly chosen angles
            - phases: list of 5 randomly chosen phases
    """
    import matplotlib.pyplot as plt
    import numpy as np
    
    # Set random seed for reproducible random angles and phases
    np.random.seed(42)
    
    # Generate 5 random angles between 0 and 180 degrees
    angles = np.random.uniform(0, 180, 5)
    
    # Generate 5 random phases between 0 and 2*pi
    phases = np.random.uniform(0, 2*np.pi, 5)
    
    # Generate the 5 gratings
    gratings = []
    for i in range(5):
        # Use the existing helper function to generate each grating
        params = (angles[i], phases[i], spatial_frequency, contrast, False, None, None)
        grating, _, _, _, _ = _generate_single_grating(params)
        gratings.append(grating)
    
    # Convert to numpy array
    gratings = np.array(gratings)
    
    # Create a figure with 5 subplots
    fig, axes = plt.subplots(1, 5, figsize=(20, 4))
    fig.suptitle(f'5 Example Gratings (Spatial Frequency: {spatial_frequency}, Contrast: {contrast})', fontsize=16)

    for i in range(5):
        axes[i].imshow(gratings[i], cmap='gray', vmin=-contrast, vmax=contrast)
        axes[i].set_title(f'Angle: {angles[i]:.1f}°\nPhase: {phases[i]:.2f}')
        axes[i].axis('off')

    plt.tight_layout()

    # Save to PDF in task_vae/canvas/plots if requested
    if save_pdf:
        import os, time
        base_dir = os.path.join(os.path.dirname(__file__), 'canvas', 'plots')
        os.makedirs(base_dir, exist_ok=True)
        if filename is None:
            timestamp = int(time.time())
            filename = f'example_gratings_sf{spatial_frequency}_contrast{contrast}_{timestamp}.pdf'
        pdf_path = os.path.join(base_dir, filename)
        fig.savefig(pdf_path, format='pdf', bbox_inches='tight')
        print(f'Saved example gratings to {pdf_path}')

    if show_plots:
        plt.show()

    plt.close(fig)
    
    return gratings, angles.tolist(), phases.tolist()


def generate_gratings(pca_model=None, hist_match=True, whiten=True, \
                      train_fraction=1, shuffle=True, noise_std=None, spherical_mask=False, n_cores=-1):
    # Define the parameters
    #angles = [0, 30, 45, 60, 65, 70, 75, 90, 120, 135, 150]
    angles = np.arange(0, 180, 5)
    phases = np.linspace(0, 2*np.pi, 50)
    spatial_frequencies = np.arange(1, 5)
    #spatial_frequencies = np.arange(1, 2)
    contrasts = np.linspace(0, 2, 21)
    #contrasts = np.linspace(1, 1, 1)

    # Calculate the total number of images
    n_images = len(angles) * len(phases) * len(spatial_frequencies) * len(contrasts)

    # Create arrays to store the parameters
    angles_array = np.zeros(n_images)
    phases_array = np.zeros(n_images)
    spatial_frequencies_array = np.zeros(n_images)
    contrasts_array = np.zeros(n_images)

    # Prepare parameters for parallel processing
    param_combinations = list(itertools.product(angles, phases, spatial_frequencies, contrasts))
    
    # Add additional parameters for each combination
    params_with_extras = []
    for i, (angle, phase, spatial_frequency, contrast) in enumerate(param_combinations):
        # Use index as seed offset to ensure reproducible results
        seed_offset = i if noise_std is not None else None
        params_with_extras.append((angle, phase, spatial_frequency, contrast, spherical_mask, noise_std, seed_offset))

    # Handle n_cores parameter: -1 means use all available cores
    if n_cores == -1:
        n_cores = mp.cpu_count()
    
    # Use parallel processing to generate gratings
    print(f"Generating {n_images} gratings using {n_cores} CPU cores...")
    with mp.Pool(processes=n_cores) as pool:
        results = pool.map(_generate_single_grating, params_with_extras)
    
    # Extract results
    images = np.zeros((n_images, 40, 40))
    for i, (grating, angle, phase, spatial_frequency, contrast) in enumerate(results):
        images[i] = grating
        angles_array[i] = angle
        phases_array[i] = phase
        spatial_frequencies_array[i] = spatial_frequency
        contrasts_array[i] = contrast

    if shuffle is True:
    #shuffle the images
        indices = np.arange(n_images)

        np.random.shuffle(indices)
        images = images[indices]
        angles_array = angles_array[indices]
        phases_array = phases_array[indices]
        spatial_frequencies_array = spatial_frequencies_array[indices]
        contrasts_array = contrasts_array[indices]
    # Split the data into training and test sets

    train_size = int(train_fraction * n_images)
    train_images = images[:train_size]
    test_images = images[train_size:]

    if hist_match is True:
        #histogram macthing and whitening of the images
        train_images = match_histograms(train_images.reshape(train_images.shape[0], -1), np.random.randn(16384, 16384)).astype(np.float32)
        if train_fraction < 1:
            test_images = match_histograms(test_images.reshape(test_images.shape[0], -1), np.random.randn(16384, 16384)).astype(np.float32)
    else:
        if whiten is True:
            train_images = train_images.reshape(train_images.shape[0], -1)
            if train_fraction < 1:
                test_images = test_images.reshape(test_images.shape[0], -1)

    #whitening
    if whiten is True:
        train_images = whitenImages(train_images, filter_noise=True, pca=pca_model)[0].reshape(train_images.shape[0], 40, 40).astype(np.float32)
        if train_fraction < 1:
            test_images = whitenImages(test_images, filter_noise=True, pca=pca_model)[0].reshape(test_images.shape[0], 40, 40).astype(np.float32)

    train_angles = angles_array[:train_size]
    test_angles = angles_array[train_size:]

    train_phases = phases_array[:train_size]
    test_phases = phases_array[train_size:]

    train_spatial_frequencies = spatial_frequencies_array[:train_size]
    test_spatial_frequencies = spatial_frequencies_array[train_size:]

    train_contrasts = contrasts_array[:train_size]
    test_contrasts = contrasts_array[train_size:]

    # Create the output dictionary
    data = {
        "train_images": train_images,
        "test_images": test_images,
        "train_angle": train_angles,
        "test_angle": test_angles,
        "train_phase": train_phases,
        "test_phase": test_phases,
        "train_spatial_frequency": train_spatial_frequencies,
        "test_spatial_frequency": test_spatial_frequencies,
        "train_contrast": train_contrasts,
        "test_contrast": test_contrasts,
    #    # Backward-compatibility aliases (plural keys expected by some scripts)
    #    "train_spatial_frequencies": train_spatial_frequencies,
    #    "test_spatial_frequencies": test_spatial_frequencies,
   #     "train_contrasts": train_contrasts,
   #     "test_contrasts": test_contrasts,
    }

    return data


def load_model():
    config = get_hparams()
    p = init_globals(config)

    checkpoint_obj, checkpoint_path = utils.load_experiment_for('test', None, p.log_params)
    device = p.model_params.device

    assert checkpoint_obj is not None
    model = checkpoint_obj.get_model()
    print(f'Model Checkpoint is loaded from {p.log_params.load_from_eval}')

    model.summary()
    model = model.to(device)

    return model


def plot_image(image):
    import matplotlib.pyplot as plt

    # Plot the first image from the training images
    plt.imshow(image, cmap='gray', vmin=-1, vmax=1)
    plt.axis('off')
    plt.show()

def plot_images(images):
    import matplotlib.pyplot as plt

    # Plot the first image from the training images
    fig, axes = plt.subplots(1, images.shape[0], figsize=(20, 2))
    for i, ax in enumerate(axes):
        ax.imshow(images[i], cmap='gray', vmin=-1, vmax=1)
        ax.axis('off')
    plt.show()


def find_active_dimensions(model, data, num_active_dims=6):
    import torch
    import numpy as np

    # Get the first layer of the model
    device = next(model.parameters()).device
    output = model(torch.tensor(data["train_images"], dtype=torch.float32).to(device))

    y_std = output[1]["y"][1].scale.detach().numpy()

    y_scales = y_std.mean(axis=0)

    return np.argsort(y_scales)[:num_active_dims]


def get_constrained_means(output, data, active_dims, angle=None, phase=None, spatial_frequency=None, contrast=None):
    means = output[1]["y"][1].loc.detach().numpy()
    means = means[:, active_dims]

    # Create a combined mask for all conditions
    mask = np.ones(data["train_angle"].shape, dtype=bool)

    if angle is not None:
        mask &= np.abs(data["train_angle"] - angle) <= 1
    if phase is not None:
        mask &= np.abs(data["train_phase"] - phase) <= 0.01
    if spatial_frequency is not None:
        mask &= np.abs(data["train_spatial_frequency"] - spatial_frequency) <= 0.5
    if contrast is not None:
        mask &= np.abs(data["train_contrast"] - contrast) <= 0.05

#    if angle is not None:
#        mask &= (data["train_angle"] == angle)
#    if phase is not None:
#        mask &= (data["train_phase"] == phase)
#    if spatial_frequency is not None:
#        mask &= (data["train_spatial_frequency"] == spatial_frequency)
#    if contrast is not None:
#        mask &= (data["train_contrast"] == contrast)

    # Apply the mask to filter means and ensure alignment with data
    means = means[mask]
    images = data["train_images"][mask]

    return means, images

def get_constrained_images(data, angle=None, phase=None, spatial_frequency=None, contrast=None):
    mask = np.ones(data["train_angle"].shape, dtype=bool)
    if angle is not None:
        mask &= np.abs(data["train_angle"] - angle) <= 1
    if phase is not None:
        mask &= np.abs(data["train_phase"] - phase) <= 0.01
    if spatial_frequency is not None:
        mask &= np.abs(data["train_spatial_frequency"] - spatial_frequency) <= 0.5
    if contrast is not None:
        mask &= np.abs(data["train_contrast"] - contrast) <= 0.05

    return data["train_images"][mask]

def get_constrained_std(output, data, active_dims, angle=None, phase=None, spatial_frequency=None, contrast=None):
    std = output[1]["y"][1].scale.detach().numpy()
    std = std[:, active_dims]

    # Create a combined mask for all conditions
    mask = np.ones(data["train_angle"].shape, dtype=bool)

    if angle is not None:
        mask &= (data["train_angle"] == angle)
    if phase is not None:
        mask &= (data["train_phase"] == phase)
    if spatial_frequency is not None:
        mask &= (data["train_spatial_frequency"] == spatial_frequency)
    if contrast is not None:
        mask &= (data["train_contrast"] == contrast)

    # Apply the mask to filter means and ensure alignment with data
    std = std[mask]

    return std



def calculate_avg_abs_z_by_angle(data, model, spatial_frequency=None, contrast=None, abs=True, 
                                 prior_images=None, one_layer=False, alpha=0, use_prior_z=False, 
                                 avg_abs_prior=False, active_dims=None, weights=True, moment='mean', threshold=0.0,
                                 component_weights=None, natural_prior_weight=None, 
                                 mixture_of_gaussians=False, n_mixture_samples=20):
    #create mask for the spatial frequency and contrast using approximate matching
    mask_train = np.ones(data["train_spatial_frequency"].shape, dtype=bool)
    mask_test = np.ones(data["test_spatial_frequency"].shape, dtype=bool)
    if spatial_frequency is not None:
        mask_train = np.abs(data["train_spatial_frequency"] - spatial_frequency) <= 0.05
        mask_test = np.abs(data["test_spatial_frequency"] - spatial_frequency) <= 0.05
    if contrast is not None:
        mask_train &= np.abs(data["train_contrast"] - contrast) <= 0.05
        mask_test &= np.abs(data["test_contrast"] - contrast) <= 0.05

    # Concatenate the training and testing images and angles
    all_images = np.concatenate((data["train_images"][mask_train], data["test_images"][mask_test]))
    all_angles = np.concatenate((data["train_angle"][mask_train], data["test_angle"][mask_test]))

    # Get the unique angles
    unique_angles = np.unique(all_angles)

    # Initialize an empty dictionary to store the images for each angle
    images_by_angle = {}

    # Loop over each unique angle
    for angle in unique_angles:
        # Get the indices of images with the current angle
        idx = np.where(all_angles == angle)[0]
        
        # Get the images with the current angle
        images = all_images[idx]
        
        # Store the images in the dictionary
        images_by_angle[angle] = images

    # Initialize an empty dictionary to store the average absolute values for each angle
    avg_abs_z_by_angle = {}

    # Loop over each angle in the images_by_angle dictionary
    for angle, images in images_by_angle.items():
        # Check if prior images are provided
        if prior_images is None:
            # Original behavior: Pass the images to the model and get the output
            device = next(model.parameters()).device
            z_dist = model(torch.tensor(images, dtype=torch.float32).to(device))[1]["z"][1]
            output_array = laplace_moment(z_dist, moment=moment, threshold=threshold).cpu().detach().numpy()
        else:
            # New behavior: Use apply_model_with_prior functions
            if one_layer:
                output_array = apply_model_with_prior_one_layer(model, images, prior_images, moment=moment, threshold=threshold, component_weights=component_weights, natural_prior_weight=natural_prior_weight)
            else:
                output_array = apply_model_with_prior(model, images, prior_images, 
                                                     alpha=alpha, use_prior_z=use_prior_z, 
                                                     avg_abs_prior=avg_abs_prior, 
                                                     active_dims=active_dims, weights=weights,
                                                     mixture_of_gaussians=mixture_of_gaussians,
                                                     n_mixture_samples=n_mixture_samples)
        
        if abs:
            output_array = np.abs(output_array)
        # Calculate the average absolute value of all "z" variables
        avg_abs_z = np.mean(output_array, axis=0)
        
        # Store the average absolute values in the dictionary with integer key
        avg_abs_z_by_angle[int(angle)] = avg_abs_z

    return avg_abs_z_by_angle

def plot_orientation_tuning(avg_abs_z_by_angle, idx):
    # Sort the dictionary by key (angle) and extract the values
    angles = sorted(avg_abs_z_by_angle.keys())
    avg_abs_z_values = [avg_abs_z_by_angle[angle] for angle in angles]

    # Plot the average absolute values as a function of the angle
    import matplotlib.pyplot as plt
    plt.plot(angles, [values[idx] for values in avg_abs_z_values], label=f'z_{idx}')
    plt.xlabel('Angle')
    plt.ylabel('Average Absolute Value of "z" Variables')
    plt.legend()
    plt.show()

def find_most_responsive_angles_threshold(avg_abs_z_by_angle, decay_threshold=0.25, return_max_values=False):
    """
    Find the most responsive angles for each z component based on orientation tuning curves using threshold-based method.
    
    Args:
        avg_abs_z_by_angle (dict): Dictionary where keys are angles and values are arrays of average absolute z values.
        decay_threshold (float): Threshold for determining if a peak is steep enough (default: 0.25).
        return_max_values (bool): If True, also return the maximum value for each z component (default: False).
    
    Returns:
        numpy.ndarray or tuple: 
            - If return_max_values=False: Array of most responsive angles (NaN for non-responsive components)
            - If return_max_values=True: Tuple of (most_responsive_angles, max_values) where max_values contains the peak values
    """
    # Initialize an empty numpy array to store the most responsive angles for each z component
    most_responsive_angles = np.zeros(len(avg_abs_z_by_angle[0]))
    
    # Initialize an empty numpy array to store the maximum values for each z component
    if return_max_values:
        max_values = np.zeros(len(avg_abs_z_by_angle[0]))

    # Loop over each z component
    for i in range(len(avg_abs_z_by_angle[0])):
        # Get the values for the current z component
        values = [angle_values[i] for angle_values in avg_abs_z_by_angle.values()]
        
        # Find the index of the maximum value
        max_index = np.argmax(values)
        
        # Calculate the angle of the maximum value
        angles = list(avg_abs_z_by_angle.keys())
        max_angle = angles[max_index]
        
        # Store the maximum value if requested
        if return_max_values:
            max_values[i] = values[max_index]
        
        # Check if the peak is steep enough
        left_angle = (max_angle - 40) % 180
        left_index = angles.index(left_angle)
        left_value = values[left_index]
        right_angle = (max_angle + 40) % 180
        right_index = angles.index(right_angle)
        right_value = values[right_index]
        if left_value / values[max_index] < decay_threshold and right_value / values[max_index] < decay_threshold:
            # Check if the value decays to 25% at 90 degrees away
            decay_angle = (max_angle + 90) % 180
            decay_index = angles.index(decay_angle)
            decay_value = values[decay_index]
            if decay_value / values[max_index] < decay_threshold:
                most_responsive_angles[i] = max_angle
            else:
                most_responsive_angles[i] = np.nan
        else:
            most_responsive_angles[i] = np.nan

    if return_max_values:
        return most_responsive_angles, max_values
    else:
        return most_responsive_angles


def find_most_responsive_angles_von_mises(avg_abs_z_by_angle, return_max_values=False, maxfev=10000, threshold=0.1):
    """
    Find the most responsive angles for each z component using a scaled von Mises function (not a probability distribution).
    Uses least squares fitting and compares R2 to a constant fit. If R2 improvement > threshold, returns the peak position and amplitude as max_value.
    Otherwise, returns NaN.

    Args:
        avg_abs_z_by_angle (dict): Dictionary where keys are angles (in degrees) and values are arrays of average absolute z values.
        return_max_values (bool): If True, also return the amplitude (A) for each z component as max_value.
        maxfev (int): Maximum number of function evaluations for curve_fit (default: 10000).
        threshold (float): Minimum R2 improvement over constant fit to accept von Mises fit (default: 0.1).

    Returns:
        numpy.ndarray or tuple: 
            - If return_max_values=False: Array of most responsive angles (NaN for non-responsive components)
            - If return_max_values=True: Tuple of (most_responsive_angles, max_values) where max_values contains the amplitude (A)
    """
    import numpy as np
    from scipy.optimize import curve_fit
    from joblib import Parallel, delayed

    angles_deg = np.array(list(avg_abs_z_by_angle.keys()))
    angles_rad = np.radians(angles_deg)
    n_z = len(next(iter(avg_abs_z_by_angle.values())))

    def fit_one_z(i):
        y = np.array([v[i] for v in avg_abs_z_by_angle.values()])
        # Fit constant model
        B0 = np.mean(y)
        y_const = np.full_like(y, B0)
        # Fit von Mises model
        # Initial guesses: amplitude=max(y)-min(y), kappa=1, mu=angle of max, baseline=min(y)
        A0 = np.max(y) - np.min(y)
        kappa0 = 1.0
        mu0 = angles_rad[np.argmax(y)]
        B0 = np.min(y)
        p0 = [A0, kappa0, mu0, B0]
        bounds = ([0, 0, 0, -np.inf], [np.inf, 50, 2*np.pi, np.inf])
        try:
            popt, _ = curve_fit(
                lambda theta, A, kappa, mu, B: A * np.exp(kappa * np.cos(2 * (theta - mu))) + B,
                angles_rad, y, p0=p0, bounds=bounds, maxfev=maxfev
            )
            y_fit = popt[0] * np.exp(popt[1] * np.cos(2 * (angles_rad - popt[2]))) + popt[3]
            # R2 for von Mises
            ss_res = np.sum((y - y_fit) ** 2)
            ss_tot = np.sum((y - np.mean(y)) ** 2)
            r2_vonmises = 1 - ss_res / ss_tot if ss_tot > 0 else 0
            # R2 for constant
            ss_res_const = np.sum((y - y_const) ** 2)
            r2_const = 0  # always 0 for mean fit
            if r2_vonmises > r2_const + threshold:
                mu_deg = (np.degrees(popt[2]) % 180)
                max_value = popt[0] if return_max_values else None
                return mu_deg, max_value
            else:
                return np.nan, np.nan if return_max_values else np.nan
        except Exception as e:
            return np.nan, np.nan if return_max_values else np.nan

    results = Parallel(n_jobs=-1, prefer="processes")(
        delayed(fit_one_z)(i) for i in range(n_z)
    )
    if return_max_values:
        most_responsive_angles = np.array([r[0] for r in results])
        max_values = np.array([r[1] for r in results])
        return most_responsive_angles, max_values
    else:
        most_responsive_angles = np.array([r[0] for r in results])
        return most_responsive_angles


def find_most_responsive_angles(avg_abs_z_by_angle, method='threshold', decay_threshold=0.25, return_max_values=False, maxfev=10000, threshold=0.5, filter_by_max_values=False, max_values_percentile=90):
    """
    Find the most responsive angles for each z component using specified method.
    
    Args:
        avg_abs_z_by_angle (dict): Dictionary where keys are angles and values are arrays of average absolute z values.
        method (str): Method to use ('threshold' or 'von_mises', default: 'threshold').
        decay_threshold (float): Threshold for determining if a peak is steep enough (only for threshold method, default: 0.25).
        return_max_values (bool): If True, also return the maximum value for each z component (default: False).
        maxfev (int): Maximum number of function evaluations for curve_fit (only for von_mises method, default: 10000).
        threshold (float): Minimum R2 improvement over constant fit to accept von Mises fit (only for von_mises method, default: 0.1).
        filter_by_max_values (bool): If True, filter out components with low max_values by setting their angles to NaN (default: False).
        max_values_percentile (float): Percentile threshold for filtering max_values. Components below this percentile will have their angles set to NaN (default: 90).
    
    Returns:
        numpy.ndarray or tuple: 
            - If return_max_values=False: Array of most responsive angles (NaN for non-responsive components)
            - If return_max_values=True: Tuple of (most_responsive_angles, max_values) where max_values contains the peak values
    """
    # Determine if we need max_values for filtering
    need_max_values = filter_by_max_values or return_max_values
    
    if method == 'threshold':
        result = find_most_responsive_angles_threshold(avg_abs_z_by_angle, decay_threshold, return_max_values=need_max_values)
    elif method == 'von_mises':
        result = find_most_responsive_angles_von_mises(avg_abs_z_by_angle, return_max_values=need_max_values, maxfev=maxfev, threshold=threshold)
    else:
        raise ValueError(f"Unknown method: {method}. Use 'threshold' or 'von_mises'.")
    
    # Apply max_values filtering if requested
    if filter_by_max_values:
        if isinstance(result, tuple):
            most_responsive_angles, max_values = result
        else:
            # This shouldn't happen since we called with need_max_values=True
            raise ValueError("max_values not available for filtering")
        
        # Calculate the percentile threshold, ignoring NaNs
        valid_mask = ~np.isnan(max_values)
        valid_max_values = max_values[valid_mask]
        percentile_threshold = np.percentile(valid_max_values, max_values_percentile)
        
        # Mask for high values among valid indices
        high_value_mask = np.zeros_like(max_values, dtype=bool)
        high_value_mask[valid_mask] = valid_max_values > percentile_threshold
        
        most_responsive_angles[high_value_mask] = np.nan
        
        print(f"Filtered out {np.sum(high_value_mask)} components with max_values above {percentile_threshold:.4f} (percentile {max_values_percentile})")
        
        if return_max_values:
            return most_responsive_angles, max_values
        else:
            return most_responsive_angles
    
    return result

def response_by_orientation(image, model, most_responsive_angles, prior_image=None, \
                            absolute=True, natural_prior=False, alpha=0, use_prior_z=False, avg_abs_prior=False,\
                                  active_dims=None, weights=True, one_layer=False, return_output_laplace=False,\
                                    c_adjustment=False, prior_scale_factor=1.0, image_scale_factor=1.0, return_component_norms=False, bin_size=5, moment='mean', threshold=0.0,
                                    component_weights=None, natural_prior_weight=None, mixture_of_gaussians=False, n_mixture_samples=20, return_variability=False):
    """
    Calculate the average response by each orientation by averaging over the z dimension for the given orientation.

    Args:
        image (numpy.ndarray): The input image.
        model (torch.nn.Module): The trained model.
        most_responsive_angles (numpy.ndarray): Array of orientations corresponding to z components.
        prior_image (numpy.ndarray, optional): Prior image(s) to use.
        absolute (bool): Whether to take absolute values of responses.
        natural_prior (bool): Whether to use natural prior.
        alpha (float): Weight for combining y_from_image and y_prior.
        use_prior_z (bool): Whether to use prior z.
        avg_abs_prior (bool): Whether to average absolute prior.
        active_dims (array-like, optional): Active dimensions to use.
        weights (bool): Whether to use weights.
        one_layer (bool): Whether to use one layer processing.
        return_output_laplace (bool): Whether to return output as Laplace distribution.
        c_adjustment (bool): Whether to apply c adjustment.
        prior_scale_factor (float): Scale factor for prior.
        image_scale_factor (float): Scale factor for image z distribution scale.
        return_component_norms (bool): If True, also return the normalization constants for each component.
        bin_size (float, optional): If provided, digitize angles into bins of this size (e.g., 5.0 for 5-degree bins).
                                   If None, use exact angles (assumes quantized angles like 0, 5, 10, etc.).
        moment (str): Which moment to compute: 'mean' (E[X], default), 'abs' (E[|X|]), or 'relu' (E[ReLU(X - x0)]).
        threshold (float): Threshold value for relu moment (x0 in E[ReLU(X - x0)]). Default is 0.0.
        component_weights (array-like or None, optional): Optional non-negative weights for each mixture component.
            If None (default), all components are weighted equally (weight = 1).
            Only relative magnitudes matter, so the array does not need to sum to 1.
        natural_prior_weight (float or None, optional): If not None, adds a unit Laplace component 
            (zero mean, unit scale) to the mixture with the specified weight. Must be positive if provided.

    Returns:
        dict or tuple: If return_component_norms is False:
                          A dictionary where keys are orientations and values are the average responses.
                       If return_component_norms is True:
                          Tuple (orientation_responses, component_norms, component_norms_per_dim) where:
                          - orientation_responses: Dictionary of orientation responses
                          - component_norms: Normalization constants for each component, shape (M,)
                          - component_norms_per_dim: Normalization constants for each component and dimension, shape (M, N)
    """


    if prior_image is None:
        # Add an extra dimension to the image to simulate a batch if it is a single image
        if image.ndim == 2:
            image = np.expand_dims(image, axis=0)

        if natural_prior is True:
            output = apply_model_with_natural_prior(model, image, moment=moment, threshold=threshold)
        else:
            # Pass the image to the model and get the output
            device = next(model.parameters()).device
            z_dist = model(torch.tensor(image, dtype=torch.float32).to(device))[1]["z"][1]
            output = laplace_moment(z_dist, moment=moment, threshold=threshold).cpu().detach().numpy()

    else:
        if one_layer is False:
            output = apply_model_with_prior(model, image, prior_image, alpha=alpha, use_prior_z=use_prior_z, avg_abs_prior=avg_abs_prior, \
                                            active_dims=active_dims, weights=weights,
                                            mixture_of_gaussians=mixture_of_gaussians,
                                            n_mixture_samples=n_mixture_samples)
        else:
            if return_component_norms:
                output, component_norms, component_norms_per_dim = apply_model_with_prior_one_layer(model, image, prior_image, prior_scale_factor=prior_scale_factor, image_scale_factor=image_scale_factor, return_component_norms=True, moment=moment, threshold=threshold, component_weights=component_weights, natural_prior_weight=natural_prior_weight)
            else:
                output = apply_model_with_prior_one_layer(model, image, prior_image, prior_scale_factor=prior_scale_factor, image_scale_factor=image_scale_factor, return_component_norms=False, moment=moment, threshold=threshold, component_weights=component_weights, natural_prior_weight=natural_prior_weight)
        #import pdb; pdb.set_trace()
        #average the output over the batch dimension
    #output = np.mean(np.abs(output), axis=0)

    if absolute is True:
        output = np.abs(output)
    
    # Preserve per-image outputs for variability; compute mean across images for central tendency
    output_mean_across_images = np.mean(output, axis=0)

    # Initialize a dictionary to store the average response for each orientation
    orientation_responses = {}
    orientation_stds = {} if return_variability else None

    # Handle digitization if bin_size is provided
    if bin_size is not None:
        # Create a copy of most_responsive_angles for digitization
        digitized_angles = most_responsive_angles.copy()
        # Only digitize non-NaN values
        valid_mask = ~np.isnan(most_responsive_angles)
        if np.any(valid_mask):
            # Digitize angles: round to nearest bin center
            digitized_angles[valid_mask] = np.round(most_responsive_angles[valid_mask] / bin_size) * bin_size
            # Ensure angles are in [0, 180) range
            digitized_angles[valid_mask] = digitized_angles[valid_mask] % 180
    else:
        digitized_angles = most_responsive_angles

    # Loop through each unique orientation in digitized_angles
    unique_orientations = np.unique(digitized_angles[~np.isnan(digitized_angles)])
    for orientation in unique_orientations:
        # Find the indices of z components corresponding to the current orientation
        indices = np.where(digitized_angles == orientation)[0]
        # Calculate the mean response for the current orientation (across z components, using the image-mean activations)
        avg_response = np.mean(output_mean_across_images[indices])
        # Optionally compute variability across images/phases: std of per-image population means for this orientation
        if return_variability:
            per_image_population_means = np.mean(output[:, indices], axis=1)
            std_response = np.std(per_image_population_means)
        # Store the result in the dictionary
        orientation_responses[orientation] = np.abs(avg_response) if absolute else avg_response
        if return_variability:
            orientation_stds[orientation] = std_response

    if return_output_laplace:
        #initialize a Laplace distribution with zero mean and scale based on the mean across images
        if c_adjustment:
            scale = output_mean_across_images * np.mean(output_mean_across_images)
        else:
            scale = output_mean_across_images
        # Convert scale to PyTorch tensor
        scale = torch.tensor(scale, dtype=torch.float32, device=next(model.parameters()).device)
        laplace_dist = torch.distributions.Laplace(loc=0, scale=scale)
        if return_component_norms:
            if return_variability:
                return orientation_responses, laplace_dist, component_norms, component_norms_per_dim, orientation_stds
            else:
                return orientation_responses, laplace_dist, component_norms, component_norms_per_dim
        else:
            if return_variability:
                return orientation_responses, laplace_dist, orientation_stds
            else:
                return orientation_responses, laplace_dist
    else:
        if return_component_norms:
            if return_variability:
                return orientation_responses, component_norms, component_norms_per_dim, orientation_stds
            else:
                return orientation_responses, component_norms, component_norms_per_dim
        else:
            if return_variability:
                return orientation_responses, orientation_stds
            else:
                return orientation_responses

def plot_avg_response_by_orientation(orientation_responses, comparison_orientation_responses=None, title="Average Response by Orientation"):
    """
    Plot the average response as a function of orientation with a spline.

    Args:
        orientation_responses (dict): A dictionary where keys are orientations and values are the average responses.
        comparison_orientation_responses (dict, optional): A dictionary for comparison, where keys are orientations and values are the average responses.
    """
    import matplotlib.pyplot as plt
    from scipy.interpolate import UnivariateSpline
    import numpy as np

    # Sort the orientations and corresponding responses
    orientations = sorted(orientation_responses.keys())
    avg_responses = [orientation_responses[orientation] for orientation in orientations]

    # Fit a spline to the data
    spline = UnivariateSpline(orientations, avg_responses, s=0.2)
    smoothed_orientations = np.linspace(min(orientations), max(orientations), 500)
    smoothed_responses = spline(smoothed_orientations)

    # Plot the data
    plt.figure(figsize=(8, 6))
    plt.plot(orientations, avg_responses, marker='o', linestyle='-', color='b', label='Avg Response')
    plt.plot(smoothed_orientations, smoothed_responses, linestyle='--', color='r', label='Spline Fit')

    # If comparison_orientation_responses is provided, plot it
    if comparison_orientation_responses:
        comp_orientations = sorted(comparison_orientation_responses.keys())
        comp_avg_responses = [comparison_orientation_responses[orientation] for orientation in comp_orientations]

        # Fit a spline to the comparison data
        comp_spline = UnivariateSpline(comp_orientations, comp_avg_responses, s=0.2)
        comp_smoothed_orientations = np.linspace(min(comp_orientations), max(comp_orientations), 500)
        comp_smoothed_responses = comp_spline(comp_smoothed_orientations)

        # Plot the comparison data
        plt.plot(comp_orientations, comp_avg_responses, marker='x', linestyle='-', color='g', label='Comparison Avg Response')
        plt.plot(comp_smoothed_orientations, comp_smoothed_responses, linestyle='--', color='orange', label='Comparison Spline Fit')

    plt.xlabel('Orientation (degrees)')
    plt.ylabel('Average Response')
    plt.title(title)
    plt.grid(True)
    plt.legend()
    plt.show()



def apply_model_with_prior(model, image, prior_image, alpha=0, use_prior_z=False, avg_abs_prior=False,
                           active_dims=None, weights=True, mixture_of_gaussians=False, n_mixture_samples=20):
    """
    Apply the model with a prior image and compute the output.

    Args:
        model (torch.nn.Module): The trained model.
        image (numpy.ndarray): The input image.
        prior_image (numpy.ndarray): The prior image(s).
        alpha (float): Weight for combining y_from_image and y_prior. 
                       Should be in the range [0, 1].
        mixture_of_gaussians (bool): If True, use mixture-of-gaussians mode.
        n_mixture_samples (int): Number of samples to draw from mixture distribution.

    Returns:
        numpy.ndarray: The computed output.
    """
    # Check if image is a single image or a batch
    if image.ndim == 2:  # Single image case
        image = np.expand_dims(image, axis=0)

    # Convert image to tensor and compute hidden representation
    device = next(model.parameters()).device
    computed_hidden = model(torch.tensor(image, dtype=torch.float32).to(device), stop_at='hiddens')[0]
    z_block = model.blocks[3]

    # Posterior y distribution for the image
    y_post = model(torch.tensor(image, dtype=torch.float32).to(device), stop_at='y')[1]["y"][1]
    y_from_image = y_post.loc
    y_scale_from_image = y_post.scale
    # Zero out the non-active dimensions, if requested
    if active_dims is not None:
        y_active = y_from_image[:, active_dims]
        y_from_image = torch.zeros_like(y_from_image)
        y_from_image[:, active_dims] = y_active

    # ------------------------------------------------------------------
    # Mixture-of-Gaussians mode over y
    # ------------------------------------------------------------------
    if mixture_of_gaussians:
        # Ensure prior_image has batch dimension
        if prior_image is not None and prior_image.ndim == 2:
            prior_image = np.expand_dims(prior_image, axis=0)

        # Compute posterior y for prior images, if any
        if prior_image is not None:
            prior_tensor = torch.tensor(prior_image, dtype=torch.float32).to(device)
            with torch.no_grad():
                y_prior_dist = model(prior_tensor, stop_at='y')[1]["y"][1]
            prior_mus = y_prior_dist.loc   # (P, Dy)
            prior_scales = y_prior_dist.scale
        else:
            prior_mus, prior_scales = None, None

        outputs_list = []
        for b in range(image.shape[0]):
            # Prepare base parameters
            base_mu = y_from_image[b]
            base_var = (y_scale_from_image[b] ** 2)
            # Prepare component parameters from priors (if any)
            if prior_mus is not None:
                comp_mus_t = prior_mus.clone()  # shape (P, Dy)
                comp_vars_t = (prior_scales ** 2).clone()
            else:
                # No prior components – fallback to original behaviour (just use base)
                comp_mus_t = torch.empty((0, base_mu.shape[0]), device=device)
                comp_vars_t = torch.empty_like(comp_mus_t)

            # Use gaussian_mixture_fusion to build mixture
            mixture_dist, _, _, _ = gaussian_mixture_fusion(
                base_mu=base_mu,
                base_var=base_var,
                comp_mus=comp_mus_t,
                comp_vars=comp_vars_t,
            )
            # Draw samples from the mixture
            y_samples = mixture_dist.sample((n_mixture_samples,))  # (S, Dy)
            img_tensor = torch.tensor(image[b:b+1], dtype=torch.float32).to(device)
            sample_outputs = []
            for s in range(n_mixture_samples):
                hidden_b = model(img_tensor, stop_at='hiddens')[0]
                hidden_b['y'] = y_samples[s].unsqueeze(0)
                z_out = z_block(hidden_b)

                # If requested, fuse this sample's posterior z with its prior z, then store the fused mean
                if use_prior_z is not None and use_prior_z is not False:
                    if not (0 < use_prior_z < 1):
                        raise ValueError("use_prior_z must be a scalar between 0 and 1 (exclusive)")
                    mu_base = z_out[1][1].loc.detach().cpu().numpy()      # (1, Dz)
                    sigma_base = z_out[1][1].scale.detach().cpu().numpy()  # (1, Dz)
                    mu_prior = z_out[1][0].loc.detach().cpu().numpy()      # (1, Dz)
                    sigma_prior = z_out[1][0].scale.detach().cpu().numpy() # (1, Dz)

                    scaling_factor = use_prior_z / (1.0 - use_prior_z)
                    sigma_prior = sigma_prior * scaling_factor

                    fused = laplace_fusion_mean(
                        mu_base[0],
                        sigma_base[0],
                        mu_prior[0:1],
                        sigma_prior[0:1],
                        moment="mean",
                        divide_unit_laplace=False,
                    )  # (Dz,)
                    sample_outputs.append(fused)
                else:
                    # Use posterior mean directly
                    z_mean = z_out[1][1].loc.detach().cpu().numpy()[0]  # (Dz,)
                    sample_outputs.append(z_mean)

            # Average over all samples for this image
            outputs_list.append(np.mean(np.stack(sample_outputs, axis=0), axis=0))
        return np.stack(outputs_list, axis=0)

    # Check if prior_image is a single image or a batch
    if prior_image.ndim == 2:  # Single prior image case
        prior_image = np.expand_dims(prior_image, axis=0)

    # Initialize a list to store outputs for each prior image
    batch_outputs = []

    if weights:
        unnorm_log_weights = []

    # Iterate over each prior image in the batch
    for single_prior in prior_image:
        single_prior = np.expand_dims(single_prior, axis=0)  # Ensure single_prior has batch dimension
      
        y_prior = model(torch.tensor(single_prior, dtype=torch.float32).to(device), stop_at='y')[1]["y"][1].loc

        # Broadcast y_prior to match the batch size of the image
        y_prior_broadcasted = y_prior.expand(image.shape[0], -1)

        if weights:
            # Calculate log weights for single prior
            posterior = model(torch.tensor(image, dtype=torch.float32).to(device), stop_at='y')[1]["y"][1]
            prior = model(torch.tensor(image, dtype=torch.float32).to(device), stop_at='y')[1]["y"][0]
            log_weights = torch.sum(posterior.log_prob(y_prior) - prior.log_prob(y_prior), dim=1)
            unnorm_log_weights.append(log_weights)

        # Linearly combine y_from_image and y_prior_broadcasted using alpha
        y_combined = alpha * y_from_image + (1 - alpha) * y_prior_broadcasted

        # Assign the combined y to computed_hidden["y"]
        computed_hidden["y"] = y_combined

        # Pass the updated hidden representation through the z_block
        z_output = z_block(computed_hidden)

        # Determine how to combine posterior and prior z distributions based on use_prior_z
        # New behaviour: if use_prior_z is a scalar between 0 and 1 (exclusive), fuse the two
        # Laplace distributions using laplace_fusion_mean. Otherwise, fall back to the
        # original behaviour of returning the posterior mean.
        if use_prior_z is None or use_prior_z is False:
            # Default behaviour – use posterior z mean only
            output = z_output[1][1].loc
        else:
            # Ensure valid weighting factor
            if not (0 < use_prior_z < 1):
                raise ValueError("use_prior_z must be a scalar between 0 and 1 (exclusive)")

            # Posterior parameters (treated as base distribution)
            mu_base = z_output[1][1].loc.detach().cpu().numpy()
            sigma_base = z_output[1][1].scale.detach().cpu().numpy()

            # Prior parameters (single component)
            mu_prior = z_output[1][0].loc.detach().cpu().numpy()
            sigma_prior = z_output[1][0].scale.detach().cpu().numpy()

            # Rescale prior scale according to specified factor
            scaling_factor = use_prior_z / (1.0 - use_prior_z)
            sigma_prior = sigma_prior * scaling_factor

            # Fuse posterior and prior per image in the batch
            fused_means = []
            for i_img in range(mu_base.shape[0]):
                fused = laplace_fusion_mean(
                    mu_base[i_img],
                    sigma_base[i_img],
                    mu_prior[i_img : i_img + 1],    # shape (1, N)
                    sigma_prior[i_img : i_img + 1], # shape (1, N)
                    moment="mean",
                    divide_unit_laplace=False,
                )
                fused_means.append(fused)

            # Convert fused means back to torch tensor on the correct device/dtype
            output = torch.tensor(np.array(fused_means), device=z_output[1][1].loc.device, dtype=z_output[1][1].loc.dtype)

        # Append the output for this prior image
        batch_outputs.append(output.detach().numpy())

    # Compute the average of the outputs over all prior images
    if avg_abs_prior:
        batch_outputs = np.abs(batch_outputs)

    if weights:
        # Stack unnormalized log weights into a tensor
        unnorm_log_weights = torch.stack(unnorm_log_weights)

        # Compute max_log_weights along the prior axis (axis=0)
        max_log_weights = torch.max(unnorm_log_weights, axis=0, keepdim=True)[0]
        
        # Apply the log-sum-exp trick to normalize weights
        exp_shifted = torch.exp(unnorm_log_weights - max_log_weights)
        norm_weights = (exp_shifted / torch.sum(exp_shifted, dim=0, keepdim=True)).detach().numpy()


        # Calculate the weighted average of the outputs
        # Stack the list into a 3D array of shape (200, 100, 1800)
        stacked_outputs = np.stack(batch_outputs, axis=0)  # shape: (200, 100, 1800)

        # Reshape weights to broadcast correctly: (200, 100, 1)
        reshaped_weights = norm_weights[:, :, np.newaxis]  # shape: (200, 100, 1)

        # Compute the weighted average along axis 0
        # Numerator: element-wise multiplication, then sum over axis 0
        numerator = np.sum(stacked_outputs * reshaped_weights, axis=0)  # shape: (100, 1800)

        # Denominator: sum of weights over axis 0
        denominator = np.sum(reshaped_weights, axis=0)  # shape: (100, 1)

        # Avoid division by zero just in case
        weighted_avg = numerator / np.maximum(denominator, 1e-8)

        
        return weighted_avg
    else:
        return np.mean(batch_outputs, axis=0)



def apply_model_with_natural_prior(model, image, moment='mean', threshold=0.0):
    # Check if image is a single image or a batch
    if image.ndim == 2:  # Single image case
        image = np.expand_dims(image, axis=0)

    # Convert image to tensor and compute hidden representation
    device = next(model.parameters()).device
    computed_hidden = model(torch.tensor(image, dtype=torch.float32).to(device), stop_at='hiddens')[0]
    z_block = model.blocks[3]

    y_prior = model(torch.tensor(image, dtype=torch.float32).to(device), stop_at='y')[1]["y"][0].sample()

    # Broadcast y_prior to match the batch size of image
    y_prior_broadcasted = y_prior.expand(image.shape[0], -1)
    #import pdb; pdb.set_trace()
    # Assign the broadcasted y_prior to computed_hidden["y"]
    computed_hidden["y"] = y_prior_broadcasted

    # Pass the updated hidden representation through the z_block
    z_output = z_block(computed_hidden)


    return laplace_moment(z_output[1][1], moment=moment, threshold=threshold).detach().numpy()


def apply_model_with_prior_one_layer(model, image, prior_image, prior_scale_factor=1.0, image_scale_factor=1.0, return_component_norms=False, moment='mean', threshold=0.0, component_weights=None, natural_prior_weight=None):
    """
    Apply a one-layer model with a prior and compute the fused output using Laplace fusion.
    
    This function is designed for one-layer models that have 'z' distributions but not 'y' distributions.
    It computes z distributions from the image and either processes prior images through the model
    or uses parameters from a provided PyTorch distribution directly.

    Args:
        model (torch.nn.Module): The trained one-layer model.
        image (numpy.ndarray): The input image.
        prior_image (numpy.ndarray, torch.distributions.Distribution, or list of torch.distributions.Distribution): 
            Either prior images to process through the model, a single PyTorch distribution object with loc and scale parameters,
            or a list of PyTorch distribution objects to be used as mixture components.
        prior_scale_factor (float): Factor to multiply the prior z distribution scale by. 
                                   Values < 1.0 shrink the scale, > 1.0 expand it. Default is 1.0 (no change).
        image_scale_factor (float): Factor to multiply the image z distribution scale by.
                                   Values < 1.0 shrink the scale, > 1.0 expand it. Default is 1.0 (no change).
        return_component_norms (bool): If True, also return the normalization constants for each component.
                                      Default is False.
        moment (str): Which moment to compute: 'mean' (E[X], default), 'abs' (E[|X|]), or 'relu' (E[ReLU(X - x0)]).
        threshold (float): Threshold value for relu moment (x0 in E[ReLU(X - x0)]). Default is 0.0.
        component_weights (array-like or None, optional): Optional non-negative weights for each mixture component.
            If None (default), all components are weighted equally (weight = 1).
            Only relative magnitudes matter, so the array does not need to sum to 1.
        natural_prior_weight (float or None, optional): If not None, adds a unit Laplace component 
            (zero mean, unit scale) to the mixture with the specified weight. Must be positive if provided.

    Returns:
        numpy.ndarray or tuple: If return_component_norms is False:
                                   The fused mean from the Laplace fusion of z distributions.
                               If return_component_norms is True:
                                   Tuple (fused_mean, component_norms, component_norms_per_dim) where:
                                   - fused_mean: The fused mean from the Laplace fusion
                                   - component_norms: Normalization constants for each component, shape (M,)
                                   - component_norms_per_dim: Normalization constants for each component and dimension, shape (M, N)
    """
    # Check if image is a single image or a batch
    if image.ndim == 2:  # Single image case
        image = np.expand_dims(image, axis=0)

    # Get the device that the model is on
    device = next(model.parameters()).device
    
    # Convert image to tensor and move to the same device as the model
    image_tensor = torch.tensor(image, dtype=torch.float32).to(device)
    
    # Get z distributions from the model for the image
    image_output = model(image_tensor)
    
    # Extract z distribution parameters for the image (base distribution)
    z_image_mean = image_output[1]["z"][1].loc.detach().cpu().numpy()  # Shape: (batch_size, z_dim)
    z_image_scale = image_output[1]["z"][1].scale.detach().cpu().numpy()  # Shape: (batch_size, z_dim)
    
    # Apply image scale factor
    z_image_scale = z_image_scale * image_scale_factor
    #import pdb; pdb.set_trace()

    # Check if prior_image is a PyTorch distribution, list of distributions, or numpy array
    if isinstance(prior_image, torch.distributions.Distribution):
        # Extract parameters directly from the PyTorch distribution
        z_prior_mean = prior_image.loc.detach().cpu().numpy()  # Shape: (num_priors, z_dim)
        z_prior_scale = prior_image.scale.detach().cpu().numpy()  # Shape: (num_priors, z_dim)
        
        # Apply scale factor to prior scale
        z_prior_scale = z_prior_scale * prior_scale_factor
        
        # Ensure proper shape - if it's a single distribution, add batch dimension
        if z_prior_mean.ndim == 1:
            z_prior_mean = np.expand_dims(z_prior_mean, axis=0)
            z_prior_scale = np.expand_dims(z_prior_scale, axis=0)
    elif isinstance(prior_image, list) and all(isinstance(dist, torch.distributions.Distribution) for dist in prior_image):
        # Handle list of PyTorch distributions
        z_prior_mean_list = []
        z_prior_scale_list = []
        
        for dist in prior_image:
            # Extract parameters from each distribution
            dist_mean = dist.loc.detach().cpu().numpy()
            dist_scale = dist.scale.detach().cpu().numpy()
            
            # Apply scale factor to each distribution's scale
            dist_scale = dist_scale * prior_scale_factor
            
            # Ensure proper shape - if it's a single distribution, add batch dimension
            if dist_mean.ndim == 1:
                dist_mean = np.expand_dims(dist_mean, axis=0)
                dist_scale = np.expand_dims(dist_scale, axis=0)
            
            z_prior_mean_list.append(dist_mean)
            z_prior_scale_list.append(dist_scale)
        
        # Concatenate all distributions into mixture components
        z_prior_mean = np.concatenate(z_prior_mean_list, axis=0)  # Shape: (total_num_priors, z_dim)
        z_prior_scale = np.concatenate(z_prior_scale_list, axis=0)  # Shape: (total_num_priors, z_dim)
        
        print(f"Using {len(prior_image)} Laplace distributions as mixture components")
    else:
        # Process prior images through the model (original behavior)
        # Check if prior_image is a single image or a batch  
        if prior_image.ndim == 2:  # Single prior image case
            prior_image = np.expand_dims(prior_image, axis=0)
        
        # Convert prior images to tensor and move to device
        prior_tensor = torch.tensor(prior_image, dtype=torch.float32).to(device)
        
        # Get z distributions from the model for prior images
        prior_output = model(prior_tensor)
        
        # Extract z distribution parameters for the prior (mixture distribution)  
        z_prior_mean = prior_output[1]["z"][1].loc.detach().cpu().numpy()  # Shape: (batch_size, z_dim)
        z_prior_scale = prior_output[1]["z"][1].scale.detach().cpu().numpy()  # Shape: (batch_size, z_dim)
        
        # Apply scale factor to prior scale
        z_prior_scale = z_prior_scale * prior_scale_factor
    
    # Initialize list to store fused results for each image in the batch
    fused_results = []
    component_norms_list = []
    component_norms_per_dim_list = []
    
    # Process each image in the batch
    for i in range(image.shape[0]):
        print(f"Processing image {i} of {image.shape[0]}")
        # Get parameters for this specific image
        mu_base = z_image_mean[i]  # Shape: (z_dim,)
        sigma_base = z_image_scale[i]  # Shape: (z_dim,)
        
        # For the prior, we treat each prior image as a mixture component
        # Reshape prior parameters to be compatible with laplace_fusion_mean
        mu2_mat = z_prior_mean  # Shape: (num_priors, z_dim)
        sigma2_mat = z_prior_scale  # Shape: (num_priors, z_dim)
        
        # Apply Laplace fusion
        if return_component_norms:
            fused_mean, comp_norms, comp_norms_per_dim = laplace_fusion_mean(mu_base, sigma_base, mu2_mat, sigma2_mat, return_component_norms=True, moment=moment, threshold=threshold, component_weights=component_weights, natural_prior_weight=natural_prior_weight)
            component_norms_list.append(comp_norms)
            component_norms_per_dim_list.append(comp_norms_per_dim)
        else:
            fused_mean = laplace_fusion_mean(mu_base, sigma_base, mu2_mat, sigma2_mat, return_component_norms=False, moment=moment, threshold=threshold, component_weights=component_weights, natural_prior_weight=natural_prior_weight)
        # import pdb; pdb.set_trace()
        fused_results.append(fused_mean)
    
    # Convert results to numpy array
    fused_results = np.array(fused_results)
    
    if return_component_norms:
        # Return normalization constants for all images in the batch
        return fused_results, component_norms_list, component_norms_per_dim_list
    else:
        return fused_results

def orientation_tuning_analysis(data, model, stimulus_angle, prior_angles=None, spatial_frequency=None, contrast=None, one_layer=True, tuning_prior_angles=None, moment='mean', threshold=0.0, component_weights=None, natural_prior_weight=None, mixture_of_gaussians=False, n_mixture_samples=20):
    """
    Perform orientation tuning analysis by combining multiple analysis steps.
    
    Args:
        model: The trained model
        stimulus_angle: The angle of the stimulus to analyze
        prior_angles: List of angles to use as priors
        spatial_frequency: Spatial frequency parameter
        contrast: Contrast parameter
        one_layer: Boolean indicating whether to use one layer processing
        data: The dataset dictionary containing train/test images and parameters
        tuning_prior_angles: Optional parameter for tuning prior (if None, calculated without prior_image option)
        moment (str): Which moment to compute: 'mean' (E[X], default), 'abs' (E[|X|]), or 'relu' (E[ReLU(X - x0)]).
        threshold (float): Threshold value for relu moment (x0 in E[ReLU(X - x0)]). Default is 0.0.
        component_weights (array-like or None, optional): Optional non-negative weights for each mixture component.
            If None (default), all components are weighted equally (weight = 1).
            Only relative magnitudes matter, so the array does not need to sum to 1.
        natural_prior_weight (float or None, optional): If not None, adds a unit Laplace component 
            (zero mean, unit scale) to the mixture with the specified weight. Must be positive if provided.
    
    Returns:
        tuple: (initial_tuning, most_responsive_angles, final_tuning) where:
            - initial_tuning: Response by orientation from first step
            - most_responsive_angles: Array of most responsive angles
            - final_tuning: Response by orientation using most responsive angles
    """
    
    # Step 1: Get constrained images for the stimulus angle
    stimulus_images = get_constrained_images(data, angle=stimulus_angle, 
                                           spatial_frequency=spatial_frequency, 
                                           contrast=contrast)
    print(f"Found {len(stimulus_images)} stimulus images for angle {stimulus_angle}")
    
    # Step 2: Determine prior images if prior_angles are provided
    prior_images = None
    if prior_angles:
        prior_image_list = []
        for angle in prior_angles:
            angle_images = get_constrained_images(data, angle=angle,
                                                spatial_frequency=spatial_frequency,
                                                contrast=contrast)
            if len(angle_images) > 0:
                prior_image_list.append(angle_images)
                print(f"Found {len(angle_images)} prior images for angle {angle}")
        
        # Combine all prior images
        if prior_image_list:
            prior_images = np.concatenate(prior_image_list, axis=0)
            print(f"Combined {len(prior_images)} total prior images")
    
    # Step 3: Determine tuning prior images if tuning_prior is provided
    tuning_prior_images = None
    if tuning_prior_angles is not None:
        prior_image_list = []
        for angle in tuning_prior_angles:
            angle_images = get_constrained_images(data, angle=angle,
                                                spatial_frequency=spatial_frequency,
                                                contrast=contrast)
            if len(angle_images) > 0:
                prior_image_list.append(angle_images)
                print(f"Found {len(angle_images)} tuning prior images for angle {angle}")

        if prior_image_list:
            tuning_prior_images = np.concatenate(prior_image_list, axis=0)
            print(f"Combined {len(tuning_prior_images)} total tuning prior images")

    
    # Step 4: Call calculate_avg_abs_z_by_angle 
    # If tuning_prior is None, calculate without prior_image option
    # If tuning_prior is provided, calculate with tuning prior images
    print("Calculating average absolute z by angle...")
    if tuning_prior_angles is None:
        # Calculate without prior images
        initial_tuning = calculate_avg_abs_z_by_angle(data, model, 
                                                     spatial_frequency=spatial_frequency, 
                                                     contrast=contrast,
                                                     prior_images=None,
                                                     one_layer=one_layer,
                                                     moment=moment,
                                                     threshold=threshold,
                                                     component_weights=component_weights,
                                                     natural_prior_weight=natural_prior_weight,
                                                     mixture_of_gaussians=mixture_of_gaussians,
                                                     n_mixture_samples=n_mixture_samples)
    else:
        # Calculate with tuning prior images
        
        initial_tuning = calculate_avg_abs_z_by_angle(data, model, 
                                                     spatial_frequency=spatial_frequency, 
                                                     contrast=contrast,
                                                     prior_images=tuning_prior_images,
                                                     one_layer=one_layer,
                                                     moment=moment,
                                                     threshold=threshold,
                                                     component_weights=component_weights,
                                                     natural_prior_weight=natural_prior_weight,
                                                     mixture_of_gaussians=mixture_of_gaussians,
                                                     n_mixture_samples=n_mixture_samples)
    
    print(f"Calculated tuning for {len(initial_tuning)} angles")
    
    # Step 5: Find most responsive angles
    print("Finding most responsive angles...")
    result = find_most_responsive_angles(initial_tuning)
    if isinstance(result, tuple):
        most_responsive_angles, max_values = result
    else:
        most_responsive_angles = result
        max_values = None
    num_responsive = np.sum(~np.isnan(most_responsive_angles))
    print(f"Found {num_responsive} responsive angles out of {len(most_responsive_angles)} total")
    
    # Step 6: Calculate final response_by_orientation using the most responsive angles
    print("Calculating final response by orientation...")
    if len(stimulus_images) > 0:
        orientation_responses = response_by_orientation(stimulus_images, model,
                                             most_responsive_angles,
                                             prior_image=prior_images,
                                             one_layer=one_layer,
                                             moment=moment,
                                             threshold=threshold,
                                             mixture_of_gaussians=mixture_of_gaussians,
                                             n_mixture_samples=n_mixture_samples)
        print(f"Final tuning calculated for {len(orientation_responses)} orientations")
    else:
        orientation_responses = {}
        print("No stimulus images found, final tuning is empty")
    
    return initial_tuning, most_responsive_angles, orientation_responses

def plot_z_samples(model, data, angle, spatial_frequency=None, contrast=None, index=0, num_samples=10000, use_samples=False, samples_per_image=10, fit_distributions=False):
    """
    Plot samples from the "z" latent variable when constrained dataset is fed into the model.
    
    Args:
        model: The trained model
        data: The dataset dictionary containing train/test images and parameters
        angle: The angle to constrain the data to
        spatial_frequency: Spatial frequency parameter (optional)
        contrast: Contrast parameter (optional) 
        index: Which dimension/component of z to plot (default: 0)
        num_samples: Maximum number of images to use for plotting (default: 10000)
        use_samples: If True, draw samples from z distributions; if False, use means (default: False)
        samples_per_image: Number of samples to draw per image when use_samples=True (default: 10)
        fit_distributions: If True, fit Gaussian and Laplacian distributions and compare (default: False)
    """
    import matplotlib.pyplot as plt
    import torch
    
    # Get constrained images based on the specified parameters
    constrained_images = get_constrained_images(data, angle=angle, 
                                              spatial_frequency=spatial_frequency, 
                                              contrast=contrast)
    
    if len(constrained_images) == 0:
        print(f"No images found for angle={angle}, spatial_frequency={spatial_frequency}, contrast={contrast}")
        return
    
    print(f"Found {len(constrained_images)} constrained images")
    
    # Limit the number of samples if there are too many
    if len(constrained_images) > num_samples:
        indices = np.random.choice(len(constrained_images), num_samples, replace=False)
        constrained_images = constrained_images[indices]
        print(f"Randomly selected {num_samples} samples for plotting")
    
    # Feed the constrained images through the model
    device = next(model.parameters()).device
    with torch.no_grad():
        output = model(torch.tensor(constrained_images, dtype=torch.float32).to(device))
    
    # Extract z samples from the model output
    if use_samples:
        # Draw samples from the z distribution
        z_dist = output[1]["z"][1]
        z_samples_list = []
        for _ in range(samples_per_image):
            z_sample = z_dist.sample().cpu().detach().numpy()
            z_samples_list.append(z_sample)
        # Stack samples: shape will be (samples_per_image, num_images, z_dim)
        z_samples = np.concatenate(z_samples_list, axis=0)
        print(f"Drew {samples_per_image} samples per image, total: {z_samples.shape[0]} samples")
    else:
        # Use the means of the z distribution
        z_samples = output[1]["z"][1].loc.cpu().detach().numpy()
        print(f"Using distribution means, total: {z_samples.shape[0]} samples")
    
    # Check if the index is valid
    if index >= z_samples.shape[1]:
        print(f"Index {index} is out of range. z has {z_samples.shape[1]} dimensions.")
        return
    
    # Extract the samples for the specified index
    z_index_samples = z_samples[:, index]
    
    # Fit distributions if requested
    fit_results = None
    if fit_distributions:
        fit_results = fit_and_compare_distributions(z_index_samples)
    
    # Create the plot
    plt.figure(figsize=(15, 6) if fit_distributions else (10, 6))
    
    # Plot histogram
    n_subplots = 3 if fit_distributions else 2
    plt.subplot(1, n_subplots, 1)
    n, bins, patches = plt.hist(z_index_samples, bins=30, alpha=0.7, edgecolor='black', density=fit_distributions)
    plt.xlabel(f'z_{index} values')
    plt.ylabel('Density' if fit_distributions else 'Frequency')
    sample_type = "samples" if use_samples else "means"
    plt.title(f'Histogram of z_{index} {sample_type}\n(angle={angle}, sf={spatial_frequency}, contrast={contrast})')
    plt.grid(True, alpha=0.3)
    
    # Overlay fitted distributions if requested
    if fit_distributions and fit_results:
        x_range = np.linspace(z_index_samples.min(), z_index_samples.max(), 100)
        
        # Plot Gaussian fit
        gaussian_params = fit_results['gaussian']['params']
        gaussian_pdf = stats.norm.pdf(x_range, *gaussian_params)
        plt.plot(x_range, gaussian_pdf, 'r-', linewidth=2, label=f'Gaussian (AIC: {fit_results["gaussian"]["aic"]:.2f})')
        
        # Plot Laplacian fit
        laplace_params = fit_results['laplace']['params']
        laplace_pdf = stats.laplace.pdf(x_range, *laplace_params)
        plt.plot(x_range, laplace_pdf, 'g-', linewidth=2, label=f'Laplacian (AIC: {fit_results["laplace"]["aic"]:.2f})')
        
        plt.legend()
        
        # Add text showing which fit is better
        better_fit = fit_results['better_fit']
        plt.text(0.02, 0.98, f'Better fit: {better_fit}', transform=plt.gca().transAxes, 
                verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
    
    # Plot time series
    plt.subplot(1, n_subplots, 2)
    if len(z_index_samples) > 1000:
        # For large datasets, use scatter plot without lines
        plt.scatter(range(len(z_index_samples)), z_index_samples, alpha=0.6, s=1)
    else:
        # For smaller datasets, use line plot
        plt.plot(z_index_samples, 'o-', alpha=0.7, markersize=3)
    plt.xlabel('Sample index')
    plt.ylabel(f'z_{index} values')
    plt.title(f'z_{index} {sample_type} over dataset\n(angle={angle}, sf={spatial_frequency}, contrast={contrast})')
    plt.grid(True, alpha=0.3)
    
    # Plot Q-Q plots if distributions were fitted
    if fit_distributions and fit_results:
        plt.subplot(1, n_subplots, 3)
        
        # Create Q-Q plots for both distributions
        # Sort the samples for Q-Q plot
        sorted_samples = np.sort(z_index_samples)
        n_samples = len(sorted_samples)
        theoretical_quantiles = np.linspace(0.01, 0.99, n_samples)
        
        # Gaussian Q-Q plot
        gaussian_params = fit_results['gaussian']['params']
        gaussian_theoretical = stats.norm.ppf(theoretical_quantiles, *gaussian_params)
        plt.scatter(gaussian_theoretical, sorted_samples, alpha=0.6, s=5, color='red', label='Gaussian')
        
        # Laplacian Q-Q plot
        laplace_params = fit_results['laplace']['params']
        laplace_theoretical = stats.laplace.ppf(theoretical_quantiles, *laplace_params)
        plt.scatter(laplace_theoretical, sorted_samples, alpha=0.6, s=5, color='green', label='Laplacian')
        
        # Add diagonal reference line
        min_val = min(sorted_samples.min(), gaussian_theoretical.min(), laplace_theoretical.min())
        max_val = max(sorted_samples.max(), gaussian_theoretical.max(), laplace_theoretical.max())
        plt.plot([min_val, max_val], [min_val, max_val], 'k--', alpha=0.8, label='Perfect fit')
        
        plt.xlabel('Theoretical Quantiles')
        plt.ylabel('Sample Quantiles')
        plt.title('Q-Q Plot Comparison')
        plt.legend()
        plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print some statistics
    print(f"\nStatistics for z_{index}:")
    print(f"Mean: {np.mean(z_index_samples):.4f}")
    print(f"Std: {np.std(z_index_samples):.4f}")
    print(f"Min: {np.min(z_index_samples):.4f}")
    print(f"Max: {np.max(z_index_samples):.4f}")
    print(f"Number of samples: {len(z_index_samples)}")
    
    # Print distribution fitting results
    if fit_distributions and fit_results:
        print(f"\nDistribution Fitting Results:")
        print(f"Gaussian - AIC: {fit_results['gaussian']['aic']:.4f}, BIC: {fit_results['gaussian']['bic']:.4f}")
        print(f"  Parameters: μ={fit_results['gaussian']['params'][0]:.4f}, σ={fit_results['gaussian']['params'][1]:.4f}")
        print(f"Laplacian - AIC: {fit_results['laplace']['aic']:.4f}, BIC: {fit_results['laplace']['bic']:.4f}")
        print(f"  Parameters: loc={fit_results['laplace']['params'][0]:.4f}, scale={fit_results['laplace']['params'][1]:.4f}")
        print(f"Better fit: {fit_results['better_fit']} (lower AIC)")


def fit_and_compare_distributions(data):
    """
    Fit Gaussian and Laplacian distributions to the data and compare their goodness of fit.
    
    Args:
        data: 1D numpy array of data points
        
    Returns:
        dict: Dictionary containing fitting results and comparison
    """
    # Fit Gaussian distribution
    gaussian_params = stats.norm.fit(data)
    gaussian_loglik = np.sum(stats.norm.logpdf(data, *gaussian_params))
    gaussian_aic = 2 * 2 - 2 * gaussian_loglik  # 2 parameters for Gaussian
    gaussian_bic = 2 * np.log(len(data)) - 2 * gaussian_loglik
    
    # Fit Laplacian distribution
    laplace_params = stats.laplace.fit(data)
    laplace_loglik = np.sum(stats.laplace.logpdf(data, *laplace_params))
    laplace_aic = 2 * 2 - 2 * laplace_loglik  # 2 parameters for Laplacian
    laplace_bic = 2 * np.log(len(data)) - 2 * laplace_loglik
    
    # Determine which fit is better (lower AIC is better)
    better_fit = "Gaussian" if gaussian_aic < laplace_aic else "Laplacian"
    
    results = {
        'gaussian': {
            'params': gaussian_params,
            'loglik': gaussian_loglik,
            'aic': gaussian_aic,
            'bic': gaussian_bic
        },
        'laplace': {
            'params': laplace_params,
            'loglik': laplace_loglik,
            'aic': laplace_aic,
            'bic': laplace_bic
        },
        'better_fit': better_fit
    }
    
    return results

def fit_laplace_to_z_components(model, data, angle, spatial_frequency=None, contrast=None, use_samples=False, samples_per_image=10, scale_divisor=1.0, use_sigmoid_transform=False, aggregated_images=None):
    """
    Apply model to subset of data images and fit Laplace distributions to each z component.
    
    Args:
        model: The trained model
        data: The dataset dictionary containing train/test images and parameters
        angle: The angle to constrain the data to (if None, use aggregated_images)
        spatial_frequency: Spatial frequency parameter (optional)
        contrast: Contrast parameter (optional) 
        use_samples: If True, draw samples from z distributions; if False, use means (default: False)
        samples_per_image: Number of samples to draw per image when use_samples=True (default: 10)
        scale_divisor: Factor to divide the fitted scale by (default: 1.0, no change)
        use_sigmoid_transform: If True, apply sigmoid-like transformation scale/(1+scale) (default: False)
        aggregated_images: Pre-computed images to use when angle=None (optional)
        
    Returns:
        torch.distributions.Laplace: PyTorch Laplace distribution with fitted parameters
    """
    import torch
    from scipy import stats
    
    # Get constrained images based on the specified parameters
    if angle is not None:
        constrained_images = get_constrained_images(data, angle=angle, 
                                                  spatial_frequency=spatial_frequency, 
                                                  contrast=contrast)
    else:
        # Use pre-computed aggregated images
        if aggregated_images is None:
            raise ValueError("aggregated_images must be provided when angle is None")
        constrained_images = aggregated_images
    
    if len(constrained_images) == 0:
        if angle is not None:
            raise ValueError(f"No images found for angle={angle}, spatial_frequency={spatial_frequency}, contrast={contrast}")
        else:
            raise ValueError("No aggregated images provided")
    
    print(f"Found {len(constrained_images)} constrained images")
    
    # Feed the constrained images through the model
    device = next(model.parameters()).device
    with torch.no_grad():
        output = model(torch.tensor(constrained_images, dtype=torch.float32).to(device))
    
    # Extract z samples from the model output
    if use_samples:
        # Draw samples from the z distribution
        z_dist = output[1]["z"][1]
        z_samples_list = []
        for _ in range(samples_per_image):
            z_sample = z_dist.sample().cpu().detach().numpy()
            z_samples_list.append(z_sample)
        # Stack samples: shape will be (samples_per_image * num_images, z_dim)
        z_samples = np.concatenate(z_samples_list, axis=0)
        print(f"Drew {samples_per_image} samples per image, total: {z_samples.shape[0]} samples")
    else:
        # Use the means of the z distribution
        z_samples = output[1]["z"][1].loc.cpu().detach().numpy()
        print(f"Using distribution means, total: {z_samples.shape[0]} samples")
    
    # Get the number of z components
    z_dim = z_samples.shape[1]
    
    # Fit Laplace distribution to each z component with enforced zero mean
    fitted_locs = []
    fitted_scales = []
    
    for i in range(z_dim):
        z_component_data = z_samples[:, i]
        
        # Enforce zero mean: fit only the scale parameter
        # For Laplace(0, b), the MLE for scale b is mean(|x|)
        fitted_loc = 0.0
        fitted_scale = np.mean(np.abs(z_component_data))
        
        fitted_locs.append(fitted_loc)
        # Apply scale divisor
        adjusted_scale = fitted_scale / scale_divisor
        
        # Apply sigmoid transform if requested
        if use_sigmoid_transform:
            adjusted_scale = adjusted_scale / (1 + adjusted_scale)
        
        fitted_scales.append(adjusted_scale)
        
        if i < 5:  # Print details for first few components
            if scale_divisor != 1.0 or use_sigmoid_transform:
                transform_info = []
                if scale_divisor != 1.0:
                    transform_info.append(f"÷{scale_divisor}")
                if use_sigmoid_transform:
                    transform_info.append("sigmoid")
                transform_str = " (" + ", ".join(transform_info) + ")"
                print(f"z_{i}: loc={fitted_loc:.4f} (enforced), original scale={fitted_scale:.4f}, adjusted scale={adjusted_scale:.4f}{transform_str}")
            else:
                print(f"z_{i}: loc={fitted_loc:.4f} (enforced), scale={fitted_scale:.4f}")
    
    # Convert to PyTorch tensors
    loc_tensor = torch.tensor(fitted_locs, dtype=torch.float32).to(device)
    scale_tensor = torch.tensor(fitted_scales, dtype=torch.float32).to(device)
    
    # Create PyTorch Laplace distribution
    laplace_dist = torch.distributions.Laplace(loc=loc_tensor, scale=scale_tensor)
    
    print(f"Fitted Laplace distributions for {z_dim} z components")
    
    return laplace_dist

def calculate_avg_z_scale(model, data, angle, spatial_frequency=None, contrast=None):
    """
    Calculate the average scale of z distributions (per dimension) obtained from images 
    with a given angle, spatial_frequency and contrast. Model is applied without prior.
    
    Args:
        model: The trained model
        data: The dataset dictionary containing train/test images and parameters
        angle: The angle to constrain the data to
        spatial_frequency: Spatial frequency parameter (optional)
        contrast: Contrast parameter (optional)
        
    Returns:
        numpy.ndarray: Average scale per z dimension
    """
    import torch
    
    # Get constrained images based on the specified parameters
    constrained_images = get_constrained_images(data, angle=angle, 
                                              spatial_frequency=spatial_frequency, 
                                              contrast=contrast)
    
    if len(constrained_images) == 0:
        raise ValueError(f"No images found for angle={angle}, spatial_frequency={spatial_frequency}, contrast={contrast}")
    
    print(f"Found {len(constrained_images)} constrained images")
    
    # Apply the model to the constrained images (without prior)
    device = next(model.parameters()).device
    with torch.no_grad():
        output = model(torch.tensor(constrained_images, dtype=torch.float32).to(device))
    
    # Extract z distribution scales from the model output
    z_scales = output[1]["z"][1].scale.cpu().detach().numpy()
    
    # Calculate the average scale per dimension
    avg_scales = np.mean(z_scales, axis=0)
    
    print(f"Calculated average scales for {len(avg_scales)} z dimensions")
    print(f"Scale range: [{np.min(avg_scales):.4f}, {np.max(avg_scales):.4f}]")
    
    return avg_scales

def pca_analysis_z(model, data, angle=None, phase=None, spatial_frequency=None, contrast=None):
    """
    Perform PCA analysis on z latent variables from the model.
    
    Args:
        model: The trained model
        data: The dataset dictionary containing train/test images and parameters
        angle: The angle to constrain the data to (optional)
        phase: The phase to constrain the data to (optional)
        spatial_frequency: Spatial frequency parameter (optional)
        contrast: Contrast parameter (optional)
        
    Returns:
        tuple: (variances, pca_vectors) where:
            - variances: numpy array of explained variances in decreasing order
            - pca_vectors: numpy array of PCA components (eigenvectors)
    """
    import torch
    from sklearn.decomposition import PCA
    
    # Get constrained images based on the specified parameters
    constrained_images = get_constrained_images(data, angle=angle, 
                                              phase=phase,
                                              spatial_frequency=spatial_frequency, 
                                              contrast=contrast)
    
    if len(constrained_images) == 0:
        raise ValueError(f"No images found for the specified constraints: angle={angle}, phase={phase}, spatial_frequency={spatial_frequency}, contrast={contrast}")
    
    print(f"Found {len(constrained_images)} constrained images")
    
    # Apply the model to the constrained images
    device = next(model.parameters()).device
    with torch.no_grad():
        output = model(torch.tensor(constrained_images, dtype=torch.float32).to(device))
    
    # Extract z means from the model output
    z_means = output[1]["z"][1].loc.cpu().detach().numpy()
    
    print(f"Extracted z means with shape: {z_means.shape}")
    print(f"z dimension: {z_means.shape[1]}")
    
    # Perform PCA analysis
    pca = PCA()
    pca.fit(z_means)
    
    # Get explained variances and components
    variances = pca.explained_variance_
    pca_vectors = pca.components_
    
    # Print results
    print(f"\nPCA Analysis Results:")
    print(f"Total variance explained: {np.sum(variances):.4f}")
    print(f"Number of components: {len(variances)}")
    
    print(f"\nExplained variances (decreasing order):")
    cumulative_variance = 0
    for i, var in enumerate(variances):
        cumulative_variance += var
        percentage = (var / np.sum(variances)) * 100
        cumulative_percentage = (cumulative_variance / np.sum(variances)) * 100
        print(f"PC{i+1}: {var:.4f} ({percentage:.2f}%) - Cumulative: {cumulative_percentage:.2f}%")
    
    print(f"\nPCA components (eigenvectors) shape: {pca_vectors.shape}")
    print(f"First 5 components (first 10 dimensions):")
    for i in range(min(5, len(pca_vectors))):
        print(f"PC{i+1}: {pca_vectors[i][:10]}")
    
    return variances, pca_vectors

    