# nc_plot_claude.py (Modified for ax0)

import numpy as np
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.decomposition import PCA
from sklearn.preprocessing import normalize  # For easier row-wise normalization
import copy
import matplotlib.ticker as ticker  # Import ticker for tick control, if needed for ax0

# Import from shared utility
from plot_config_utils import (
    BASE_COLORS,
    BASE_DEFAULT_PLOT_STYLE_CONFIG,
    _apply_font_settings,
    deep_update_style_config
)

DEFAULT_STYLE_CONFIG = copy.deepcopy(BASE_DEFAULT_PLOT_STYLE_CONFIG)
DEFAULT_STYLE_CONFIG.update({
    'figsize_1x4_main': (40, 12),  # Adjusted for 4 subplots
    'figsize_1x3_main': (32, 12),  # Kept for reference if needed elsewhere
    'figsize_1x2_digits': (10, 5),
    'legend_marker_scalefactor': 2,
    '3d_box_alpha': 0.05,
    'scatter_data_alpha': 0.4,
    'scatter_neuron_alpha': 0.8,
    'axis_labelpad': 40,
    'ax0_cmap': 'viridis',  # Colormap for the new ax0
    'title_pad_default': 15,  # Default title padding for 2D plots like ax0
    'num_samples_for_plots': 1000,  # Number of samples for ax0 and PCA plots
    'title_fontsize': 60,
    'axis_label_fontsize': 50,
    'tick_label_fontsize': 35
})


# ---

def plot_figures(style_config_overrides=None):  # Keep original name
    """
    Plot figures using a configuration-driven style.
    """
    current_style_config = copy.deepcopy(DEFAULT_STYLE_CONFIG)
    if style_config_overrides:
        deep_update_style_config(current_style_config, style_config_overrides)
    s_cfg = current_style_config

    _apply_font_settings(s_cfg['font_settings'])

    # Load data
    D = torch.load('result.pt')
    center = D['center'].detach().cpu().numpy()
    x_orig_tensor = D['x']
    y_orig_tensor = D['y']
    H_tensor = D['H']  # Keep as tensor for now if H is also used in operations
    V_tensor = D['V']
    W_tensor = D['W']

    neuron_label = torch.argmax(V_tensor, dim=0)

    # Convert to numpy early for general use
    x_orig_np = x_orig_tensor.detach().cpu().numpy()
    y_orig_np = y_orig_tensor.detach().cpu().numpy()
    H_np = H_tensor.detach().cpu().numpy()

    # --- Main Figure with ax0 and 3D Plots ---
    # Create a common sub_sample_index to be used for ax0 and PCA plots
    num_samples_for_plot = min(s_cfg['num_samples_for_plots'], x_orig_np.shape[0])
    sub_sample_index = np.random.permutation(x_orig_np.shape[0])[:num_samples_for_plot]

    fig_3d = plt.figure(figsize=s_cfg.get('figsize_1x4_main', (40, 10)), constrained_layout=True)  # Using 1x4 figsize
    # GridSpec for 1 row, 4 columns. Assign width ratios if ax0 should have different relative width.
    # For example, width_ratios=[1, 1.2, 1.2, 1.2]
    gs = fig_3d.add_gridspec(1, 4, wspace=0.15, width_ratios=s_cfg.get('main_fig_width_ratios', [1, 1, 1, 1]))

    ax0 = fig_3d.add_subplot(gs[0, 0])  # New 2D subplot on the left
    ax1 = fig_3d.add_subplot(gs[0, 1], projection='3d')  # Existing 3D plots shifted
    ax2 = fig_3d.add_subplot(gs[0, 2], projection='3d')
    ax3 = fig_3d.add_subplot(gs[0, 3], projection='3d')

    # --- Plotting for ax0 (Gramian Matrix) ---
    x_sampled_for_ax0 = x_orig_np[sub_sample_index, :] - center
    y_sampled_for_ax0 = y_orig_np[sub_sample_index]

    # Reorder by class (0, then 1, then 2) - assuming classes are 0, 1, 2
    indices_c0 = np.where(y_sampled_for_ax0 == 0)[0]
    indices_c1 = np.where(y_sampled_for_ax0 == 1)[0]
    indices_c2 = np.where(y_sampled_for_ax0 == 2)[0]
    reorder_map = np.concatenate((indices_c0, indices_c1, indices_c2))

    # Ensure all samples are covered if some classes are missing in the subsample
    # This reorder_map will only contain indices of present classes in the subsample.
    reordered_x = x_sampled_for_ax0[reorder_map, :]

    # Normalize along dimension 1 (each sample/row)
    normalized_reordered_x = normalize(reordered_x, norm='l2', axis=1)

    gram_matrix = normalized_reordered_x @ normalized_reordered_x.T

    im = ax0.imshow(gram_matrix, cmap=s_cfg.get('ax0_cmap', 'viridis'), aspect=1,
                    interpolation=s_cfg.get('ax0_interpolation', 'nearest'))
    cb = fig_3d.colorbar(im, ax=ax0, shrink=0.5, aspect=s_cfg.get('ax0_cbar_aspect', 15))  # Adjust shrink and aspect
    cb.ax.tick_params(labelsize=s_cfg['tick_label_fontsize'] * 0.8)
    # cb.set_label('Similarity (Dot Product)', fontsize=s_cfg['axis_label_fontsize'] * 0.8)

    ax0.set_title('(a) Digits Correlation', fontsize=s_cfg['title_fontsize'], fontweight='bold',
                  pad=s_cfg['title_pad_default'])
    ax0.set_xlabel('Sample Index', fontsize=s_cfg['axis_label_fontsize'])
    ax0.set_ylabel('Sample Index', fontsize=s_cfg['axis_label_fontsize'])
    ax0.tick_params(axis='both', which='major', labelsize=s_cfg['tick_label_fontsize'])
    for spine in ax0.spines.values():
        spine.set_linewidth(s_cfg['axes_linewidth'])
        spine.set_color(BASE_COLORS.get('gray', '#555555'))

    # --- Configuration for existing 3D Plots (ax1, ax2, ax3) ---
    # The original code had ax1.set_title with pad=s_cfg['axis_labelpad']. This is likely too large.
    # Using title_pad_default or a specific title_pad_3d if defined.
    title_pad_for_3d = s_cfg.get('title_pad_3d', s_cfg['title_pad_default'])

    for ax_3d_idx, ax_3d in enumerate([ax1, ax2, ax3]):  # Iterate with index if needed
        for spine in ax_3d.spines.values():
            spine.set_linewidth(s_cfg['axes_linewidth'])
        ax_3d.tick_params(axis='both', which='major', labelsize=s_cfg['tick_label_fontsize'],
                          width=s_cfg['axes_linewidth'], pad=s_cfg.get('tick_pad_3d', 10))
        ax_3d.xaxis.pane.fill = False;
        ax_3d.yaxis.pane.fill = False;
        ax_3d.zaxis.pane.fill = False
        ax_3d.xaxis.pane.set_edgecolor(BASE_COLORS['gray']);
        ax_3d.yaxis.pane.set_edgecolor(BASE_COLORS['gray']);
        ax_3d.zaxis.pane.set_edgecolor(BASE_COLORS['gray'])
        ax_3d.xaxis.pane.set_alpha(s_cfg['3d_box_alpha']);
        ax_3d.yaxis.pane.set_alpha(s_cfg['3d_box_alpha']);
        ax_3d.zaxis.pane.set_alpha(s_cfg['3d_box_alpha'])
        ax_3d.grid(True, linestyle=s_cfg['grid_linestyle'], linewidth=s_cfg['grid_linewidth'],
                   alpha=s_cfg['grid_alpha'], color=BASE_COLORS['light_gray'])
        ax_3d.w_xaxis.set_pane_color((0.95, 0.95, 0.95, s_cfg.get('pane_alpha', 0.05)))
        ax_3d.w_yaxis.set_pane_color((0.93, 0.93, 0.93, s_cfg.get('pane_alpha', 0.05)))
        ax_3d.w_zaxis.set_pane_color((0.91, 0.91, 0.91, s_cfg.get('pane_alpha', 0.05)))

        # Reduce ticks for ax1 (the first 3D subplot)
        if ax_3d == ax1:
            max_ticks_cfg_key_prefix = 'max_ticks_ax1'  # If you want to customize for other 3d axes, make this dynamic
            ax_3d.xaxis.set_major_locator(ticker.MaxNLocator(nbins=s_cfg.get(f'{max_ticks_cfg_key_prefix}_x', 5)))
            ax_3d.yaxis.set_major_locator(ticker.MaxNLocator(nbins=s_cfg.get(f'{max_ticks_cfg_key_prefix}_y', 5)))
            ax_3d.zaxis.set_major_locator(ticker.MaxNLocator(nbins=s_cfg.get(f'{max_ticks_cfg_key_prefix}_z', 5)))

    for ax_pca_idx, ax_pca in enumerate([ax2, ax3]):
        ax_pca.set_xlabel('PC1', fontsize=s_cfg['axis_label_fontsize'], labelpad=s_cfg['axis_labelpad'])
        ax_pca.set_ylabel('PC2', fontsize=s_cfg['axis_label_fontsize'], labelpad=s_cfg['axis_labelpad'])
        ax_pca.set_zlabel('PC3', fontsize=s_cfg['axis_label_fontsize'], labelpad=s_cfg['axis_labelpad'])

    ax1.view_init(elev=18, azim=38)
    ax1.set_xlabel('v[1]', fontsize=s_cfg['axis_label_fontsize'], labelpad=s_cfg['axis_labelpad'], rotation=0)
    ax1.set_ylabel('v[2]', fontsize=s_cfg['axis_label_fontsize'], labelpad=s_cfg['axis_labelpad'], rotation=0)
    ax1.set_zlabel('v[3]', fontsize=s_cfg['axis_label_fontsize'], labelpad=s_cfg['axis_labelpad'], rotation=0)

    V_np = V_tensor.detach().cpu().numpy()
    neuron_markersize = s_cfg['3d_neuron_markersize']
    neuron_alpha = s_cfg['scatter_neuron_alpha']
    ax1.scatter(V_np[0, neuron_label == 0], V_np[1, neuron_label == 0], V_np[2, neuron_label == 0],
                label='$\!v_j,\ j\in \mathcal{N}_1$', color=BASE_COLORS['blue'], marker='x', s=neuron_markersize,
                linewidth=3, alpha=neuron_alpha)
    ax1.scatter(V_np[0, neuron_label == 1], V_np[1, neuron_label == 1], V_np[2, neuron_label == 1],
                label='$\!v_j,\ j\in \mathcal{N}_2$', color=BASE_COLORS['orange'], marker='x', s=neuron_markersize,
                linewidth=3, alpha=neuron_alpha)
    ax1.scatter(V_np[0, neuron_label == 2], V_np[1, neuron_label == 2], V_np[2, neuron_label == 2],
                label='$\!v_j,\ j\in \mathcal{N}_3$', color=BASE_COLORS['green'], marker='x', s=neuron_markersize,
                linewidth=3, alpha=neuron_alpha)

    ref_line_lw = s_cfg['plot_linewidth']
    ax1.plot([0, 1], [0, -0.5], [0, -0.5], linestyle='--', linewidth=ref_line_lw, color=BASE_COLORS['gray'], alpha=0.7)
    ax1.plot([0, -0.5], [0, 1], [0, -0.5], linestyle='--', linewidth=ref_line_lw, color=BASE_COLORS['gray'], alpha=0.7)
    ax1.plot([0, -0.5], [0, -0.5], [0, 1], linestyle='--', linewidth=ref_line_lw, color=BASE_COLORS['gray'], alpha=0.7)

    legend1_loc = s_cfg.get('legend_loc_ax1', 'upper right')
    legend1_bbox = s_cfg.get('legend_bbox_to_anchor_ax1', (1.1, 1.02) if legend1_loc == 'upper right' else None)

    legend1 = ax1.legend(fontsize=s_cfg['legend_fontsize'], frameon=s_cfg['legend_frameon'],
                         fancybox=s_cfg['legend_fancybox'], edgecolor=s_cfg['legend_edgecolor'],
                         facecolor=s_cfg['legend_facecolor'], framealpha=0.9, loc=legend1_loc,
                         shadow=s_cfg['legend_shadow'], bbox_to_anchor=legend1_bbox)
    if legend1: legend1.get_frame().set_linewidth(s_cfg['legend_frame_linewidth'])
    for handle in legend1.legendHandles:  # Check attribute name, usually legendHandles or _legend_handles
        try:
            handle.set_sizes([neuron_markersize * s_cfg['legend_marker_scalefactor']])
        except AttributeError:  # PathCollection might not have set_sizes, but _sizes
            if hasattr(handle, '_sizes'): handle._sizes = np.array(
                [neuron_markersize * s_cfg['legend_marker_scalefactor']])

    # PCA plots use the same sub_sample_index
    pca_x = PCA(n_components=3).fit(x_orig_np)  # Fit PCA on the full original data
    feature_x = pca_x.transform(x_orig_np)  # Transform the full original data
    feature_x_sampled = feature_x[sub_sample_index, :]  # Apply the common sub_sample_index
    y_sampled = y_orig_np[sub_sample_index]  # Apply the common sub_sample_index

    data_markersize = s_cfg['3d_data_markersize']
    data_alpha = s_cfg['scatter_data_alpha']
    ax2.scatter(feature_x_sampled[y_sampled == 0, 0], feature_x_sampled[y_sampled == 0, 1],
                feature_x_sampled[y_sampled == 0, 2],
                label='Digit 0', color=BASE_COLORS['blue'], s=data_markersize, edgecolors=BASE_COLORS['white'],
                linewidth=0.0, alpha=data_alpha)
    ax2.scatter(feature_x_sampled[y_sampled == 1, 0], feature_x_sampled[y_sampled == 1, 1],
                feature_x_sampled[y_sampled == 1, 2],
                label='Digit 1', color=BASE_COLORS['orange'], s=data_markersize, edgecolors=BASE_COLORS['white'],
                linewidth=0.0, alpha=data_alpha)
    ax2.scatter(feature_x_sampled[y_sampled == 2, 0], feature_x_sampled[y_sampled == 2, 1],
                feature_x_sampled[y_sampled == 2, 2],
                label='Digit 2', color=BASE_COLORS['green'], s=data_markersize, edgecolors=BASE_COLORS['white'],
                linewidth=0.0, alpha=data_alpha)
    ax2.set_aspect('auto')
    ax2.view_init(elev=18, azim=38)

    legend2 = ax2.legend(fontsize=s_cfg['legend_fontsize'], frameon=s_cfg['legend_frameon'],
                         fancybox=s_cfg['legend_fancybox'], edgecolor=s_cfg['legend_edgecolor'],
                         facecolor=s_cfg['legend_facecolor'], framealpha=0.9, loc='upper right',
                         shadow=s_cfg['legend_shadow'], bbox_to_anchor=(1.1, 1.02))
    if legend2: legend2.get_frame().set_linewidth(s_cfg['legend_frame_linewidth'])
    for handle in legend2.legendHandles:
        try:
            handle.set_sizes([data_markersize * s_cfg['legend_marker_scalefactor']])
        except AttributeError:
            if hasattr(handle, '_sizes'): handle._sizes = np.array(
                [data_markersize * s_cfg['legend_marker_scalefactor']])

    pca_h = PCA(n_components=3).fit(H_np)  # H_np is already numpy
    feature_h = pca_h.transform(H_np)
    classifier_h = pca_h.transform(np.vstack((V_np,np.zeros((1,50))))*3)
    feature_h_sampled = feature_h[sub_sample_index, :]  # Use the same sub_sample_index

    ax3.scatter(feature_h_sampled[y_sampled == 0, 0], feature_h_sampled[y_sampled == 0, 1],
                feature_h_sampled[y_sampled == 0, 2],
                label='Digit 0', color=BASE_COLORS['blue'], s=data_markersize, edgecolors=BASE_COLORS['white'],
                linewidth=0.0, alpha=data_alpha)
    ax3.scatter(feature_h_sampled[y_sampled == 1, 0], feature_h_sampled[y_sampled == 1, 1],
                feature_h_sampled[y_sampled == 1, 2],
                label='Digit 1', color=BASE_COLORS['orange'], s=data_markersize, edgecolors=BASE_COLORS['white'],
                linewidth=0.0, alpha=data_alpha)
    ax3.scatter(feature_h_sampled[y_sampled == 2, 0], feature_h_sampled[y_sampled == 2, 1],
                feature_h_sampled[y_sampled == 2, 2],
                label='Digit 2', color=BASE_COLORS['green'], s=data_markersize, edgecolors=BASE_COLORS['white'],
                linewidth=0.0, alpha=data_alpha)
    scatter_handles = ax3.collections[:3]

    ax3.plot([classifier_h[3,0], classifier_h[0, 0]], [classifier_h[3,1],classifier_h[0, 1]],[classifier_h[3,2],classifier_h[0, 2]], '--',linewidth=ref_line_lw)
    ax3.plot([classifier_h[3, 0], classifier_h[1, 0]], [classifier_h[3, 1], classifier_h[1, 1]],
             [classifier_h[3, 2], classifier_h[1, 2]], '--', linewidth=ref_line_lw)
    ax3.plot([classifier_h[3, 0], classifier_h[2, 0]], [classifier_h[3, 1], classifier_h[2, 1]],
             [classifier_h[3, 2], classifier_h[2, 2]], '--', linewidth=ref_line_lw)
    classifier_proxy = ax3.plot([0], [0], linestyle='--', color='gray', linewidth=ref_line_lw, label='Classifiers')

    ax3.set_aspect('auto')
    ax3.view_init(elev=18, azim=38)

    legend3 = ax3.legend(handles=[*scatter_handles, classifier_proxy[0]],fontsize=s_cfg['legend_fontsize'], frameon=s_cfg['legend_frameon'],
                         fancybox=s_cfg['legend_fancybox'], edgecolor=s_cfg['legend_edgecolor'],
                         facecolor=s_cfg['legend_facecolor'], framealpha=0.9, loc='upper right',
                         shadow=s_cfg['legend_shadow'], bbox_to_anchor=(1.1, 1.02))
    if legend3: legend3.get_frame().set_linewidth(s_cfg['legend_frame_linewidth'])
    for handle in legend3.legendHandles:
        try:
            handle.set_sizes([data_markersize * s_cfg['legend_marker_scalefactor']])
        except AttributeError:
            if hasattr(handle, '_sizes'): handle._sizes = np.array(
                [data_markersize * s_cfg['legend_marker_scalefactor']])

    ax1.set_title('(b) Neuron Weights Alignment', fontsize=s_cfg['title_fontsize'], fontweight='bold',
                  pad=title_pad_for_3d)
    ax2.set_title('(c) PCA of Input Data', fontsize=s_cfg['title_fontsize'], fontweight='bold', pad=title_pad_for_3d)
    ax3.set_title('(d) PCA of Last-layer Feature', fontsize=s_cfg['title_fontsize'], fontweight='bold',
                  pad=title_pad_for_3d)

    # constrained_layout=True on fig_3d should handle this. If not, uncomment and adjust tight_layout.
    # tight_layout_kwargs_main = {'pad': s_cfg.get('tight_layout_pad_main', 3.0)}
    # if s_cfg.get('tight_layout_rect_main') is not None:
    #     tight_layout_kwargs_main['rect'] = s_cfg['tight_layout_rect_main']
    # fig_3d.tight_layout(**tight_layout_kwargs_main)

    fig_3d.savefig('mnist_main_plots_with_gramian.png', transparent=False, facecolor=s_cfg['save_facecolor'],
                   dpi=s_cfg['save_dpi'], bbox_inches='tight')
    fig_3d.savefig('mnist_main_plots_with_gramian.pdf', transparent=False, facecolor=s_cfg['save_facecolor'],
                   bbox_inches='tight')
    plt.close(fig_3d)

    # --- Digit Comparison Plots ---
    W = W_tensor.detach().cpu().numpy()
    for k in range(3):
        avg_neuron = np.mean(W[:, neuron_label == k], axis=1)
        avg_data = np.mean(x_orig_np[y_orig_np == k, :], axis=0)
        avg_neuron = avg_neuron * np.linalg.norm(avg_data) / np.linalg.norm(avg_neuron)
        fig, axes = plt.subplots(1, 2, figsize=(10, 5))

        # Plot the first digit
        axes[0].imshow(np.reshape(avg_neuron + center, (28, 28)), cmap='gray', vmin=0, vmax=1)
        # axes[0].set_title(f'Average Neuron for Digit {k}', fontsize=0)
        axes[0].axis('off')  # Remove the axis

        # Plot the second digit
        axes[1].imshow(np.reshape(avg_data + center, (28, 28)), cmap='gray', vmin=0, vmax=1)
        # axes[1].set_title(f'Average Data for Digit {k}', fontsize=0)
        axes[1].axis('off')  # Remove the axis

        # Adjust the layout
        plt.tight_layout()

        # Save the figure as an image
        plt.savefig(f'mnist_digits_side_by_side_{k}.png', dpi=300)
        plt.close()


if __name__ == '__main__':

    font = 'times new roman'  # Changed from 'font' to font_choice
    overrides = {}
    if font == 'times new roman' or font == 'serif':
        # Ensure 'serif' key exists in DEFAULT_STYLE_CONFIG['font_settings'] before list comprehension
        default_serif_fonts = DEFAULT_STYLE_CONFIG['font_settings'].get('serif', [])  # Provide a default empty list
        overrides['font_settings'] = {
            'family': 'serif',
            'serif': ['Times New Roman'] + [f for f in default_serif_fonts if f != 'Times New Roman']
        }
    elif font == 'sans-serif':
        overrides['font_settings'] = {'family': 'sans-serif'}

    overrides.update({
        'main_fig_width_ratios': [0.75, 1, 1, 1],  # Give ax0 slightly less width
        '3d_neuron_markersize': 200,
        '3d_data_markersize': 200,
        'plot_linewidth': 4,
        'axes_linewidth': 2,
        'title_pad_default': 85,
        'title_pad_3d': 0,  # Pad for 3D titles
        'num_samples_for_plots': 500,  # Reduce samples for faster plotting / less dense ax0
        'ax0_cmap': 'coolwarm',  # Example of a different cmap
        'tick_pad_3d': 4,  # Reduce padding for 3D ticks
        'max_ticks_ax1_x': 4,  # Control ticks for ax1 specifically
        'max_ticks_ax1_y': 4,
        'max_ticks_ax1_z': 4,
        'legend_fontsize': 35,
        'legend_title_fontsize': 35,  # Used by nc_plot.py
    })

    print("Plotting figures for nc_plot.py (plot_nc_mnist.py)...")
    plot_figures(style_config_overrides=overrides)
    print("\nFinished plotting for nc_plot.py (plot_nc_mnist.py).")
