import torch
import numpy as np
import matplotlib.pyplot as plt
import wandb
from collections import defaultdict
import os

class PDEResidualTracker:
    def __init__(self, max_samples=10, log_frequency=50, detailed_log_frequency=100):
        self.max_samples = max_samples
        self.log_frequency = log_frequency  
        self.detailed_log_frequency = detailed_log_frequency  
        
        # Global scatter plot data (persistent, stores all scatter data)
        self.global_scatter_steps = []
        self.global_scatter_sigmas = []
        self.global_scatter_residuals = []
        
        # Boundary residual tracking (extracted from PDE residual maps)
        self.global_boundary_steps = []
        self.global_boundary_sigmas = []
        self.global_boundary_residuals = []
        
        # For detailed sigma vs residual plots
        self.recent_sigmas = []
        self.recent_residuals = []
        self.recent_steps = []
        
        # Boundary residual tracking (2-pixel and 1-pixel boundaries)
        self.recent_boundary2_sigmas = []      # 2-pixel boundary (2 rows/cols from edges)
        self.recent_boundary2_residuals = []   # 2-pixel boundary residuals
        self.recent_boundary1_sigmas = []      # 1-pixel boundary (1 row/col from edges)  
        self.recent_boundary1_residuals = []   # 1-pixel boundary residuals
        
        # Store recent residual maps for visualization (keep only last few to save memory)
        self.recent_residual_maps = []         # Store actual residual tensors
        self.recent_residual_map_sigmas = []   # Corresponding sigma values
        self.recent_residual_map_steps = []    # Corresponding training steps
        
        # For training visualization storage
        self.recent_training_data = []         # Store recent training data for comprehensive visualization
        self.max_training_data = 3             # Keep last 3 batches for training visualization

    def _cleanup_figure(self, fig):
        """Properly cleanup matplotlib figure to prevent memory leaks"""
        if fig is not None:
            plt.figure(fig.number)
            plt.clf()
            plt.close(fig)
        
    def log_residuals(self, pde_residuals, sigmas, step):
        """Log residuals for first max_samples in batch"""
        # Always collect data for recent detailed plots
        batch_size = min(self.max_samples, pde_residuals.size(0))
        
        # Debug: Print tensor shapes to understand the data structure
        # if step % 100 == 0:  # Only print every 100 steps to avoid spam
        #     print(f"PDE residuals shape: {pde_residuals.shape}, sigmas shape: {sigmas.shape}")
        
        sample_residuals = torch.mean(torch.abs(pde_residuals[:batch_size].view(batch_size, -1)), dim=1)
        sample_sigmas = sigmas[:batch_size].squeeze()
        
        # Extract boundary residuals
        boundary2_residuals, boundary1_residuals = self._extract_boundary_residuals(
            pde_residuals[:batch_size].squeeze(), sigmas[:batch_size].squeeze() 
        )
        
        # Store recent data for detailed plots
        self.recent_sigmas.extend(sample_sigmas.detach().cpu().numpy())
        self.recent_residuals.extend(sample_residuals.detach().cpu().numpy())
        self.recent_steps.extend([step] * batch_size)
        
        # Store boundary residuals
        self.recent_boundary2_sigmas.extend(sample_sigmas.detach().cpu().numpy())
        self.recent_boundary2_residuals.extend(boundary2_residuals.detach().cpu().numpy())
        self.recent_boundary1_sigmas.extend(sample_sigmas.detach().cpu().numpy())
        self.recent_boundary1_residuals.extend(boundary1_residuals.detach().cpu().numpy())
        
        # Store residual maps for visualization (keep only last 10 to save memory)
        if len(self.recent_residual_maps) < 10:
            self.recent_residual_maps.append(pde_residuals[:batch_size].detach().cpu())
            self.recent_residual_map_sigmas.append(sigmas[:batch_size].detach().cpu())
            self.recent_residual_map_steps.append(step)
        
        # Log to global history every log_frequency steps
        if step % self.log_frequency != 0:
            return
            
        # Store data for global scatter plot (persistent)
        self.global_scatter_steps.extend([step] * batch_size)
        self.global_scatter_sigmas.extend(sample_sigmas.detach().cpu().numpy())
        self.global_scatter_residuals.extend(sample_residuals.detach().cpu().numpy())

    def _extract_boundary_residuals(self, pde_residuals, sigmas):
        """Extract boundary residuals from PDE residual maps
        
        Args:
            pde_residuals: Tensor of shape [batch_size, channels, height, width] or [batch_size, height, width]
            sigmas: Tensor of shape [batch_size]
            
        Returns:
            boundary2_residuals: Mean residuals from 2-pixel boundary [batch_size]
            boundary1_residuals: Mean residuals from 1-pixel boundary [batch_size]
        """
        # print(f"PDE residuals shape: {pde_residuals.shape}, sigmas shape: {sigmas.shape}")
        # breakpoint()
        batch_size, height, width = pde_residuals.shape
        
        # 2-pixel boundary: 2 rows/cols from edges
        # Top 2 rows, bottom 2 rows, left 2 cols, right 2 cols
        boundary2_mask = torch.zeros(height, width, dtype=torch.bool, device=pde_residuals.device)
        boundary2_mask[:2, :] = True      # Top 2 rows
        boundary2_mask[-2:, :] = True     # Bottom 2 rows  
        boundary2_mask[:, :2] = True      # Left 2 cols
        boundary2_mask[:, -2:] = True     # Right 2 cols
        
        # 1-pixel boundary: 1 row/col from edges
        boundary1_mask = torch.zeros(height, width, dtype=torch.bool, device=pde_residuals.device)
        boundary1_mask[0, :] = True       # Top row
        boundary1_mask[-1, :] = True      # Bottom row
        boundary1_mask[:, 0] = True       # Left col
        boundary1_mask[:, -1] = True      # Right col
        
        # Extract boundary residuals
        boundary2_residuals = []
        boundary1_residuals = []
        
        for i in range(batch_size):
            # 2-pixel boundary
            boundary2_pixels = pde_residuals[i][boundary2_mask]
            boundary2_residuals.append(torch.mean(torch.abs(boundary2_pixels)))
            
            # 1-pixel boundary  
            boundary1_pixels = pde_residuals[i][boundary1_mask]
            boundary1_residuals.append(torch.mean(torch.abs(boundary1_pixels)))
        
        return torch.stack(boundary2_residuals), torch.stack(boundary1_residuals)

    def _subsample_global_data(self, max_points=20):
        """Subsample global scatter data to show evenly spaced steps across training"""
        if not self.global_scatter_steps:
            return [], [], []
            
        steps = np.array(self.global_scatter_steps)
        sigmas = np.array(self.global_scatter_sigmas)
        residuals = np.array(self.global_scatter_residuals)
        
        # Get unique steps
        unique_steps = np.unique(steps)
        if len(unique_steps) <= max_points:
            return steps, sigmas, residuals
        
        # Generate target steps (evenly spaced)
        min_step = unique_steps.min()
        max_step = unique_steps.max()
        target_steps = np.linspace(min_step, max_step, max_points)
        
        subsampled_steps = []
        subsampled_sigmas = []
        subsampled_residuals = []
        
        for target_step in target_steps:
            # Find the closest actual step to our target
            closest_step_idx = np.argmin(np.abs(unique_steps - target_step))
            closest_step = unique_steps[closest_step_idx]
            
            # Get all data points for this step
            step_mask = steps == closest_step
            if np.any(step_mask):
                step_steps = steps[step_mask]
                step_sigmas = sigmas[step_mask]
                step_residuals = residuals[step_mask]
                
                # Take all samples from this step
                subsampled_steps.extend(step_steps)
                subsampled_sigmas.extend(step_sigmas)
                subsampled_residuals.extend(step_residuals)
        
        return np.array(subsampled_steps), np.array(subsampled_sigmas), np.array(subsampled_residuals)

    def create_global_scatter_plot(self, save_dir=None, max_columns=15):
        """Create global scatter plot showing entire training history - ORIGINAL VERSION (No Clipping)"""
        if not self.global_scatter_steps:
            return None
            
        # CLEAR MATPLOTLIB STATE
        plt.clf()
        plt.close('all')
            
        fig, ax = plt.subplots(figsize=(16, 10))
        
        # Subsample data to get representative points across training
        steps, sigmas, residuals = self._subsample_global_data(max_points=max_columns)
        
        if len(steps) == 0:
            return None
        
        # PLOT VERTICAL LINES FIRST (in background)
        unique_displayed_steps = np.unique(steps)
        if len(unique_displayed_steps) > 1:
            for step in unique_displayed_steps:
                ax.axvline(x=step, color='lightgray', alpha=0.6, linestyle='--', 
                          linewidth=1.0, zorder=1)
        
        # ORIGINAL SCATTER PLOT - NO CLIPPING
        scatter = ax.scatter(steps, sigmas, c=residuals, cmap='plasma',
                           alpha=0.8, s=75, edgecolors='black', linewidth=0.4,
                           zorder=2)
        
        ax.set_xlabel('Training Step', fontsize=14)
        ax.set_ylabel('Sigma Value (log scale)', fontsize=14)
        ax.set_title('PDE Residuals vs Training Progress (Global View - Full Range)', fontsize=16)
        ax.set_yscale('log')
        
        # Enhanced X-axis ticks
        if len(unique_displayed_steps) > 0:
            ax.set_xticks(unique_displayed_steps)
            ax.set_xticklabels([f'{int(step)}' for step in unique_displayed_steps], 
                              rotation=45, ha='right', fontsize=11)
            
            if len(unique_displayed_steps) > 1:
                step_range = unique_displayed_steps.max() - unique_displayed_steps.min()
                minor_step_interval = step_range / (len(unique_displayed_steps) * 5)
                minor_ticks = np.arange(unique_displayed_steps.min(), 
                                       unique_displayed_steps.max() + minor_step_interval, 
                                       minor_step_interval)
                ax.set_xticks(minor_ticks, minor=True)
        
        # Enhanced Y-axis ticks
        sigma_min = max(sigmas.min(), 1e-3)
        sigma_max = sigmas.max()
        
        major_powers = np.arange(np.floor(np.log10(sigma_min)), 
                                np.ceil(np.log10(sigma_max)) + 1)
        major_sigma_ticks = 10.0 ** major_powers
        
        minor_sigma_ticks = []
        for power in major_powers:
            base = 10.0 ** power
            for multiplier in [2, 3, 4, 5, 6, 7, 8, 9]:
                tick_val = base * multiplier
                if sigma_min <= tick_val <= sigma_max:
                    minor_sigma_ticks.append(tick_val)
        
        major_ticks_in_range = major_sigma_ticks[(major_sigma_ticks >= sigma_min) & 
                                                (major_sigma_ticks <= sigma_max)]
        ax.set_yticks(major_ticks_in_range)
        ax.set_yticklabels([f'{tick:.0e}' if tick >= 1 else f'{tick:.3f}' 
                           for tick in major_ticks_in_range], fontsize=12)
        
        if minor_sigma_ticks:
            ax.set_yticks(minor_sigma_ticks, minor=True)
        
        # COLORBAR with data range info
        cbar = plt.colorbar(scatter, ax=ax, shrink=0.8)
        cbar.set_label('Mean Absolute PDE Residual (Full Range)', fontsize=14)
        cbar.ax.tick_params(labelsize=11)
        
        # Enhanced grid
        ax.grid(True, which='major', alpha=0.3, linewidth=0.8, zorder=0)
        ax.grid(True, which='minor', alpha=0.1, linewidth=0.4, zorder=0)
        
        # Calculate step interval
        unique_logged_steps = np.unique(self.global_scatter_steps)
        
        if len(unique_logged_steps) > 1:
            logged_step_intervals = np.diff(unique_logged_steps)
            actual_log_frequency = int(np.median(logged_step_intervals)) if len(logged_step_intervals) > 0 else self.log_frequency
            
            if len(unique_displayed_steps) > 1:
                display_step_interval = int(np.median(np.diff(unique_displayed_steps)))
                effective_interval = display_step_interval
            else:
                effective_interval = actual_log_frequency
        else:
            effective_interval = 0

        # INFO BOX with full range statistics
        step_info = f'Steps: {", ".join([str(int(s)) for s in unique_displayed_steps[:5]])}{"..." if len(unique_displayed_steps) > 5 else ""}'
        
        info_text = (f'FULL RANGE VISUALIZATION:\n'
                    f'• Showing {len(steps)} points from {len(self.global_scatter_steps)} total\n'
                    f'• Residual range: {np.min(residuals):.3f} → {np.max(residuals):.2f}\n'
                    f'• Color scaling: No clipping (full data range)\n'
                    f'• Log frequency: every {self.log_frequency} steps\n'
                    f'• Display interval: every {effective_interval} training steps\n'
                    f'• {step_info}')
        
        # Position info box outside the plot
        fig.text(0.02, 0.02, info_text, fontsize=10, 
                bbox=dict(boxstyle='round,pad=0.5', facecolor='lightblue', alpha=0.9, edgecolor='navy'),
                verticalalignment='bottom')
        
        plt.tight_layout()
        plt.subplots_adjust(bottom=0.20)  
        
        if save_dir:
            plt.savefig(os.path.join(save_dir, 'pde_residual_scatter_global_full.png'), dpi=150, bbox_inches='tight')
        
        return fig

    def create_global_scatter_plot_clipped(self, save_dir=None, max_columns=15):
        """Create global scatter plot with RECOMMENDED PERCENTILE CLIPPING for better color discrimination"""
        if not self.global_scatter_steps:
            return None
            
        # CLEAR MATPLOTLIB STATE
        plt.clf()
        plt.close('all')
            
        fig, ax = plt.subplots(figsize=(16, 10))
        
        # Subsample data to get representative points across training
        steps, sigmas, residuals = self._subsample_global_data(max_points=max_columns)
        
        if len(steps) == 0:
            return None
        
        # PLOT VERTICAL LINES FIRST (in background)
        unique_displayed_steps = np.unique(steps)
        if len(unique_displayed_steps) > 1:
            for step in unique_displayed_steps:
                ax.axvline(x=step, color='lightgray', alpha=0.6, linestyle='--', 
                          linewidth=1.0, zorder=1)
        
        # ENHANCED SCATTER PLOT with RECOMMENDED CLIPPING
        # Calculate clipping percentiles for better color discrimination
        vmin_clip = np.percentile(residuals, 5)   # Bottom 5% clipped
        vmax_clip = np.percentile(residuals, 95)  # Top 5% clipped
        
        scatter = ax.scatter(steps, sigmas, c=residuals, cmap='plasma', 
                           alpha=0.8, s=80, edgecolors='black', linewidth=0.5,
                           zorder=2, vmin=vmin_clip, vmax=vmax_clip)
        
        ax.set_xlabel('Training Step', fontsize=14)
        ax.set_ylabel('Sigma Value (log scale)', fontsize=14)
        ax.set_title('PDE Residuals vs Training Progress (Global View - Enhanced Colors)', fontsize=16)
        ax.set_yscale('log')
        
        # Enhanced X-axis ticks (same as original)
        if len(unique_displayed_steps) > 0:
            ax.set_xticks(unique_displayed_steps)
            ax.set_xticklabels([f'{int(step)}' for step in unique_displayed_steps], 
                              rotation=45, ha='right', fontsize=11)
            
            if len(unique_displayed_steps) > 1:
                step_range = unique_displayed_steps.max() - unique_displayed_steps.min()
                minor_step_interval = step_range / (len(unique_displayed_steps) * 5)
                minor_ticks = np.arange(unique_displayed_steps.min(), 
                                       unique_displayed_steps.max() + minor_step_interval, 
                                       minor_step_interval)
                ax.set_xticks(minor_ticks, minor=True)
        
        # Enhanced Y-axis ticks (same as original)
        sigma_min = max(sigmas.min(), 1e-3)
        sigma_max = sigmas.max()
        
        major_powers = np.arange(np.floor(np.log10(sigma_min)), 
                                np.ceil(np.log10(sigma_max)) + 1)
        major_sigma_ticks = 10.0 ** major_powers
        
        minor_sigma_ticks = []
        for power in major_powers:
            base = 10.0 ** power
            for multiplier in [2, 3, 4, 5, 6, 7, 8, 9]:
                tick_val = base * multiplier
                if sigma_min <= tick_val <= sigma_max:
                    minor_sigma_ticks.append(tick_val)
        
        major_ticks_in_range = major_sigma_ticks[(major_sigma_ticks >= sigma_min) & 
                                                (major_sigma_ticks <= sigma_max)]
        ax.set_yticks(major_ticks_in_range)
        ax.set_yticklabels([f'{tick:.0e}' if tick >= 1 else f'{tick:.3f}' 
                           for tick in major_ticks_in_range], fontsize=12)
        
        if minor_sigma_ticks:
            ax.set_yticks(minor_sigma_ticks, minor=True)
        
        # ENHANCED COLORBAR with clipping info
        cbar = plt.colorbar(scatter, ax=ax, shrink=0.8)
        cbar.set_label('Mean Absolute PDE Residual (5%-95% Range)', fontsize=14)
        cbar.ax.tick_params(labelsize=11)
        
        # Enhanced grid
        ax.grid(True, which='major', alpha=0.3, linewidth=0.8, zorder=0)
        ax.grid(True, which='minor', alpha=0.1, linewidth=0.4, zorder=0)
        
        # Calculate step interval (same as original)
        unique_logged_steps = np.unique(self.global_scatter_steps)
        
        if len(unique_logged_steps) > 1:
            logged_step_intervals = np.diff(unique_logged_steps)
            actual_log_frequency = int(np.median(logged_step_intervals)) if len(logged_step_intervals) > 0 else self.log_frequency
            
            if len(unique_displayed_steps) > 1:
                display_step_interval = int(np.median(np.diff(unique_displayed_steps)))
                effective_interval = display_step_interval
            else:
                effective_interval = actual_log_frequency
        else:
            effective_interval = 0

        # ENHANCED INFO BOX with clipping statistics
        step_info = f'Steps: {", ".join([str(int(s)) for s in unique_displayed_steps[:5]])}{"..." if len(unique_displayed_steps) > 5 else ""}'
        
        # Calculate outlier statistics
        outliers_low = np.sum(residuals < vmin_clip)
        outliers_high = np.sum(residuals > vmax_clip)
        total_points = len(residuals)
        
        info_text = (f'ENHANCED COLOR VISUALIZATION:\n'
                    f'• Showing {len(steps)} points from {len(self.global_scatter_steps)} total\n'
                    f'• Full range: {np.min(residuals):.3f} → {np.max(residuals):.2f}\n'
                    f'• Color range: {vmin_clip:.3f} → {vmax_clip:.3f} (5%-95%)\n'
                    f'• Clipped outliers: {outliers_low + outliers_high}/{total_points} points ({100*(outliers_low + outliers_high)/total_points:.1f}%)\n'
                    f'• Log frequency: every {self.log_frequency} steps\n'
                    f'• Display interval: every {effective_interval} training steps\n'
                    f'• {step_info}')
        
        # Position info box outside the plot
        fig.text(0.02, 0.02, info_text, fontsize=10, 
                bbox=dict(boxstyle='round,pad=0.5', facecolor='lightblue', alpha=0.9, edgecolor='navy'),
                verticalalignment='bottom')
        
        plt.tight_layout()
        plt.subplots_adjust(bottom=0.20)  # More room for enhanced info box
        
        if save_dir:
            plt.savefig(os.path.join(save_dir, 'pde_residual_scatter_global_clipped.png'), dpi=150, bbox_inches='tight')
        
        return fig

    def create_global_scatter_plot_clipped_zoomed_log(self, save_dir=None, max_columns=15, sigma_max_threshold=1.0):
        """Create ZOOMED LOG SCALE global scatter plot focusing on lower sigma values (≤ 1.0) with percentile clipping"""
        if not self.global_scatter_steps:
            return None
            
        plt.clf()
        plt.close('all')
        
        fig, ax = plt.subplots(figsize=(16, 10))
        
        # Subsample data to get representative points across training
        steps, sigmas, residuals = self._subsample_global_data(max_points=max_columns)
        
        if len(steps) == 0:
            return None
        
        # FILTER DATA: Only keep sigma values below threshold
        zoom_mask = sigmas <= sigma_max_threshold
        if not np.any(zoom_mask):
            # No data in zoom range
            ax.text(0.5, 0.5, f'No data with σ ≤ {sigma_max_threshold}', 
                    ha='center', va='center', transform=ax.transAxes, fontsize=16,
                    bbox=dict(boxstyle='round', facecolor='lightgray'))
            ax.set_title(f'Zoomed View: σ ≤ {sigma_max_threshold} (No Data)', fontsize=16)
            return fig
        
        # Apply zoom filter
        zoom_steps = steps[zoom_mask]
        zoom_sigmas = sigmas[zoom_mask]
        zoom_residuals = residuals[zoom_mask]
        
        # PLOT VERTICAL LINES for steps that have data in zoom range
        unique_zoom_steps = np.unique(zoom_steps)
        if len(unique_zoom_steps) > 1:
            for step in unique_zoom_steps:
                ax.axvline(x=step, color='lightgray', alpha=0.6, linestyle='--', 
                          linewidth=1.0, zorder=1)
    
        # ZOOMED SCATTER PLOT with CLIPPING for better color discrimination
        vmin_clip = np.percentile(zoom_residuals, 5)   # Bottom 5% clipped
        vmax_clip = np.percentile(zoom_residuals, 95)  # Top 5% clipped
        
        scatter = ax.scatter(zoom_steps, zoom_sigmas, c=zoom_residuals, cmap='plasma', 
                           alpha=0.8, s=80, edgecolors='black', linewidth=0.5,
                           zorder=2, vmin=vmin_clip, vmax=vmax_clip)
        
        ax.set_xlabel('Training Step', fontsize=14)
        ax.set_ylabel('Sigma Value (log scale)', fontsize=14)
        ax.set_title(f'PDE Residuals vs Training Progress (ZOOMED LOG: σ ≤ {sigma_max_threshold})', fontsize=16)
        ax.set_yscale('log')
        
        # Enhanced X-axis ticks for zoom range
        if len(unique_zoom_steps) > 0:
            ax.set_xticks(unique_zoom_steps)
            ax.set_xticklabels([f'{int(step)}' for step in unique_zoom_steps], 
                              rotation=45, ha='right', fontsize=11)
            
            if len(unique_zoom_steps) > 1:
                step_range = unique_zoom_steps.max() - unique_zoom_steps.min()
                minor_step_interval = step_range / (len(unique_zoom_steps) * 5)
                minor_ticks = np.arange(unique_zoom_steps.min(), 
                                       unique_zoom_steps.max() + minor_step_interval, 
                                       minor_step_interval)
                ax.set_xticks(minor_ticks, minor=True)
    
        # Enhanced Y-axis ticks for zoomed sigma range
        sigma_min = max(zoom_sigmas.min(), 1e-4)  # Prevent log(0)
        sigma_max = min(zoom_sigmas.max(), sigma_max_threshold)
        
        # Create appropriate log ticks for zoomed range
        if sigma_min < 1e-2:
            major_powers = np.arange(np.floor(np.log10(sigma_min)), 1)  # Up to 10^0 = 1
        else:
            major_powers = np.arange(-2, 1)  # 0.01, 0.1, 1.0
    
        major_sigma_ticks = 10.0 ** major_powers
        major_ticks_in_range = major_sigma_ticks[(major_sigma_ticks >= sigma_min) & 
                                                (major_sigma_ticks <= sigma_max)]
        
        if len(major_ticks_in_range) > 0:
            ax.set_yticks(major_ticks_in_range)
            ax.set_yticklabels([f'{tick:.3f}' for tick in major_ticks_in_range], fontsize=12)
        
        # Set explicit y-axis limits for zoom
        ax.set_ylim(max(sigma_min*0.8, 1e-6), min(sigma_max*1.2, sigma_max_threshold*1.1))
        
        # COLORBAR with zoom info
        cbar = plt.colorbar(scatter, ax=ax, shrink=0.8)
        cbar.set_label(f'Mean Absolute PDE Residual (5%-95% Range, σ≤{sigma_max_threshold})', fontsize=14)
        cbar.ax.tick_params(labelsize=11)
        
        # Enhanced grid
        ax.grid(True, which='major', alpha=0.3, linewidth=0.8, zorder=0)
        ax.grid(True, which='minor', alpha=0.1, linewidth=0.4, zorder=0)
        
        # Calculate statistics for zoom range
        unique_logged_steps = np.unique(self.global_scatter_steps)
        
        if len(unique_logged_steps) > 1:
            if len(unique_zoom_steps) > 1:
                display_step_interval = int(np.median(np.diff(unique_zoom_steps)))
            else:
                display_step_interval = self.log_frequency
        else:
            display_step_interval = 0

        # ZOOM INFO BOX with statistics
        step_info = f'Steps: {", ".join([str(int(s)) for s in unique_zoom_steps[:5]])}{"..." if len(unique_zoom_steps) > 5 else ""}'
        total_points_original = len(steps)
        zoom_points = len(zoom_steps)
        zoom_percentage = 100 * zoom_points / total_points_original if total_points_original > 0 else 0
        
        # Calculate outlier statistics for zoom range
        outliers_low = np.sum(zoom_residuals < vmin_clip)
        outliers_high = np.sum(zoom_residuals > vmax_clip)
        
        info_text = (f'ZOOMED LOG VIEW (σ ≤ {sigma_max_threshold}):\n'
                    f'• Showing {zoom_points} points from {total_points_original} total ({zoom_percentage:.1f}%)\n'
                    f'• Sigma range: {zoom_sigmas.min():.4f} → {zoom_sigmas.max():.3f}\n'
                    f'• Residual range: {zoom_residuals.min():.4f} → {zoom_residuals.max():.3f}\n'
                    f'• Color range: {vmin_clip:.4f} → {vmax_clip:.4f} (5%-95%)\n'
                    f'• Clipped outliers: {outliers_low + outliers_high}/{zoom_points} points\n'
                    f'• Log frequency: every {self.log_frequency} steps\n'
                    f'• Display interval: every {display_step_interval} training steps\n'
                    f'• {step_info}')
        
        # Position info box
        fig.text(0.02, 0.02, info_text, fontsize=10, 
                bbox=dict(boxstyle='round,pad=0.5', facecolor='lightgreen', alpha=0.9, edgecolor='darkgreen'),
                verticalalignment='bottom')
        
        plt.tight_layout()
        plt.subplots_adjust(bottom=0.20)
        
        if save_dir:
            plt.savefig(os.path.join(save_dir, f'pde_residual_scatter_zoomed_log_sigma_{sigma_max_threshold}.png'), 
                       dpi=150, bbox_inches='tight')
        
        return fig

    def create_global_scatter_plot_clipped_zoomed_linear(self, save_dir=None, max_columns=15, sigma_max_threshold=1.0):
        """Create ZOOMED LINEAR SCALE global scatter plot focusing on lower sigma values (≤ 1.0) with percentile clipping"""
        if not self.global_scatter_steps:
            return None
            
        plt.clf()
        plt.close('all')
        
        fig, ax = plt.subplots(figsize=(16, 10))
        
        # Subsample data to get representative points across training
        steps, sigmas, residuals = self._subsample_global_data(max_points=max_columns)
        
        if len(steps) == 0:
            return None
        
        # FILTER DATA: Only keep sigma values below threshold
        zoom_mask = sigmas <= sigma_max_threshold
        if not np.any(zoom_mask):
            # No data in zoom range
            ax.text(0.5, 0.5, f'No data with σ ≤ {sigma_max_threshold}', 
                    ha='center', va='center', transform=ax.transAxes, fontsize=16,
                    bbox=dict(boxstyle='round', facecolor='lightgray'))
            ax.set_title(f'Zoomed Linear View: σ ≤ {sigma_max_threshold} (No Data)', fontsize=16)
            return fig
        
        # Apply zoom filter
        zoom_steps = steps[zoom_mask]
        zoom_sigmas = sigmas[zoom_mask]
        zoom_residuals = residuals[zoom_mask]
        
        # PLOT VERTICAL LINES for steps that have data in zoom range
        unique_zoom_steps = np.unique(zoom_steps)
        if len(unique_zoom_steps) > 1:
            for step in unique_zoom_steps:
                ax.axvline(x=step, color='lightgray', alpha=0.6, linestyle='--', 
                          linewidth=1.0, zorder=1)
    
        # ZOOMED SCATTER PLOT with CLIPPING for better color discrimination
        vmin_clip = np.percentile(zoom_residuals, 5)   # Bottom 5% clipped
        vmax_clip = np.percentile(zoom_residuals, 95)  # Top 5% clipped
        
        scatter = ax.scatter(zoom_steps, zoom_sigmas, c=zoom_residuals, cmap='plasma', 
                           alpha=0.8, s=80, edgecolors='black', linewidth=0.5,
                           zorder=2, vmin=vmin_clip, vmax=vmax_clip)
        
        ax.set_xlabel('Training Step', fontsize=14)
        ax.set_ylabel('Sigma Value (linear scale)', fontsize=14)
        ax.set_title(f'PDE Residuals vs Training Progress (ZOOMED LINEAR: σ ≤ {sigma_max_threshold})', fontsize=16)
        # NO log scale for Y-axis
    
        # Enhanced X-axis ticks for zoom range
        if len(unique_zoom_steps) > 0:
            ax.set_xticks(unique_zoom_steps)
            ax.set_xticklabels([f'{int(step)}' for step in unique_zoom_steps], 
                              rotation=45, ha='right', fontsize=11)
            
            if len(unique_zoom_steps) > 1:
                step_range = unique_zoom_steps.max() - unique_zoom_steps.min()
                minor_step_interval = step_range / (len(unique_zoom_steps) * 5)
                minor_ticks = np.arange(unique_zoom_steps.min(), 
                                       unique_zoom_steps.max() + minor_step_interval, 
                                       minor_step_interval)
                ax.set_xticks(minor_ticks, minor=True)
    
        # LINEAR Y-axis ticks for zoomed sigma range
        sigma_min = sigmas.min()
        sigma_max = min(sigmas.max(), sigma_max_threshold)
        sigma_range = sigma_max - sigma_min
        
        # Create appropriate linear ticks
        if sigma_range > 5:
            tick_interval = 1.0
            major_sigma_ticks = np.arange(0, sigma_max_threshold + tick_interval, tick_interval)
        elif sigma_range > 1:
            tick_interval = 0.5
            major_sigma_ticks = np.arange(0, sigma_max_threshold + tick_interval, tick_interval)
        elif sigma_range > 0.1:
            tick_interval = 0.1
            major_sigma_ticks = np.arange(0, sigma_max_threshold + tick_interval, tick_interval)
        else:
            tick_interval = sigma_range / 10
            major_sigma_ticks = np.arange(sigma_min, sigma_max + tick_interval, tick_interval)
    
        # Filter ticks to reasonable range
        major_ticks_in_range = major_sigma_ticks[(major_sigma_ticks >= max(0, sigma_min*0.9)) & 
                                                (major_sigma_ticks <= sigma_max_threshold*1.05)]
        
        if len(major_ticks_in_range) > 0:
            ax.set_yticks(major_ticks_in_range)
            ax.set_yticklabels([f'{tick:.2f}' for tick in major_ticks_in_range], fontsize=12)
        
        # Set explicit y-limits for better display
        y_padding = sigma_range * 0.05
        ax.set_ylim(max(0, sigma_min - y_padding), sigma_max + y_padding)
        
        # COLORBAR with linear scale info
        cbar = plt.colorbar(scatter, ax=ax, shrink=0.8)
        cbar.set_label(f'Mean Absolute PDE Residual (5%-95% Range, Linear σ≤{sigma_max_threshold})', fontsize=14)
        cbar.ax.tick_params(labelsize=11)
        
        # Enhanced grid
        ax.grid(True, which='major', alpha=0.3, linewidth=0.8, zorder=0)
        ax.grid(True, which='minor', alpha=0.1, linewidth=0.4, zorder=0)
        
        # Calculate step interval (same as other versions)
        unique_logged_steps = np.unique(self.global_scatter_steps)
        
        if len(unique_logged_steps) > 1:
            if len(unique_zoom_steps) > 1:
                display_step_interval = int(np.median(np.diff(unique_zoom_steps)))
            else:
                display_step_interval = self.log_frequency
        else:
            display_step_interval = 0

        # LINEAR ZOOM INFO BOX
        step_info = f'Steps: {", ".join([str(int(s)) for s in unique_zoom_steps[:5]])}{"..." if len(unique_zoom_steps) > 5 else ""}'
        total_points_original = len(steps)
        zoom_points = len(zoom_steps)
        zoom_percentage = 100 * zoom_points / total_points_original if total_points_original > 0 else 0
        
        # Calculate outlier statistics
        outliers_low = np.sum(zoom_residuals < vmin_clip)
        outliers_high = np.sum(zoom_residuals > vmax_clip)
        
        info_text = (f'ZOOMED LINEAR VIEW (σ ≤ {sigma_max_threshold}):\n'
                    f'• Showing {zoom_points} points from {total_points_original} total ({zoom_percentage:.1f}%)\n'
                    f'• Sigma range: {zoom_sigmas.min():.4f} → {zoom_sigmas.max():.3f}\n'
                    f'• Residual range: {zoom_residuals.min():.4f} → {zoom_residuals.max():.3f}\n'
                    f'• Color range: {vmin_clip:.4f} → {vmax_clip:.4f} (5%-95%)\n'
                    f'• Clipped outliers: {outliers_low + outliers_high}/{zoom_points} points\n'
                    f'• Log frequency: every {self.log_frequency} steps\n'
                    f'• Display interval: every {display_step_interval} training steps\n'
                    f'• {step_info}')
        
        # Position info box
        fig.text(0.02, 0.02, info_text, fontsize=10, 
                bbox=dict(boxstyle='round,pad=0.5', facecolor='lightcyan', alpha=0.9, edgecolor='darkcyan'),
                verticalalignment='bottom')
        
        plt.tight_layout()
        plt.subplots_adjust(bottom=0.20)
        
        if save_dir:
            plt.savefig(os.path.join(save_dir, f'pde_residual_scatter_zoomed_linear_sigma_{sigma_max_threshold}.png'), 
                       dpi=150, bbox_inches='tight')
        
        return fig

    def create_comprehensive_analysis(self, save_dir=None):
        """Create comprehensive analysis with FIXED colorbar positioning"""
        if not self.global_scatter_steps:
            return None
        
        # CLEAR MATPLOTLIB STATE
        plt.clf()
        plt.close('all')
        
        # LARGER FIGURE SIZE
        fig = plt.figure(figsize=(24, 14))
        # MODIFIED GRID to give more space for twin axis
        gs = fig.add_gridspec(2, 5, height_ratios=[1, 1], width_ratios=[2.5, 0.2, 1, 1, 1])
        
        # Prepare data
        steps, sigmas, residuals = self._subsample_global_data(max_points=25)
        if len(steps) == 0:
            return None
            
        unique_steps = np.unique(steps)
        
        # 1. MAIN GLOBAL SCATTER PLOT (spans 2 rows, 1st column)
        ax1 = fig.add_subplot(gs[:, 0])
        
        # Vertical reference lines
        for step in unique_steps[::max(1, len(unique_steps)//10)]:
            ax1.axvline(x=step, color='lightgray', alpha=0.4, linestyle='--', zorder=1)
        
        # Enhanced scatter with better color resolution
        scatter = ax1.scatter(steps, sigmas, c=residuals, cmap='plasma', 
                            alpha=0.8, s=70, edgecolors='black', linewidth=0.3,
                            zorder=2, vmin=np.percentile(residuals, 10), 
                            vmax=np.percentile(residuals, 90))
        
        ax1.set_xlabel('Training Step', fontsize=16)
        ax1.set_ylabel('Sigma Value (log scale)', fontsize=16)
        ax1.set_title('Global PDE Residual Analysis', fontsize=18, fontweight='bold')
        ax1.set_yscale('log')
        ax1.grid(True, alpha=0.3)
        
        # FIXED: Colorbar in dedicated column (2nd column)
        cbar_ax = fig.add_subplot(gs[:, 1])
        cbar1 = plt.colorbar(scatter, cax=cbar_ax)
        cbar1.set_label('Mean Absolute PDE Residual', fontsize=14)
        cbar1.ax.tick_params(labelsize=12)
        
        # FIXED: Trend line overlay WITHOUT twin axis to avoid conflicts
        if len(unique_steps) > 3:
            step_means = [np.mean(residuals[steps == s]) for s in unique_steps]
            
            # Normalize trend line to fit in plot coordinates
            sigma_range = [sigmas.min(), sigmas.max()]
            trend_normalized = np.interp(step_means, 
                                       [min(step_means), max(step_means)], 
                                       [sigma_range[0]*0.5, sigma_range[1]*1.5])
            
            # Plot trend line on same axis
            trend_line = ax1.plot(unique_steps, trend_normalized, 'r-', linewidth=4, alpha=0.9, 
                         marker='o', markersize=6, markerfacecolor='red', 
                         markeredgecolor='darkred', label='Mean Trend (scaled)', zorder=3)
            ax1.legend(loc='upper right', fontsize=12)
        
        # 2. RESIDUAL EVOLUTION (top, 3rd column)
        ax2 = fig.add_subplot(gs[0, 2])
        if len(unique_steps) > 1:
            step_means = [np.mean(residuals[steps == s]) for s in unique_steps]
            step_stds = [np.std(residuals[steps == s]) for s in unique_steps]
            
            ax2.errorbar(unique_steps, step_means, yerr=step_stds, 
                        fmt='o-', capsize=4, linewidth=3, markersize=5, alpha=0.8,
                        color='blue', markerfacecolor='lightblue')
            ax2.set_xlabel('Training Step', fontsize=14)
            ax2.set_ylabel('Mean Residual ± Std', fontsize=14)
            ax2.set_title('Residual Evolution', fontsize=16, fontweight='bold')
            ax2.set_yscale('log')
            ax2.grid(True, alpha=0.3)
            ax2.tick_params(labelsize=12)
        
        # 3. SIGMA DISTRIBUTION EVOLUTION (top, 4th column)
        ax3 = fig.add_subplot(gs[0, 3])
        if len(unique_steps) > 1:
            sigma_means = [np.mean(sigmas[steps == s]) for s in unique_steps]
            ax3.plot(unique_steps, sigma_means, 'o-', linewidth=3, 
                    markersize=6, alpha=0.8, color='orange', markerfacecolor='yellow')
            ax3.set_xlabel('Training Step', fontsize=14)
            ax3.set_ylabel('Mean Sigma', fontsize=14)
            ax3.set_title('Sigma Evolution', fontsize=16, fontweight='bold')
            ax3.set_yscale('log')
            ax3.grid(True, alpha=0.3)
            ax3.tick_params(labelsize=12)
        
        # 4. RESIDUAL vs SIGMA CORRELATION (bottom, 3rd column)
        ax4 = fig.add_subplot(gs[1, 2])
        
        sigma_bins = np.logspace(np.log10(max(sigmas.min(), 1e-3)), 
                                np.log10(sigmas.max()), 8)
        bin_indices = np.digitize(sigmas, sigma_bins)
        
        bin_centers = []
        bin_means = []
        bin_stds = []
        bin_counts = []
        
        for i in range(1, len(sigma_bins)):
            mask = bin_indices == i
            if np.any(mask):
                bin_centers.append(np.sqrt(sigma_bins[i-1] * sigma_bins[i]))
                bin_means.append(np.mean(residuals[mask]))
                bin_stds.append(np.std(residuals[mask]))
                bin_counts.append(np.sum(mask))
        
        if bin_centers:
            ax4.errorbar(bin_centers, bin_means, yerr=bin_stds, 
                        fmt='o-', capsize=5, linewidth=3, markersize=6,
                        color='green', markerfacecolor='lightgreen')
            ax4.set_xlabel('Sigma (bin centers)', fontsize=14)
            ax4.set_ylabel('Mean Residual', fontsize=14)
            ax4.set_title('Sigma vs Residual', fontsize=16, fontweight='bold')
            ax4.set_xscale('log')
            ax4.set_yscale('log')
            ax4.grid(True, alpha=0.3)
            ax4.tick_params(labelsize=12)
            
            for x, y, count in zip(bin_centers, bin_means, bin_counts):
                ax4.annotate(f'n={count}', (x, y), xytext=(4, 4), 
                           textcoords='offset points', fontsize=10, alpha=0.8,
                           fontweight='bold')
        
        # 5. PERFORMANCE SUMMARY (spans bottom 4th and 5th columns)
        ax5 = fig.add_subplot(gs[1, 3:])
        ax5.axis('off')
        
        # Calculate comprehensive metrics
        if len(residuals) > 20:
            recent_residuals = residuals[-len(residuals)//4:]
            early_residuals = residuals[:len(residuals)//4]
            improvement = np.mean(early_residuals) / np.mean(recent_residuals)
            stability = np.std(recent_residuals) / np.mean(recent_residuals)
            sigma_span = sigmas.max() / sigmas.min()
            residual_span = residuals.max() / residuals.min()
        else:
            improvement = 1.0
            stability = np.std(residuals) / np.mean(residuals) if np.mean(residuals) > 0 else 0
            sigma_span = sigmas.max() / sigmas.min() if sigmas.min() > 0 else 1
            residual_span = residuals.max() / residuals.min() if residuals.min() > 0 else 1
        
        # FIXED: No emojis, ASCII symbols only
        perf_text = f"""PERFORMANCE SUMMARY

            [DATA] Training Progress:
            • Total training steps: {len(np.unique(self.global_scatter_steps))}
            • Data points collected: {len(self.global_scatter_steps):,}
            • Current step: {unique_steps[-1] if len(unique_steps) > 0 else 0}
            • Training duration: {unique_steps[-1] - unique_steps[0] if len(unique_steps) > 1 else 0} steps

            [TARGET] Learning Metrics:
            • Performance improvement: {improvement:.2f}x better
            • Training stability: {stability:.4f} (lower = more stable)
            • Noise level coverage: {sigma_span:.1f}x range
            • Residual dynamic range: {residual_span:.1f}x

            [TREND] Current Performance:
            • Minimum residual: {np.min(residuals):.4f}
            • Maximum residual: {np.max(residuals):.3f}
            • Mean residual: {np.mean(residuals):.4f}
            • Std deviation: {np.std(residuals):.4f}
            • Median residual: {np.median(residuals):.4f}

            [SEARCH] Data Coverage:
            • Sigma range: {sigmas.min():.3f} → {sigmas.max():.1f}
            • Log frequency: every {self.log_frequency} steps
            • Samples per step: {self.max_samples}
            """
        
        ax5.text(0.05, 0.95, perf_text, transform=ax5.transAxes, 
                fontsize=12, verticalalignment='top', family='monospace',
                bbox=dict(boxstyle='round,pad=0.8', facecolor='lightblue', 
                         alpha=0.9, edgecolor='navy', linewidth=2))
        
        fig.suptitle('Comprehensive PDE Residual Analysis', fontsize=22, fontweight='bold', y=0.98)
        
        plt.tight_layout()
        plt.subplots_adjust(top=0.93)
        
        if save_dir:
            plt.savefig(os.path.join(save_dir, 'pde_comprehensive_analysis.png'), 
                       dpi=150, bbox_inches='tight')
        
        return fig


    # Fix for create_global_boundary_scatter_plot_linear method
    def create_global_boundary_scatter_plot_linear(self, save_dir=None, max_columns=15):
        """Create global boundary scatter plot - LINEAR SCALE"""
        if not self.global_boundary_steps:
            return None
            
        plt.clf()
        plt.close('all')
        
        fig, ax = plt.subplots(figsize=(16, 10))
        
        steps, sigmas, residuals = self._subsample_global_boundary_data(max_points=max_columns)
        
        if len(steps) == 0:
            return None
        
        # Plot vertical lines
        unique_displayed_steps = np.unique(steps)
        if len(unique_displayed_steps) > 1:
            for step in unique_displayed_steps:
                ax.axvline(x=step, color='lightgray', alpha=0.6, linestyle='--', 
                        linewidth=1.0, zorder=1)
        
        # Enhanced scatter plot
        vmin_clip = np.percentile(residuals, 5)
        vmax_clip = np.percentile(residuals, 95)
        
        scatter = ax.scatter(steps, sigmas, c=residuals, cmap='viridis',
                        alpha=0.8, s=80, edgecolors='black', linewidth=0.5,
                        zorder=2, vmin=vmin_clip, vmax=vmax_clip)
        
        ax.set_xlabel('Training Step', fontsize=14)
        ax.set_ylabel('Sigma Value (linear scale)', fontsize=14)
        ax.set_title('Boundary Residuals vs Training Progress (Global View - Linear Scale)', fontsize=16)
        
        # Linear Y-axis ticks
        sigma_min = sigmas.min()
        sigma_max = sigmas.max()
        sigma_range = sigma_max - sigma_min
        
        # Create appropriate linear ticks
        if sigma_range > 5:
            tick_interval = 1.0
        elif sigma_range > 1:
            tick_interval = 0.5
        else:
            tick_interval = max(sigma_range / 10, 0.1)
        
        major_sigma_ticks = np.arange(0, sigma_max + tick_interval, tick_interval)
        major_ticks_in_range = major_sigma_ticks[major_sigma_ticks <= sigma_max*1.05]
        
        if len(major_ticks_in_range) > 0:
            ax.set_yticks(major_ticks_in_range)
            ax.set_yticklabels([f'{tick:.2f}' for tick in major_ticks_in_range], fontsize=12)
        
        # Set explicit y-limits for better display
        y_padding = sigma_range * 0.05
        ax.set_ylim(max(0, sigma_min - y_padding), sigma_max + y_padding)
        
        # COLORBAR with linear scale info
        cbar = plt.colorbar(scatter, ax=ax, shrink=0.8)
        cbar.set_label(f'Mean Absolute Boundary Residual (5%-95% Range, Linear σ≤{sigma_max_threshold})', fontsize=14)
        cbar.ax.tick_params(labelsize=11)
        
        # Enhanced grid
        ax.grid(True, which='major', alpha=0.3, linewidth=0.8, zorder=0)
        ax.grid(True, which='minor', alpha=0.1, linewidth=0.4, zorder=0)
        
        # Calculate step interval (same as other versions)
        unique_logged_steps = np.unique(self.global_scatter_steps)
        
        if len(unique_logged_steps) > 1:
            if len(unique_displayed_steps) > 1:
                display_step_interval = int(np.median(np.diff(unique_displayed_steps)))
            else:
                display_step_interval = self.log_frequency
        else:
            display_step_interval = 0

        # LINEAR ZOOM INFO BOX
        step_info = f'Steps: {", ".join([str(int(s)) for s in unique_zoom_steps[:5]])}{"..." if len(unique_zoom_steps) > 5 else ""}'
        total_points_original = len(steps)
        zoom_points = len(zoom_steps)
        zoom_percentage = 100 * zoom_points / total_points_original if total_points_original > 0 else 0
        
        # Calculate outlier statistics
        outliers_low = np.sum(zoom_residuals < vmin_clip)
        outliers_high = np.sum(zoom_residuals > vmax_clip)
        
        info_text = (f'ZOOMED LINEAR VIEW (σ ≤ {sigma_max_threshold}):\n'
                    f'• Showing {zoom_points} points from {total_points_original} total ({zoom_percentage:.1f}%)\n'
                    f'• Sigma range: {zoom_sigmas.min():.4f} → {zoom_sigmas.max():.3f}\n'
                    f'• Residual range: {zoom_residuals.min():.4f} → {zoom_residuals.max():.3f}\n'
                    f'• Color range: {vmin_clip:.4f} → {vmax_clip:.4f} (5%-95%)\n'
                    f'• Clipped outliers: {outliers_low + outliers_high}/{zoom_points} points\n'
                    f'• Log frequency: every {self.log_frequency} steps\n'
                    f'• Display interval: every {display_step_interval} training steps\n'
                    f'• {step_info}')
        
        # Position info box
        fig.text(0.02, 0.02, info_text, fontsize=10, 
                bbox=dict(boxstyle='round,pad=0.5', facecolor='lightcyan', alpha=0.9, edgecolor='darkcyan'),
                verticalalignment='bottom')
        
        plt.tight_layout()
        plt.subplots_adjust(bottom=0.20)
        
        if save_dir:
            plt.savefig(os.path.join(save_dir, f'pde_residual_scatter_zoomed_linear_sigma_{sigma_max_threshold}.png'), 
                       dpi=150, bbox_inches='tight')
        
        return fig

    def create_boundary_training_evolution_comparison(self, save_dir=None, n_subplots=8):
        """Compare boundary sigma-residual relationship across training steps - LOG AND LINEAR SCALES"""
        if not self.global_boundary_steps:
            return None
        
        plt.clf()
        plt.close('all')
        
        all_steps = np.array(self.global_boundary_steps)
        all_sigmas = np.array(self.global_boundary_sigmas)
        all_residuals = np.array(self.global_boundary_residuals)
        
        unique_steps = np.unique(all_steps)
        if len(unique_steps) < 2:
            return None
        
        if len(unique_steps) > n_subplots:
            step_indices = np.linspace(0, len(unique_steps)-1, n_subplots, dtype=int)
            selected_steps = unique_steps[step_indices]
        else:
            selected_steps = unique_steps

        n_rows = 2
        n_cols = int(np.ceil(len(selected_steps) / n_rows))
        
        # CREATE TWO FIGURES - Log and Linear scales for boundary residuals
        fig_log, axes_log = plt.subplots(n_rows, n_cols, figsize=(6*n_cols, 5*n_rows))
        fig_linear, axes_linear = plt.subplots(n_rows, n_cols, figsize=(6*n_cols, 5*n_rows))

        if len(selected_steps) == 1:
            axes_log = [axes_log]
            axes_linear = [axes_linear]
        else:
            axes_log = axes_log.flatten()
            axes_linear = axes_linear.flatten()

        # CALCULATE GLOBAL LIMITS for shared axes
        all_sigma_min = np.min(all_sigmas)
        all_sigma_max = np.max(all_sigmas)
        all_residual_min = np.min(all_residuals)
        all_residual_max = np.max(all_residuals)
        
        # Add some padding for better visualization
        sigma_padding = (all_sigma_max - all_sigma_min) * 0.1
        residual_padding = (all_residual_max - all_residual_min) * 0.1

        # Store correlations for both versions
        correlations_log = []
        correlations_linear = []

        for i, step in enumerate(selected_steps):
            if i >= len(axes_log):
                break
            
            # Get data for this specific step
            step_mask = all_steps == step
            step_sigmas = all_sigmas[step_mask]
            step_residuals = all_residuals[step_mask]
            
            if len(step_sigmas) == 0:
                # Handle no data case for both plots
                for axes_set in [axes_log, axes_linear]:
                    axes_set[i].text(0.5, 0.5, f'No boundary data\nStep {step}', ha='center', va='center',
                                   transform=axes_set[i].transAxes, fontsize=12)
                continue
            
            # =============== LOG SCALE PLOT ===============
            ax_log = axes_log[i]
            ax_log.scatter(step_sigmas, step_residuals, alpha=0.7, s=35, 
                          color='darkgreen', edgecolors='black', linewidth=0.3)
            
            ax_log.set_xlabel('Sigma (log)', fontsize=11)
            ax_log.set_ylabel('Boundary Residual (log)', fontsize=11)
            ax_log.set_title(f'Step {step} - Log Scale\n(boundary n={len(step_sigmas)})', 
                            fontsize=12, fontweight='bold', pad=15)
            ax_log.set_xscale('log')
            ax_log.set_yscale('log')
            ax_log.grid(True, alpha=0.3)
            
            # SHARED LOG AXIS LIMITS
            ax_log.set_xlim(max(all_sigma_min*0.5, 1e-6), all_sigma_max*2)
            ax_log.set_ylim(max(all_residual_min*0.5, 1e-6), all_residual_max*2)
            
            # Log correlation and trend line
            if len(step_sigmas) > 2:
                try:
                    log_corr = np.corrcoef(np.log(step_sigmas + 1e-8), 
                                         np.log(step_residuals + 1e-8))[0,1]
                    correlations_log.append(log_corr)
                    
                    ax_log.text(0.05, 0.90, f'r={log_corr:.2f}', transform=ax_log.transAxes, 
                               fontsize=9, fontweight='bold',
                               bbox=dict(boxstyle='round,pad=0.3', facecolor='lightgreen', alpha=0.8))
                    
                    # Add trend line
                    if len(step_sigmas) > 5:
                        log_sigmas = np.log(step_sigmas + 1e-8)
                        log_residuals = np.log(step_residuals + 1e-8)
                        z = np.polyfit(log_sigmas, log_residuals, 1)
                        
                        sigma_range = np.logspace(np.log10(step_sigmas.min()), 
                                                np.log10(step_sigmas.max()), 50)
                        trend_residuals = np.exp(z[0] * np.log(sigma_range) + z[1])
                        ax_log.plot(sigma_range, trend_residuals, 'darkgreen', 
                                   linestyle='--', alpha=0.8, linewidth=2)
                except:
                    correlations_log.append(0)
            else:
                correlations_log.append(0)
            
            # =============== LINEAR SCALE PLOT ===============
            ax_linear = axes_linear[i]
            ax_linear.scatter(step_sigmas, step_residuals, alpha=0.7, s=35, 
                             color='darkcyan', edgecolors='black', linewidth=0.3)
            
            ax_linear.set_xlabel('Sigma (linear)', fontsize=11)
            ax_linear.set_ylabel('Boundary Residual (linear)', fontsize=11)
            ax_linear.set_title(f'Step {step} - Linear Scale\n(boundary n={len(step_sigmas)})', 
                               fontsize=12, fontweight='bold', pad=15)
            ax_linear.grid(True, alpha=0.3)
            
            # SHARED LINEAR AXIS LIMITS
            ax_linear.set_xlim(all_sigma_min - sigma_padding, all_sigma_max + sigma_padding)
            ax_linear.set_ylim(all_residual_min - residual_padding, all_residual_max + residual_padding)
            
            # Linear correlation and trend line
            if len(step_sigmas) > 2:
                try:
                    linear_corr = np.corrcoef(step_sigmas, step_residuals)[0,1]
                    correlations_linear.append(linear_corr)
                    
                    ax_linear.text(0.05, 0.90, f'r={linear_corr:.2f}', transform=ax_linear.transAxes, 
                                  fontsize=9, fontweight='bold',
                                  bbox=dict(boxstyle='round,pad=0.3', facecolor='lightcyan', alpha=0.8))
                    
                    # Add linear trend line
                    if len(step_sigmas) > 5:
                        z = np.polyfit(step_sigmas, step_residuals, 1)
                        sigma_range = np.linspace(step_sigmas.min(), step_sigmas.max(), 50)
                        trend_residuals = z[0] * sigma_range + z[1]
                        ax_linear.plot(sigma_range, trend_residuals, 'darkcyan', 
                                      linestyle='--', alpha=0.8, linewidth=2)
                except:
                    correlations_linear.append(0)
            else:
                correlations_linear.append(0)

        # Hide unused subplots for both figures
        for i in range(len(selected_steps), len(axes_log)):
            axes_log[i].set_visible(False)
            axes_linear[i].set_visible(False)

        # =============== FIGURE FINALIZATIONS ===============
        total_duration = unique_steps[-1] - unique_steps[0] if len(unique_steps) > 1 else 0

        # 1. LOG SCALE FIGURE
        avg_log_corr = np.mean(correlations_log) if correlations_log else 0
        
        fig_log.suptitle(f'Boundary Training Evolution: Sigma vs Boundary Residual (LOG SCALE)\n' + 
                        f'Duration: {total_duration} steps | Analyzed: {len(selected_steps)} of {len(unique_steps)} total | ' +
                        f'Avg Log Correlation: {avg_log_corr:.3f}',
                        fontsize=14, fontweight='bold', y=0.96)
        
        plt.figure(fig_log.number)
        plt.tight_layout(rect=[0, 0.05, 1, 0.90])
        plt.subplots_adjust(top=0.85, bottom=0.10, left=0.08, right=0.95, 
                        hspace=0.5, wspace=0.4)
        
        if save_dir:
            plt.figure(fig_log.number)
            plt.savefig(os.path.join(save_dir, 'boundary_training_evolution_comparison_log.png'), 
                    dpi=150, bbox_inches='tight', pad_inches=0.3,
                    facecolor='white', edgecolor='none')
        
        # 2. LINEAR SCALE FIGURE
        avg_linear_corr = np.mean(correlations_linear) if correlations_linear else 0
        
        fig_linear.suptitle(f'Boundary Training Evolution: Sigma vs Boundary Residual (LINEAR SCALE)\n' + 
                        f'Duration: {total_duration} steps | Analyzed: {len(selected_steps)} of {len(unique_steps)} total | ' +
                        f'Avg Linear Correlation: {avg_linear_corr:.3f}',
                        fontsize=14, fontweight='bold', y=0.96)
        
        plt.figure(fig_linear.number)
        plt.tight_layout(rect=[0, 0.05, 1, 0.90])
        plt.subplots_adjust(top=0.85, bottom=0.10, left=0.08, right=0.95, 
                        hspace=0.5, wspace=0.4)
        
        if save_dir:
            plt.figure(fig_linear.number)
            plt.savefig(os.path.join(save_dir, 'boundary_training_evolution_comparison_linear.png'), 
                    dpi=150, bbox_inches='tight', pad_inches=0.3,
                    facecolor='white', edgecolor='none')

        # Return both figures as a tuple
        return fig_log, fig_linear

    def log_to_wandb(self, step):
        """Log plots to wandb with proper cleanup - UPDATED WITH BOUNDARY PLOTS"""
        # Original PDE residual plots
        fig1 = self.create_global_scatter_plot()
        if fig1 is not None:
            wandb.log({"PDE_global_scatter/residual_scatter_global_full": wandb.Image(fig1)}, step=step)
            self._cleanup_figure(fig1)
        
        fig2 = self.create_global_scatter_plot_clipped()
        if fig2 is not None:
            wandb.log({"PDE_global_scatter/residual_scatter_global_enhanced": wandb.Image(fig2)}, step=step)
            self._cleanup_figure(fig2)
        
        # Zoomed PDE plots (if methods exist)
        if hasattr(self, 'create_global_scatter_plot_clipped_zoomed_log'):
            fig2_zoom_log = self.create_global_scatter_plot_clipped_zoomed_log(sigma_max_threshold=1.0)
            if fig2_zoom_log is not None:
                wandb.log({"PDE_global_scatter/residual_scatter_global_zoomed_log": wandb.Image(fig2_zoom_log)}, step=step)
                self._cleanup_figure(fig2_zoom_log)
    
        if hasattr(self, 'create_global_scatter_plot_clipped_zoomed_linear'):
            fig2_zoom_linear = self.create_global_scatter_plot_clipped_zoomed_linear(sigma_max_threshold=1.0)
            if fig2_zoom_linear is not None:
                wandb.log({"PDE_global_scatter/residual_scatter_global_zoomed_linear": wandb.Image(fig2_zoom_linear)}, step=step)
                self._cleanup_figure(fig2_zoom_linear)
    
        #Boundary residual global scatter plots
        fig_boundary_log = self.create_global_boundary_scatter_plot_log()
        if fig_boundary_log is not None:
            wandb.log({"Boundary_plots/residual_scatter_global_log": wandb.Image(fig_boundary_log)}, step=step)
            self._cleanup_figure(fig_boundary_log)
    
        fig_boundary_linear = self.create_global_boundary_scatter_plot_linear()
        if fig_boundary_linear is not None:
            wandb.log({"Boundary_plots/residual_scatter_global_linear": wandb.Image(fig_boundary_linear)}, step=step)
            self._cleanup_figure(fig_boundary_linear)
    
        # Other PDE plots
        fig3 = self.create_comprehensive_analysis()
        if fig3 is not None:
            wandb.log({"PDE_plots_other/comprehensive_analysis": wandb.Image(fig3)}, step=step)
            self._cleanup_figure(fig3)
    
        if len(self.recent_sigmas) > 0:
            fig4 = self.create_sigma_vs_residual_plot(step=step)
            if fig4 is not None:
                wandb.log({"PDE_plots_other/sigma_vs_residual_comprehensive": wandb.Image(fig4)}, step=step)
                self._cleanup_figure(fig4)
        
        # Training visualization
        if self.recent_training_data:
            latest_data = self.recent_training_data[-1]
            training_viz_figures = None

            # Check if data is for unified mode
        if 'ground_truth_a' in latest_data:
            training_viz_figures = self.create_training_comparison_visualization_unified(
                ground_truth_a=latest_data['ground_truth_a'],
                ground_truth_u=latest_data['ground_truth_u'],
                predictions_a=latest_data['predictions_a'],
                predictions_u=latest_data['predictions_u'],
                model_input_a=latest_data['model_input_a'],
                model_input_u=latest_data['model_input_u'],
                mask_a=latest_data['mask_a'],
                mask_u=latest_data['mask_u'],
                step=step,
                save_dir=None,  # Don't save locally
                n_samples=16,
                direction=latest_data['direction'],
                pde_loss_fn=getattr(self, '_cached_pde_loss_fn', None),
                sample_selection="random"
            )
            # The unified function returns a single figure, so we wrap it in a list
            for fig_idx, fig in enumerate(training_viz_figures):
                wandb.log({
                    f"training_visualization/fig_{fig_idx+1}": wandb.Image(fig)
                }, step=step)
                self._cleanup_figure(fig)
        else:
            # Original function for conditional mode
            training_viz_figures = self.create_training_comparison_visualization(
                input_data=latest_data['input_data'],
                ground_truth=latest_data['ground_truth'],
                predictions=latest_data['predictions'],
                step=step,
                save_dir=None,  # Don't save locally
                n_samples=4,  # 16 samples across 4 figures
                direction=latest_data['direction'],
                pde_loss_fn=getattr(self, '_cached_pde_loss_fn', None),
                sample_selection="random"  # Can be changed to "fixed" once implemented
            )
            if training_viz_figures is not None:
                # Log each figure separately
                for fig_idx, fig in enumerate(training_viz_figures):
                    wandb.log({
                        f"training_visualization/{latest_data['direction']}_fig_{fig_idx+1}": wandb.Image(fig)
                    }, step=step)
                    self._cleanup_figure(fig)
        
        # NEW: Log residual map visualizations to WandB
        if self.recent_residual_maps:
            # Create and log thresholded residual maps
            thresholded_fig = self.create_thresholded_residual_maps(
                self.recent_residual_maps[-1], 
                self.recent_residual_map_sigmas[-1], 
                self.recent_residual_map_steps[-1]
            )
            if thresholded_fig is not None:
                wandb.log({"Residual_Maps/thresholded_residual_maps": wandb.Image(thresholded_fig)}, step=step)
                self._cleanup_figure(thresholded_fig)
            
            # Create and log comprehensive heatmap analysis
            # heatmap_fig = self.create_residual_heatmap_analysis(
            #     self.recent_residual_maps[-1], 
            #     self.recent_residual_map_sigmas[-1], 
            #     self.recent_residual_map_steps[-1]
            # )
            # if heatmap_fig is not None:
            #     wandb.log({"Residual_Maps/residual_heatmap_analysis": wandb.Image(heatmap_fig)}, step=step)
            #     self._cleanup_figure(heatmap_fig)

        # PDE training evolution (all five versions)
        training_evolution_result = self.create_training_evolution_comparison()
        if training_evolution_result is not None:
            if isinstance(training_evolution_result, tuple):
                if len(training_evolution_result) == 5:
                    # New version with five figures
                    fig_log, fig_linear, fig_linear_clipped, fig_log_zoomed, fig_linear_zoomed = training_evolution_result
                    if fig_log is not None:
                        wandb.log({"PDE_plots/training_evolution_log_scale": wandb.Image(fig_log)}, step=step)
                        self._cleanup_figure(fig_log)
                    if fig_linear is not None:
                        wandb.log({"PDE_plots/training_evolution_linear_scale": wandb.Image(fig_linear)}, step=step)
                        self._cleanup_figure(fig_linear)
                    if fig_linear_clipped is not None:
                        wandb.log({"PDE_plots/training_evolution_linear_clipped": wandb.Image(fig_linear_clipped)}, step=step)
                        self._cleanup_figure(fig_linear_clipped)
                    if fig_log_zoomed is not None:
                        wandb.log({"PDE_plots/training_evolution_log_zoomed": wandb.Image(fig_log_zoomed)}, step=step)
                        self._cleanup_figure(fig_log_zoomed)
                    if fig_linear_zoomed is not None:
                        wandb.log({"PDE_plots/training_evolution_linear_zoomed": wandb.Image(fig_linear_zoomed)}, step=step)
                        self._cleanup_figure(fig_linear_zoomed)
                elif len(training_evolution_result) == 3:
                    # Legacy version with three figures
                    fig_log, fig_linear, fig_linear_clipped = training_evolution_result
                    if fig_log is not None:
                        wandb.log({"PDE_plots/training_evolution_log_scale": wandb.Image(fig_log)}, step=step)
                        self._cleanup_figure(fig_log)
                    if fig_linear is not None:
                        wandb.log({"PDE_plots/training_evolution_linear_scale": wandb.Image(fig_linear)}, step=step)
                        self._cleanup_figure(fig_linear)
                    if fig_linear_clipped is not None:
                        wandb.log({"PDE_plots/training_evolution_linear_clipped": wandb.Image(fig_linear_clipped)}, step=step)
                        self._cleanup_figure(fig_linear_clipped)
                elif len(training_evolution_result) == 2:
                    # Legacy version with two figures
                    fig_log, fig_linear = training_evolution_result
                    if fig_log is not None:
                        wandb.log({"PDE_plots/training_evolution_log_scale": wandb.Image(fig_log)}, step=step)
                        self._cleanup_figure(fig_log)
                    if fig_linear is not None:
                        wandb.log({"PDE_plots/training_evolution_linear_scale": wandb.Image(fig_linear)}, step=step)
                        self._cleanup_figure(fig_linear)
            else:
                # Single figure (backward compatibility)
                wandb.log({"PDE_plots/training_evolution_comparison": wandb.Image(training_evolution_result)}, step=step)
                self._cleanup_figure(training_evolution_result)
        
        # 🆕 NEW: Boundary training evolution plots
        boundary_evolution_result = self.create_boundary_training_evolution_comparison()
        if boundary_evolution_result is not None:
            if isinstance(boundary_evolution_result, tuple):
                fig_boundary_log, fig_boundary_linear = boundary_evolution_result
                if fig_boundary_log is not None:
                    wandb.log({"Boundary_plots/training_evolution_log_scale": wandb.Image(fig_boundary_log)}, step=step)
                    self._cleanup_figure(fig_boundary_log)
                if fig_boundary_linear is not None:
                    wandb.log({"Boundary_plots/training_evolution_linear_scale": wandb.Image(fig_boundary_linear)}, step=step)
                    self._cleanup_figure(fig_boundary_linear)
            else:
                # Single figure fallback
                wandb.log({"Boundary_plots/training_evolution": wandb.Image(boundary_evolution_result)}, step=step)
                self._cleanup_figure(boundary_evolution_result)
        

    def log_detailed_plots(self, step, save_dir=None):
        """Log detailed sigma vs residual plots more frequently"""
        if step % self.detailed_log_frequency == 0 and len(self.recent_sigmas) > 0:
            fig = self.create_sigma_vs_residual_plot(save_dir=save_dir, step=step)
            if fig is not None:
                wandb.log({f"PDE_detailed/sigma_vs_residual_step_{step}": wandb.Image(fig)}, step=step)
                self._cleanup_figure(fig)
    
    def clear_recent_history(self):
        """Clear recent history for detailed plots (keep some overlap)"""
        # Keep last 50 samples for continuity
        if len(self.recent_sigmas) > 50:
            self.recent_sigmas = self.recent_sigmas[-50:]
            self.recent_residuals = self.recent_residuals[-50:]
            self.recent_steps = self.recent_steps[-50:]
            
            # Also clear boundary residual history
            self.recent_boundary2_sigmas = self.recent_boundary2_sigmas[-50:]
            self.recent_boundary2_residuals = self.recent_boundary2_residuals[-50:]
            self.recent_boundary1_sigmas = self.recent_boundary1_sigmas[-50:]
            self.recent_boundary1_residuals = self.recent_boundary1_residuals[-50:]
            
            # Clear residual maps (keep only last 5 to save memory)
            if len(self.recent_residual_maps) > 5:
                self.recent_residual_maps = self.recent_residual_maps[-5:]
                self.recent_residual_map_sigmas = self.recent_residual_map_sigmas[-5:]
                self.recent_residual_map_steps = self.recent_residual_map_steps[-5:]
        
    def get_statistics(self):
        """Get current statistics for recent data including boundary residuals"""
        if not self.recent_sigmas:
            return {}
            
        sigmas = np.array(self.recent_sigmas)
        residuals = np.array(self.recent_residuals)
        
        # Basic statistics for interior residuals
        stats = {
            'sigma_mean': np.mean(sigmas),
            'sigma_std': np.std(sigmas),
            'residual_mean': np.mean(residuals),
            'residual_std': np.std(residuals),
            'residual_max': np.max(residuals),
            'residual_min': np.min(residuals),
            'residual_skewness': self._calculate_skewness_from_array(residuals),
            'residual_kurtosis': self._calculate_kurtosis_from_array(residuals),
        }
        
        # Add boundary residual statistics if available
        if self.recent_boundary2_residuals:
            boundary2_residuals = np.array(self.recent_boundary2_residuals)
            boundary1_residuals = np.array(self.recent_boundary1_residuals)
            
            stats.update({
                'boundary2_residual_mean': np.mean(boundary2_residuals),
                'boundary2_residual_std': np.std(boundary2_residuals),
                'boundary2_residual_max': np.max(boundary2_residuals),
                'boundary2_residual_skewness': self._calculate_skewness_from_array(boundary2_residuals),
                'boundary2_residual_kurtosis': self._calculate_kurtosis_from_array(boundary2_residuals),
                'boundary1_residual_mean': np.mean(boundary1_residuals),
                'boundary1_residual_std': np.std(boundary1_residuals),
                'boundary1_residual_max': np.max(boundary1_residuals),
                'boundary1_residual_skewness': self._calculate_skewness_from_array(boundary1_residuals),
                'boundary1_residual_kurtosis': self._calculate_kurtosis_from_array(boundary1_residuals),
            })
            
            # # Compute boundary vs interior ratios
            # if np.mean(residuals) > 0:
            #     stats['boundary2_vs_interior_ratio'] = np.mean(boundary2_residuals) / np.mean(residuals)
            #     stats['boundary1_vs_interior_ratio'] = np.mean(boundary1_residuals) / np.mean(residuals)
        
        if len(sigmas) > 1:
            # Compute correlation in log space
            try:
                log_sigmas = np.log(sigmas + 1e-8)  # Add small epsilon to avoid log(0)
                log_residuals = np.log(residuals + 1e-8)
                stats['correlation'] = np.corrcoef(log_sigmas, log_residuals)[0,1]
            except:
                stats['correlation'] = 0.0
        else:
            stats['correlation'] = 0.0
            
        return stats
    

    def create_sigma_vs_residual_plot(self, save_dir=None, step=None):
        """Enhanced sigma vs residual analysis with multiple views"""
        if not self.recent_sigmas:
            return None
            
        # CLEAR MATPLOTLIB STATE
        plt.clf()
        plt.close('all')
        
        # CREATE 2x3 SUBPLOT LAYOUT for comprehensive analysis
        fig, axes = plt.subplots(2, 3, figsize=(20, 12))
        
        # Get data for CURRENT STEP ONLY (last 50 samples to focus on current behavior)
        n_recent = min(50, len(self.recent_sigmas))  # Smaller window for current step focus
        sigmas = np.array(self.recent_sigmas[-n_recent:])
        residuals = np.array(self.recent_residuals[-n_recent:])
        steps_recent = np.array(self.recent_steps[-n_recent:])
        
        # Filter to get data from current step only (or very recent steps)
        if step is not None:
            # Get data from last few steps (within 5 steps of current)
            recent_mask = np.abs(steps_recent - step) <= 5
            if np.any(recent_mask):
                sigmas = sigmas[recent_mask]
                residuals = residuals[recent_mask]
                steps_recent = steps_recent[recent_mask]
    
        if len(sigmas) == 0:
            plt.close(fig)
            return None
    
        # 1. SCATTER PLOT - LOG SCALE (Overview)
        ax1 = axes[0, 0]
        scatter1 = ax1.scatter(sigmas, residuals, c=steps_recent, cmap='plasma', 
                              alpha=0.7, s=50, edgecolors='black', linewidth=0.5)
        ax1.set_xlabel('Sigma Value (log)', fontsize=12)
        ax1.set_ylabel('PDE Residual (log)', fontsize=12)
        ax1.set_title(f'Log-Scale Overview\n(Step ~{step})', fontsize=14, fontweight='bold')
        ax1.set_xscale('log')
        ax1.set_yscale('log')
        ax1.grid(True, alpha=0.3)
    
        cbar1 = plt.colorbar(scatter1, ax=ax1, shrink=0.8)
        cbar1.set_label('Training Step', fontsize=10)
        
        # 2. SCATTER PLOT - LINEAR SCALE (Details)
        ax2 = axes[0, 1]
        scatter2 = ax2.scatter(sigmas, residuals, c=steps_recent, cmap='plasma', 
                              alpha=0.7, s=50, edgecolors='black', linewidth=0.5)
        ax2.set_xlabel('Sigma Value (linear)', fontsize=12)
        ax2.set_ylabel('PDE Residual (linear)', fontsize=12)
        ax2.set_title(f'Linear-Scale Details\n(Step ~{step})', fontsize=14, fontweight='bold')
        ax2.grid(True, alpha=0.3)
    
        cbar2 = plt.colorbar(scatter2, ax=ax2, shrink=0.8)
        cbar2.set_label('Training Step', fontsize=10)
        
        # 3. BINNED ANALYSIS - ENHANCED
        ax3 = axes[0, 2]
        if len(sigmas) >= 5:  # Need minimum points for binning
            # Create adaptive bins based on data distribution
            n_bins = min(8, len(np.unique(sigmas)))
            if np.max(sigmas) / np.min(sigmas) > 100:  # Wide range, use log bins
                sigma_bins = np.logspace(np.log10(max(sigmas.min(), 1e-6)), 
                                       np.log10(sigmas.max()), n_bins)
                use_log = True
            else:  # Narrow range, use linear bins
                sigma_bins = np.linspace(sigmas.min(), sigmas.max(), n_bins)
                use_log = False
            
            bin_indices = np.digitize(sigmas, sigma_bins)
            
            bin_centers = []
            bin_means = []
            bin_stds = []
            bin_counts = []
            
            for i in range(1, len(sigma_bins)):
                mask = bin_indices == i
                if np.any(mask):
                    if use_log:
                        bin_centers.append(np.sqrt(sigma_bins[i-1] * sigma_bins[i]))  # Geometric mean
                    else:
                        bin_centers.append((sigma_bins[i-1] + sigma_bins[i]) / 2)  # Arithmetic mean
                    bin_means.append(np.mean(residuals[mask]))
                    bin_stds.append(np.std(residuals[mask]))
                    bin_counts.append(np.sum(mask))
            
            if bin_centers:
                # Plot with different colors based on bin size
                colors = plt.cm.viridis(np.array(bin_counts) / max(bin_counts))
                
                ax3.errorbar(bin_centers, bin_means, yerr=bin_stds, 
                            fmt='o-', capsize=5, linewidth=3, markersize=8,
                            color='blue', markerfacecolor='lightblue',
                            markeredgecolor='darkblue', markeredgewidth=2)
                
                # Add count annotations with better positioning
                for x, y, count, std in zip(bin_centers, bin_means, bin_counts, bin_stds):
                    ax3.annotate(f'n={count}', (x, y + std), xytext=(0, 10), 
                               textcoords='offset points', fontsize=9, 
                               ha='center', fontweight='bold',
                               bbox=dict(boxstyle='round,pad=0.2', facecolor='yellow', alpha=0.7))
                
                x_label = 'Sigma (log)' if use_log else 'Sigma Value'
                y_label = 'Mean Residual ± Std (log)' if use_log else 'Mean Residual ± Std'
                ax3.set_xlabel(x_label, fontsize=12)
                ax3.set_ylabel(y_label, fontsize=12)
                ax3.set_title(f'Binned Analysis\n({len(bin_centers)} bins)', 
                             fontsize=13, fontweight='bold', pad=20)
                
                if use_log:
                    ax3.set_xscale('log')
                    ax3.set_yscale('log')
                ax3.grid(True, alpha=0.3)
        else:
            ax3.text(0.5, 0.5, 'Insufficient data\nfor binning', 
                    ha='center', va='center', transform=ax3.transAxes,
                    fontsize=14, bbox=dict(boxstyle='round', facecolor='lightgray'))
            ax3.set_title('Binned Analysis\n(Need more data)', fontsize=14)
    
        # 4. HISTOGRAM - SIGMA DISTRIBUTION
        ax4 = axes[1, 0]
        ax4.hist(sigmas, bins=15, alpha=0.7, color='skyblue', edgecolor='black')
        ax4.set_xlabel('Sigma Value', fontsize=12)
        ax4.set_ylabel('Frequency', fontsize=12)
        ax4.set_title(f'Sigma Distribution\n(n={len(sigmas)})', fontsize=14, fontweight='bold')
        ax4.grid(True, alpha=0.3)
        
        # Add statistics
        sigma_stats = f'Mean: {np.mean(sigmas):.3f}\nStd: {np.std(sigmas):.3f}\nRange: {np.min(sigmas):.3f}-{np.max(sigmas):.3f}'
        ax4.text(0.02, 0.98, sigma_stats, transform=ax4.transAxes, 
                verticalalignment='top', fontsize=10,
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
        
        # 5. HISTOGRAM - RESIDUAL DISTRIBUTION
        ax5 = axes[1, 1]
        ax5.hist(residuals, bins=15, alpha=0.7, color='lightcoral', edgecolor='black')
        ax5.set_xlabel('PDE Residual', fontsize=12)
        ax5.set_ylabel('Frequency', fontsize=12)
        ax5.set_title(f'Residual Distribution\n(n={len(residuals)})', fontsize=14, fontweight='bold')
        ax5.grid(True, alpha=0.3)
        
        # Add statistics
        residual_stats = f'Mean: {np.mean(residuals):.4f}\nStd: {np.std(residuals):.4f}\nRange: {np.min(residuals):.4f}-{np.max(residuals):.4f}'
        ax5.text(0.02, 0.98, residual_stats, transform=ax5.transAxes, 
                verticalalignment='top', fontsize=10,
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
        
        # 6. CORRELATION ANALYSIS
        ax6 = axes[1, 2]
        if len(sigmas) > 2:
            # Calculate correlation in both linear and log space
            try:
                linear_corr = np.corrcoef(sigmas, residuals)[0, 1]
                log_sigmas = np.log(sigmas + 1e-8)
                log_residuals = np.log(residuals + 1e-8)
                log_corr = np.corrcoef(log_sigmas, log_residuals)[0, 1]
                
                # Scatter plot with trend line
                ax6.scatter(sigmas, residuals, alpha=0.6, s=30, color='purple')
                
                # Add trend line
                z = np.polyfit(sigmas, residuals, 1)
                p = np.poly1d(z)
                x_trend = np.linspace(sigmas.min(), sigmas.max(), 100)
                ax6.plot(x_trend, p(x_trend), "r--", alpha=0.8, linewidth=2)
                
                ax6.set_xlabel('Sigma Value', fontsize=12)
                ax6.set_ylabel('PDE Residual', fontsize=12)
                ax6.set_title(f'Correlation Analysis\nLinear: {linear_corr:.3f}, Log: {log_corr:.3f}', 
                             fontsize=14, fontweight='bold')
                ax6.grid(True, alpha=0.3)
                
            except:
                ax6.text(0.5, 0.5, 'Correlation\ncalculation failed', 
                        ha='center', va='center', transform=ax6.transAxes,
                        fontsize=14)
        else:
            ax6.text(0.5, 0.5, 'Need more data\nfor correlation', 
                    ha='center', va='center', transform=ax6.transAxes,
                    fontsize=14)
    
        # OVERALL TITLE
        fig.suptitle('Comprehensive Sigma vs PDE Residual Analysis', fontsize=18, fontweight='bold', y=0.98)
        
        plt.tight_layout()
        plt.subplots_adjust(top=0.95)
        
        if save_dir:
            plt.savefig(os.path.join(save_dir, 'sigma_vs_residual_comprehensive_analysis.png'), 
                       dpi=150, bbox_inches='tight')
        
        return fig

    def log_boundary_residuals_from_maps(self, pde_residual_maps, sigmas, step, boundary_width=2):
        """
        Extract boundary residuals from PDE residual maps and log them
        
        Args:
            pde_residual_maps: torch.Tensor, PDE residual maps of shape [batch, height, width] or [batch, channels, height, width]
            sigmas: torch.Tensor, sigma values for each sample
            step: int, current training step
            boundary_width: int, width of boundary region in pixels
        """
        # Extract boundary pixels from residual maps
        boundary_residuals = self._extract_boundary_pixels(pde_residual_maps, boundary_width)
        
        # Calculate mean boundary residual per sample
        batch_size = min(self.max_samples, boundary_residuals.size(0))
        sample_boundary_residuals = torch.mean(torch.abs(boundary_residuals[:batch_size]), dim=1)  # [batch_size]
        sample_sigmas = sigmas[:batch_size].squeeze()
        
        # Log to global boundary history every log_frequency steps
        if step % self.log_frequency != 0:
            return
            
        # Store data for global boundary scatter plot (persistent)
        self.global_boundary_steps.extend([step] * batch_size)
        self.global_boundary_sigmas.extend(sample_sigmas.detach().cpu().numpy())
        self.global_boundary_residuals.extend(sample_boundary_residuals.detach().cpu().numpy())

    def _extract_boundary_pixels(self, residual_maps, boundary_width=2):
        """
        Extract boundary pixels from 2D residual maps
        
        Args:
            residual_maps: torch.Tensor of shape [batch_size, height, width] or [batch_size, channels, height, width]
            boundary_width: int, width of boundary region in pixels
        
        Returns:
            boundary_residuals: torch.Tensor of shape [batch_size, num_boundary_pixels]
        """
        # Handle different input shapes
        if residual_maps.dim() == 4:  # [batch, channels, height, width]
            # Take mean across channels if multiple channels
            residual_maps = torch.mean(residual_maps, dim=1)  # [batch, height, width]
        elif residual_maps.dim() == 3:  # [batch, height, width]
            pass  # Already correct shape
        else:
            raise ValueError(f"Expected 3D or 4D tensor, got {residual_maps.dim()}D")
        
        batch_size, height, width = residual_maps.shape
        
        # Create boundary mask
        boundary_mask = torch.zeros((height, width), dtype=torch.bool, device=residual_maps.device)
        
        # Top and bottom boundaries
        boundary_mask[:boundary_width, :] = True  # Top
        boundary_mask[-boundary_width:, :] = True  # Bottom
        
        # Left and right boundaries  
        boundary_mask[:, :boundary_width] = True  # Left
        boundary_mask[:, -boundary_width:] = True  # Right
        
        # Extract boundary pixels for each sample in batch
        boundary_residuals = []
        for i in range(batch_size):
            sample_residuals = residual_maps[i]  # [height, width]
            boundary_pixels = sample_residuals[boundary_mask]  # [num_boundary_pixels]
            boundary_residuals.append(boundary_pixels)
        
        # Stack into tensor [batch_size, num_boundary_pixels]
        boundary_residuals = torch.stack(boundary_residuals, dim=0)
        
        return boundary_residuals

    def _subsample_global_boundary_data(self, max_points=20):
        """Subsample global boundary scatter data to show evenly spaced steps across training"""
        if not self.global_boundary_steps:
            return [], [], []
            
        steps = np.array(self.global_boundary_steps)
        sigmas = np.array(self.global_boundary_sigmas)
        residuals = np.array(self.global_boundary_residuals)
        
        # Get unique steps
        unique_steps = np.unique(steps)
        if len(unique_steps) <= max_points:
            return steps, sigmas, residuals
        
        # Generate target steps (evenly spaced)
        min_step = unique_steps.min()
        max_step = unique_steps.max()
        target_steps = np.linspace(min_step, max_step, max_points)
        
        subsampled_steps = []
        subsampled_sigmas = []
        subsampled_residuals = []
        
        for target_step in target_steps:
            # Find the closest actual step to our target
            closest_step_idx = np.argmin(np.abs(unique_steps - target_step))
            closest_step = unique_steps[closest_step_idx]
            
            # Get all data points for this step
            step_mask = steps == closest_step
            if np.any(step_mask):
                step_steps = steps[step_mask]
                step_sigmas = sigmas[step_mask]
                step_residuals = residuals[step_mask]
                
                # Take all samples from this step
                subsampled_steps.extend(step_steps)
                subsampled_sigmas.extend(step_sigmas)
                subsampled_residuals.extend(step_residuals)
        
        return np.array(subsampled_steps), np.array(subsampled_sigmas), np.array(subsampled_residuals)

    def create_global_boundary_scatter_plot_log(self, save_dir=None, max_columns=15):
        """Create global boundary scatter plot - LOG SCALE"""
        if not self.global_boundary_steps:
            return None
                
        plt.clf()
        plt.close('all')
            
        fig, ax = plt.subplots(figsize=(16, 10))
        
        # Subsample data to get representative points across training
        steps, sigmas, residuals = self._subsample_global_boundary_data(max_points=max_columns)
        
        if len(steps) == 0:
            return None
        
        # PLOT VERTICAL LINES FIRST (in background)
        unique_displayed_steps = np.unique(steps)
        if len(unique_displayed_steps) > 1:
            for step in unique_displayed_steps:
                ax.axvline(x=step, color='lightgray', alpha=0.6, linestyle='--', 
                        linewidth=1.0, zorder=1)
        
        # ENHANCED SCATTER PLOT with CLIPPING for better color discrimination
        vmin_clip = np.percentile(residuals, 5)   # Bottom 5% clipped
        vmax_clip = np.percentile(residuals, 95)  # Top 5% clipped
        
        scatter = ax.scatter(steps, sigmas, c=residuals, cmap='viridis',
                        alpha=0.8, s=80, edgecolors='black', linewidth=0.5,
                        zorder=2, vmin=vmin_clip, vmax=vmax_clip)
        
        ax.set_xlabel('Training Step', fontsize=14)
        ax.set_ylabel('Sigma Value (log scale)', fontsize=14)
        ax.set_title('Boundary Residuals vs Training Progress (Global View - Log Scale)', fontsize=16)
        ax.set_yscale('log')
        
        # Enhanced X-axis ticks
        if len(unique_displayed_steps) > 0:
            ax.set_xticks(unique_displayed_steps)
            ax.set_xticklabels([f'{int(step)}' for step in unique_displayed_steps], 
                            rotation=45, ha='right', fontsize=11)
            
            if len(unique_displayed_steps) > 1:
                step_range = unique_displayed_steps.max() - unique_displayed_steps.min()
                minor_step_interval = step_range / (len(unique_displayed_steps) * 5)
                minor_ticks = np.arange(unique_displayed_steps.min(), 
                                    unique_displayed_steps.max() + minor_step_interval, 
                                    minor_step_interval)
                ax.set_xticks(minor_ticks, minor=True)
        
        # Enhanced Y-axis ticks
        sigma_min = max(sigmas.min(), 1e-3)
        sigma_max = sigmas.max()
        
        major_powers = np.arange(np.floor(np.log10(sigma_min)), 
                                np.ceil(np.log10(sigma_max)) + 1)
        major_sigma_ticks = 10.0 ** major_powers
        
        minor_sigma_ticks = []
        for power in major_powers:
            base = 10.0 ** power
            for multiplier in [2, 3, 4, 5, 6, 7, 8, 9]:
                tick_val = base * multiplier
                if sigma_min <= tick_val <= sigma_max:
                    minor_sigma_ticks.append(tick_val)
        
        major_ticks_in_range = major_sigma_ticks[(major_sigma_ticks >= sigma_min) & 
                                                (major_sigma_ticks <= sigma_max)]
        ax.set_yticks(major_ticks_in_range)
        ax.set_yticklabels([f'{tick:.0e}' if tick >= 1 else f'{tick:.3f}' 
                        for tick in major_ticks_in_range], fontsize=12)
        
        if minor_sigma_ticks:
            ax.set_yticks(minor_sigma_ticks, minor=True)
        
        # COLORBAR with data range info
        cbar = plt.colorbar(scatter, ax=ax, shrink=0.8)
        cbar.set_label('Mean Absolute Boundary Residual (Full Range)', fontsize=14)
        cbar.ax.tick_params(labelsize=11)
        
        # Enhanced grid
        ax.grid(True, which='major', alpha=0.3, linewidth=0.8, zorder=0)
        ax.grid(True, which='minor', alpha=0.1, linewidth=0.4, zorder=0)
        
        # Calculate step interval
        unique_logged_steps = np.unique(self.global_boundary_steps)
        
        if len(unique_logged_steps) > 1:
            logged_step_intervals = np.diff(unique_logged_steps)
            actual_log_frequency = int(np.median(logged_step_intervals)) if len(logged_step_intervals) > 0 else self.log_frequency
            
            if len(unique_displayed_steps) > 1:
                display_step_interval = int(np.median(np.diff(unique_displayed_steps)))
                effective_interval = display_step_interval
            else:
                effective_interval = actual_log_frequency
        else:
            effective_interval = 0

        # INFO BOX with full range statistics
        step_info = f'Steps: {", ".join([str(int(s)) for s in unique_displayed_steps[:5]])}{"..." if len(unique_displayed_steps) > 5 else ""}'
        
        info_text = (f'FULL RANGE VISUALIZATION:\n'
                    f'• Showing {len(steps)} points from {len(self.global_boundary_steps)} total\n'
                    f'• Boundary residual range: {np.min(residuals):.4f} → {np.max(residuals):.3f}\n'
                    f'• Color scaling: No clipping (full data range)\n'
                    f'• Log frequency: every {self.log_frequency} steps\n'
                    f'• Display interval: every {effective_interval} training steps\n'
                    f'• Extracted from PDE residual maps (boundary pixels)\n'
                    f'• {step_info}')
        
        # Position info box outside the plot
        fig.text(0.02, 0.02, info_text, fontsize=10, 
                bbox=dict(boxstyle='round,pad=0.5', facecolor='lightblue', alpha=0.9, edgecolor='navy'),
                verticalalignment='bottom')
        
        plt.tight_layout()
        plt.subplots_adjust(bottom=0.20)  
        
        if save_dir:
            plt.savefig(os.path.join(save_dir, 'boundary_residual_scatter_global_log.png'), dpi=150, bbox_inches='tight')
        
        return fig

    def create_boundary_training_evolution_comparison(self, save_dir=None, n_subplots=8):
        """Compare boundary sigma-residual relationship across training steps - LOG AND LINEAR SCALES"""
        if not self.global_boundary_steps:
            return None
        
        plt.clf()
        plt.close('all')
        
        all_steps = np.array(self.global_boundary_steps)
        all_sigmas = np.array(self.global_boundary_sigmas)
        all_residuals = np.array(self.global_boundary_residuals)
        
        unique_steps = np.unique(all_steps)
        if len(unique_steps) < 2:
            return None
        
        if len(unique_steps) > n_subplots:
            step_indices = np.linspace(0, len(unique_steps)-1, n_subplots, dtype=int)
            selected_steps = unique_steps[step_indices]
        else:
            selected_steps = unique_steps

        n_rows = 2
        n_cols = int(np.ceil(len(selected_steps) / n_rows))
        
        # CREATE TWO FIGURES - Log and Linear scales for boundary residuals
        fig_log, axes_log = plt.subplots(n_rows, n_cols, figsize=(6*n_cols, 5*n_rows))
        fig_linear, axes_linear = plt.subplots(n_rows, n_cols, figsize=(6*n_cols, 5*n_rows))

        if len(selected_steps) == 1:
            axes_log = [axes_log]
            axes_linear = [axes_linear]
        else:
            axes_log = axes_log.flatten()
            axes_linear = axes_linear.flatten()

        # CALCULATE GLOBAL LIMITS for shared axes
        all_sigma_min = np.min(all_sigmas)
        all_sigma_max = np.max(all_sigmas)
        all_residual_min = np.min(all_residuals)
        all_residual_max = np.max(all_residuals)
        
        # Add some padding for better visualization
        sigma_padding = (all_sigma_max - all_sigma_min) * 0.1
        residual_padding = (all_residual_max - all_residual_min) * 0.1

        # Store correlations for both versions
        correlations_log = []
        correlations_linear = []

        for i, step in enumerate(selected_steps):
            if i >= len(axes_log):
                break
            
            # Get data for this specific step
            step_mask = all_steps == step
            step_sigmas = all_sigmas[step_mask]
            step_residuals = all_residuals[step_mask]
            
            if len(step_sigmas) == 0:
                # Handle no data case for both plots
                for axes_set in [axes_log, axes_linear]:
                    axes_set[i].text(0.5, 0.5, f'No boundary data\nStep {step}', ha='center', va='center',
                                   transform=axes_set[i].transAxes, fontsize=12)
                continue
            
            # =============== LOG SCALE PLOT ===============
            ax_log = axes_log[i]
            ax_log.scatter(step_sigmas, step_residuals, alpha=0.7, s=35, 
                          color='darkgreen', edgecolors='black', linewidth=0.3)
            
            ax_log.set_xlabel('Sigma (log)', fontsize=11)
            ax_log.set_ylabel('Boundary Residual (log)', fontsize=11)
            ax_log.set_title(f'Step {step} - Log Scale\n(boundary n={len(step_sigmas)})', 
                            fontsize=12, fontweight='bold', pad=15)
            ax_log.set_xscale('log')
            ax_log.set_yscale('log')
            ax_log.grid(True, alpha=0.3)
            
            # SHARED LOG AXIS LIMITS
            ax_log.set_xlim(max(all_sigma_min*0.5, 1e-6), all_sigma_max*2)
            ax_log.set_ylim(max(all_residual_min*0.5, 1e-6), all_residual_max*2)
            
            # Log correlation and trend line
            if len(step_sigmas) > 2:
                try:
                    log_corr = np.corrcoef(np.log(step_sigmas + 1e-8), 
                                         np.log(step_residuals + 1e-8))[0,1]
                    correlations_log.append(log_corr)
                    
                    ax_log.text(0.05, 0.90, f'r={log_corr:.2f}', transform=ax_log.transAxes, 
                               fontsize=9, fontweight='bold',
                               bbox=dict(boxstyle='round,pad=0.3', facecolor='lightgreen', alpha=0.8))
                    
                    # Add trend line
                    if len(step_sigmas) > 5:
                        log_sigmas = np.log(step_sigmas + 1e-8)
                        log_residuals = np.log(step_residuals + 1e-8)
                        z = np.polyfit(log_sigmas, log_residuals, 1)
                        
                        sigma_range = np.logspace(np.log10(step_sigmas.min()), 
                                                np.log10(step_sigmas.max()), 50)
                        trend_residuals = np.exp(z[0] * np.log(sigma_range) + z[1])
                        ax_log.plot(sigma_range, trend_residuals, 'darkgreen', 
                                   linestyle='--', alpha=0.8, linewidth=2)
                except:
                    correlations_log.append(0)
            else:
                correlations_log.append(0)
            
            # =============== LINEAR SCALE PLOT ===============
            ax_linear = axes_linear[i]
            ax_linear.scatter(step_sigmas, step_residuals, alpha=0.7, s=35, 
                             color='darkcyan', edgecolors='black', linewidth=0.3)
            
            ax_linear.set_xlabel('Sigma (linear)', fontsize=11)
            ax_linear.set_ylabel('Boundary Residual (linear)', fontsize=11)
            ax_linear.set_title(f'Step {step} - Linear Scale\n(boundary n={len(step_sigmas)})', 
                               fontsize=12, fontweight='bold', pad=15)
            ax_linear.grid(True, alpha=0.3)
            
            # SHARED LINEAR AXIS LIMITS
            ax_linear.set_xlim(all_sigma_min - sigma_padding, all_sigma_max + sigma_padding)
            ax_linear.set_ylim(all_residual_min - residual_padding, all_residual_max + residual_padding)
            
            # Linear correlation and trend line
            if len(step_sigmas) > 2:
                try:
                    linear_corr = np.corrcoef(step_sigmas, step_residuals)[0,1]
                    correlations_linear.append(linear_corr)
                    
                    ax_linear.text(0.05, 0.90, f'r={linear_corr:.2f}', transform=ax_linear.transAxes, 
                                  fontsize=9, fontweight='bold',
                                  bbox=dict(boxstyle='round,pad=0.3', facecolor='lightcyan', alpha=0.8))
                    
                    # Add linear trend line
                    if len(step_sigmas) > 5:
                        z = np.polyfit(step_sigmas, step_residuals, 1)
                        sigma_range = np.linspace(step_sigmas.min(), step_sigmas.max(), 50)
                        trend_residuals = z[0] * sigma_range + z[1]
                        ax_linear.plot(sigma_range, trend_residuals, 'darkcyan', 
                                      linestyle='--', alpha=0.8, linewidth=2)
                except:
                    correlations_linear.append(0)
            else:
                correlations_linear.append(0)

        # Hide unused subplots for both figures
        for i in range(len(selected_steps), len(axes_log)):
            axes_log[i].set_visible(False)
            axes_linear[i].set_visible(False)

        # =============== FIGURE FINALIZATIONS ===============
        total_duration = unique_steps[-1] - unique_steps[0] if len(unique_steps) > 1 else 0

        # 1. LOG SCALE FIGURE
        avg_log_corr = np.mean(correlations_log) if correlations_log else 0
        
        fig_log.suptitle(f'Boundary Training Evolution: Sigma vs Boundary Residual (LOG SCALE)\n' + 
                        f'Duration: {total_duration} steps | Analyzed: {len(selected_steps)} of {len(unique_steps)} total | ' +
                        f'Avg Log Correlation: {avg_log_corr:.3f}',
                        fontsize=14, fontweight='bold', y=0.96)
        
        plt.figure(fig_log.number)
        plt.tight_layout(rect=[0, 0.05, 1, 0.90])
        plt.subplots_adjust(top=0.85, bottom=0.10, left=0.08, right=0.95, 
                        hspace=0.5, wspace=0.4)
        
        if save_dir:
            plt.figure(fig_log.number)
            plt.savefig(os.path.join(save_dir, 'boundary_training_evolution_comparison_log.png'), 
                    dpi=150, bbox_inches='tight', pad_inches=0.3,
                    facecolor='white', edgecolor='none')
        
        # 2. LINEAR SCALE FIGURE
        avg_linear_corr = np.mean(correlations_linear) if correlations_linear else 0
        
        fig_linear.suptitle(f'Boundary Training Evolution: Sigma vs Boundary Residual (LINEAR SCALE)\n' + 
                        f'Duration: {total_duration} steps | Analyzed: {len(selected_steps)} of {len(unique_steps)} total | ' +
                        f'Avg Linear Correlation: {avg_linear_corr:.3f}',
                        fontsize=14, fontweight='bold', y=0.96)
        
        plt.figure(fig_linear.number)
        plt.tight_layout(rect=[0, 0.05, 1, 0.90])
        plt.subplots_adjust(top=0.85, bottom=0.10, left=0.08, right=0.95, 
                        hspace=0.5, wspace=0.4)
        
        if save_dir:
            plt.figure(fig_linear.number)
            plt.savefig(os.path.join(save_dir, 'boundary_training_evolution_comparison_linear.png'), 
                    dpi=150, bbox_inches='tight', pad_inches=0.3,
                    facecolor='white', edgecolor='none')

        # Return both figures as a tuple
        return fig_log, fig_linear

    def create_training_evolution_comparison(self, save_dir=None, n_subplots=8):
        """Compare sigma-residual relationship across training steps - FIVE SCALES"""
        if not self.global_scatter_steps:
            return None
        
        plt.clf()
        plt.close('all')
        
        all_steps = np.array(self.global_scatter_steps)
        all_sigmas = np.array(self.global_scatter_sigmas)
        all_residuals = np.array(self.global_scatter_residuals)
        
        unique_steps = np.unique(all_steps)
        if len(unique_steps) < 2:
            return None
        
        if len(unique_steps) > n_subplots:
            step_indices = np.linspace(0, len(unique_steps)-1, n_subplots, dtype=int)
            selected_steps = unique_steps[step_indices]
        else:
            selected_steps = unique_steps

        n_rows = 2
        n_cols = int(np.ceil(len(selected_steps) / n_rows))
        
        # CREATE FIVE FIGURES - All existing plus two new zoomed versions
        fig_log, axes_log = plt.subplots(n_rows, n_cols, figsize=(6*n_cols, 5*n_rows))
        fig_linear, axes_linear = plt.subplots(n_rows, n_cols, figsize=(6*n_cols, 5*n_rows))
        fig_linear_clipped, axes_linear_clipped = plt.subplots(n_rows, n_cols, figsize=(6*n_cols, 5*n_rows))
        fig_log_zoomed, axes_log_zoomed = plt.subplots(n_rows, n_cols, figsize=(6*n_cols, 5*n_rows))          # NEW
        fig_linear_zoomed, axes_linear_zoomed = plt.subplots(n_rows, n_cols, figsize=(6*n_cols, 5*n_rows))    # NEW
    
        if len(selected_steps) == 1:
            axes_log = [axes_log]
            axes_linear = [axes_linear]
            axes_linear_clipped = [axes_linear_clipped]
            axes_log_zoomed = [axes_log_zoomed]           # NEW
            axes_linear_zoomed = [axes_linear_zoomed]     # NEW
        else:
            axes_log = axes_log.flatten()
            axes_linear = axes_linear.flatten()
            axes_linear_clipped = axes_linear_clipped.flatten()
            axes_log_zoomed = axes_log_zoomed.flatten()           # NEW
            axes_linear_zoomed = axes_linear_zoomed.flatten()     # NEW
    
        # CALCULATE GLOBAL LIMITS for shared axes
        all_sigma_min = np.min(all_sigmas)
        all_sigma_max = np.max(all_sigmas)
        all_residual_min = np.min(all_residuals)
        all_residual_max = np.max(all_residuals)
        
        # ZOOM THRESHOLD
        sigma_zoom_threshold = 1.0
    
        # CLIPPING LIMITS for enhanced visualization
        sigma_vmin_clip = np.percentile(all_sigmas, 5)
        sigma_vmax_clip = np.percentile(all_sigmas, 95)
        residual_vmin_clip = np.percentile(all_residuals, 5)
        residual_vmax_clip = np.percentile(all_residuals, 95)
        
        # Add some padding for better visualization
        sigma_padding = (all_sigma_max - all_sigma_min) * 0.1
        residual_padding = (all_residual_max - all_residual_min) * 0.1
    
        # Store correlations for all five versions
        correlations_log = []
        correlations_linear = []
        correlations_linear_clipped = []
        correlations_log_zoomed = []        # NEW
        correlations_linear_zoomed = []     # NEW
    
        # Track zoomed data statistics
        zoomed_step_counts = []
    
        for i, step in enumerate(selected_steps):
            if i >= len(axes_log):
                break
            
            # Get data for this specific step
            step_mask = all_steps == step
            step_sigmas = all_sigmas[step_mask]
            step_residuals = all_residuals[step_mask]
            
            # CREATE ZOOMED DATA (σ ≤ 1.0) for new plots
            zoom_mask = step_sigmas <= sigma_zoom_threshold
            step_sigmas_zoomed = step_sigmas[zoom_mask]
            step_residuals_zoomed = step_residuals[zoom_mask]
            zoomed_step_counts.append(len(step_sigmas_zoomed))
            
            if len(step_sigmas) == 0:
                # Handle no data case for all five plots
                for axes_set in [axes_log, axes_linear, axes_linear_clipped, axes_log_zoomed, axes_linear_zoomed]:
                    axes_set[i].text(0.5, 0.5, f'No data\nStep {step}', ha='center', va='center',
                                   transform=axes_set[i].transAxes, fontsize=12)
                continue
            
            # =============== LOG SCALE PLOT (UNCHANGED) ===============
            ax_log = axes_log[i]
            ax_log.scatter(step_sigmas, step_residuals, alpha=0.7, s=35, 
                          color='blue', edgecolors='black', linewidth=0.3)
            
            ax_log.set_xlabel('Sigma (log)', fontsize=11)
            ax_log.set_ylabel('PDE Residual (log)', fontsize=11)
            ax_log.set_title(f'Step {step} - Log Scale\n(n={len(step_sigmas)})', 
                            fontsize=12, fontweight='bold', pad=15)
            ax_log.set_xscale('log')
            ax_log.set_yscale('log')
            ax_log.grid(True, alpha=0.3)
            
            # SHARED LOG AXIS LIMITS
            ax_log.set_xlim(max(all_sigma_min*0.5, 1e-6), all_sigma_max*2)
            ax_log.set_ylim(max(all_residual_min*0.5, 1e-6), all_residual_max*2)
            
            # Log correlation and trend line
            if len(step_sigmas) > 2:
                try:
                    log_corr = np.corrcoef(np.log(step_sigmas + 1e-8), 
                                         np.log(step_residuals + 1e-8))[0,1]
                    correlations_log.append(log_corr)
                    
                    ax_log.text(0.05, 0.90, f'r={log_corr:.2f}', transform=ax_log.transAxes, 
                               fontsize=9, fontweight='bold',
                               bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.8))
                    
                    # Add trend line
                    if len(step_sigmas) > 5:
                        log_sigmas = np.log(step_sigmas + 1e-8)
                        log_residuals = np.log(step_residuals + 1e-8)
                        z = np.polyfit(log_sigmas, log_residuals, 1)
                        
                        sigma_range = np.logspace(np.log10(step_sigmas.min()), 
                                                np.log10(step_sigmas.max()), 50)
                        trend_residuals = np.exp(z[0] * np.log(sigma_range) + z[1])
                        ax_log.plot(sigma_range, trend_residuals, 'r--', alpha=0.8, linewidth=2)
                except:
                    correlations_log.append(0)
            else:
                correlations_log.append(0)
            
            # =============== LINEAR SCALE PLOT (UNCHANGED) ===============
            ax_linear = axes_linear[i]
            ax_linear.scatter(step_sigmas, step_residuals, alpha=0.7, s=35, 
                             color='red', edgecolors='black', linewidth=0.3)
            
            ax_linear.set_xlabel('Sigma (linear)', fontsize=11)
            ax_linear.set_ylabel('PDE Residual (linear)', fontsize=11)
            ax_linear.set_title(f'Step {step} - Linear Scale\n(n={len(step_sigmas)})', 
                               fontsize=12, fontweight='bold', pad=15)
            ax_linear.grid(True, alpha=0.3)
            
            # SHARED LINEAR AXIS LIMITS
            ax_linear.set_xlim(all_sigma_min - sigma_padding, all_sigma_max + sigma_padding)
            ax_linear.set_ylim(all_residual_min - residual_padding, all_residual_max + residual_padding)
            
            # Linear correlation and trend line
            if len(step_sigmas) > 2:
                try:
                    linear_corr = np.corrcoef(step_sigmas, step_residuals)[0,1]
                    correlations_linear.append(linear_corr)
                    
                    ax_linear.text(0.05, 0.90, f'r={linear_corr:.2f}', transform=ax_linear.transAxes, 
                                  fontsize=9, fontweight='bold',
                                  bbox=dict(boxstyle='round,pad=0.3', facecolor='lightgreen', alpha=0.8))
                    
                    # Add linear trend line
                    if len(step_sigmas) > 5:
                        z = np.polyfit(step_sigmas, step_residuals, 1)
                        sigma_range = np.linspace(step_sigmas.min(), step_sigmas.max(), 50)
                        trend_residuals = z[0] * sigma_range + z[1]
                        ax_linear.plot(sigma_range, trend_residuals, 'g--', alpha=0.8, linewidth=2)
                except:
                    correlations_linear.append(0)
            else:
                correlations_linear.append(0)
        
            # =============== CLIPPED LINEAR SCALE PLOT (UNCHANGED) ===============
            ax_linear_clipped = axes_linear_clipped[i]
            
            # APPLY CLIPPING to current step data
            step_sigma_clipped = np.clip(step_sigmas, sigma_vmin_clip, sigma_vmax_clip)
            step_residual_clipped = np.clip(step_residuals, residual_vmin_clip, residual_vmax_clip)
            
            # Count outliers for this step
            sigma_outliers = np.sum((step_sigmas < sigma_vmin_clip) | (step_sigmas > sigma_vmax_clip))
            residual_outliers = np.sum((step_residuals < residual_vmin_clip) | (step_residuals > residual_vmax_clip))
            total_outliers = max(sigma_outliers, residual_outliers)  # Use max to avoid double counting
            
            # Color points by whether they were clipped
            colors = []
            for sig, res in zip(step_sigmas, step_residuals):
                if (sig < sigma_vmin_clip or sig > sigma_vmax_clip or 
                    res < residual_vmin_clip or res > residual_vmax_clip):
                    colors.append('orange')  # Clipped points
                else:
                    colors.append('purple')   # Normal points
            
            ax_linear_clipped.scatter(step_sigma_clipped, step_residual_clipped, 
                                     alpha=0.7, s=35, c=colors, edgecolors='black', linewidth=0.3)
            
            ax_linear_clipped.set_xlabel('Sigma (linear, clipped)', fontsize=11)
            ax_linear_clipped.set_ylabel('PDE Residual (linear, clipped)', fontsize=11)
            ax_linear_clipped.set_title(f'Step {step} - Linear Clipped\n(n={len(step_sigmas)}, {total_outliers} clipped)', 
                                       fontsize=12, fontweight='bold', pad=15)
            ax_linear_clipped.grid(True, alpha=0.3)
            
            # SHARED CLIPPED AXIS LIMITS
            clipped_sigma_padding = (sigma_vmax_clip - sigma_vmin_clip) * 0.05
            clipped_residual_padding = (residual_vmax_clip - residual_vmin_clip) * 0.05
            
            ax_linear_clipped.set_xlim(sigma_vmin_clip - clipped_sigma_padding, 
                                      sigma_vmax_clip + clipped_sigma_padding)
            ax_linear_clipped.set_ylim(residual_vmin_clip - clipped_residual_padding, 
                                      residual_vmax_clip + clipped_residual_padding)
            
            # Clipped linear correlation and trend line
            if len(step_sigma_clipped) > 2:
                try:
                    clipped_corr = np.corrcoef(step_sigma_clipped, step_residual_clipped)[0,1]
                    correlations_linear_clipped.append(clipped_corr)
                    
                    ax_linear_clipped.text(0.05, 0.90, f'r={clipped_corr:.2f}', 
                                          transform=ax_linear_clipped.transAxes, 
                                          fontsize=9, fontweight='bold',
                                          bbox=dict(boxstyle='round,pad=0.3', facecolor='lightcyan', alpha=0.8))
                    
                    # Add clipped trend line
                    if len(step_sigma_clipped) > 5:
                        z = np.polyfit(step_sigma_clipped, step_residual_clipped, 1)
                        sigma_range = np.linspace(step_sigma_clipped.min(), step_sigma_clipped.max(), 50)
                        trend_residuals = z[0] * sigma_range + z[1]
                        ax_linear_clipped.plot(sigma_range, trend_residuals, 'orange', 
                                              linestyle='--', alpha=0.8, linewidth=2)
                except:
                    correlations_linear_clipped.append(0)
            else:
                correlations_linear_clipped.append(0)
            
            # Add legend for clipped plot
            if i == 0:  # Only add legend to first subplot
                from matplotlib.patches import Patch
                legend_elements = [Patch(facecolor='purple', alpha=0.7, label='Normal points'),
                                  Patch(facecolor='orange', alpha=0.7, label='Clipped points')]
                ax_linear_clipped.legend(handles=legend_elements, loc='upper right', fontsize=9)
    
            # =============== 🆕 LOG SCALE ZOOMED PLOT (NEW) ===============
            ax_log_zoomed = axes_log_zoomed[i]
            
            if len(step_sigmas_zoomed) == 0:
                ax_log_zoomed.text(0.5, 0.5, f'No data σ≤{sigma_zoom_threshold}\nStep {step}', 
                                  ha='center', va='center', transform=ax_log_zoomed.transAxes, fontsize=11,
                                  bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))
                ax_log_zoomed.set_title(f'Step {step} - Log Zoomed\n(no low σ data)', 
                                       fontsize=12, fontweight='bold', pad=15)
                correlations_log_zoomed.append(0)
            else:
                ax_log_zoomed.scatter(step_sigmas_zoomed, step_residuals_zoomed, alpha=0.7, s=35, 
                                     color='darkblue', edgecolors='black', linewidth=0.3)
                
                ax_log_zoomed.set_xlabel('Sigma (log, σ≤1.0)', fontsize=11)
                ax_log_zoomed.set_ylabel('PDE Residual (log)', fontsize=11)
                ax_log_zoomed.set_title(f'Step {step} - Log Zoomed σ≤1.0\n(n={len(step_sigmas_zoomed)}/{len(step_sigmas)})', 
                                       fontsize=12, fontweight='bold', pad=15)
                ax_log_zoomed.set_xscale('log')
                ax_log_zoomed.set_yscale('log')
                ax_log_zoomed.grid(True, alpha=0.3)
                
                # ZOOMED LOG AXIS LIMITS (focus on low sigma range)
                zoom_sigma_min = max(step_sigmas_zoomed.min()*0.8, 1e-6)
                zoom_sigma_max = min(step_sigmas_zoomed.max()*1.2, sigma_zoom_threshold*1.1)
                ax_log_zoomed.set_xlim(zoom_sigma_min, zoom_sigma_max)
                
                # Use residuals from zoomed data for Y-axis
                zoom_residual_min = max(step_residuals_zoomed.min()*0.8, 1e-6)
                zoom_residual_max = step_residuals_zoomed.max()*1.2
                ax_log_zoomed.set_ylim(zoom_residual_min, zoom_residual_max)
                
                # Zoomed log correlation and trend line
                if len(step_sigmas_zoomed) > 2:
                    try:
                        log_zoomed_corr = np.corrcoef(np.log(step_sigmas_zoomed + 1e-8), 
                                                     np.log(step_residuals_zoomed + 1e-8))[0,1]
                        correlations_log_zoomed.append(log_zoomed_corr)
                        
                        ax_log_zoomed.text(0.05, 0.90, f'r={log_zoomed_corr:.2f}', 
                                          transform=ax_log_zoomed.transAxes, 
                                          fontsize=9, fontweight='bold',
                                          bbox=dict(boxstyle='round,pad=0.3', facecolor='lightblue', alpha=0.8))
                        
                        # Add zoomed trend line
                        if len(step_sigmas_zoomed) > 5:
                            log_sigmas_z = np.log(step_sigmas_zoomed + 1e-8)
                            log_residuals_z = np.log(step_residuals_zoomed + 1e-8)
                            z = np.polyfit(log_sigmas_z, log_residuals_z, 1)
                            
                            sigma_range_z = np.logspace(np.log10(step_sigmas_zoomed.min()), 
                                                       np.log10(step_sigmas_zoomed.max()), 50)
                            trend_residuals_z = np.exp(z[0] * np.log(sigma_range_z) + z[1])
                            ax_log_zoomed.plot(sigma_range_z, trend_residuals_z, 'navy', 
                                              linestyle='--', alpha=0.8, linewidth=2)
                    except:
                        correlations_log_zoomed.append(0)
                else:
                    correlations_log_zoomed.append(0)
        
            # =============== 🆕 LINEAR SCALE ZOOMED PLOT (NEW) ===============
            ax_linear_zoomed = axes_linear_zoomed[i]
            
            if len(step_sigmas_zoomed) == 0:
                ax_linear_zoomed.text(0.5, 0.5, f'No data σ≤{sigma_zoom_threshold}\nStep {step}', 
                                     ha='center', va='center', transform=ax_linear_zoomed.transAxes, fontsize=11,
                                     bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))
                ax_linear_zoomed.set_title(f'Step {step} - Linear Zoomed\n(no low σ data)', 
                                          fontsize=12, fontweight='bold', pad=15)
                correlations_linear_zoomed.append(0)
            else:
                ax_linear_zoomed.scatter(step_sigmas_zoomed, step_residuals_zoomed, alpha=0.7, s=35, 
                                        color='darkred', edgecolors='black', linewidth=0.3)
                
                ax_linear_zoomed.set_xlabel('Sigma (linear, σ≤1.0)', fontsize=11)
                ax_linear_zoomed.set_ylabel('PDE Residual (linear)', fontsize=11)
                ax_linear_zoomed.set_title(f'Step {step} - Linear Zoomed σ≤1.0\n(n={len(step_sigmas_zoomed)}/{len(step_sigmas)})', 
                                          fontsize=12, fontweight='bold', pad=15)
                ax_linear_zoomed.grid(True, alpha=0.3)
                
                # ZOOMED LINEAR AXIS LIMITS
                zoom_sigma_padding = (step_sigmas_zoomed.max() - step_sigmas_zoomed.min()) * 0.1
                zoom_residual_padding = (step_residuals_zoomed.max() - step_residuals_zoomed.min()) * 0.1
                
                ax_linear_zoomed.set_xlim(max(0, step_sigmas_zoomed.min() - zoom_sigma_padding), 
                                         min(step_sigmas_zoomed.max() + zoom_sigma_padding, sigma_zoom_threshold*1.05))
                ax_linear_zoomed.set_ylim(max(0, step_residuals_zoomed.min() - zoom_residual_padding), 
                                         step_residuals_zoomed.max() + zoom_residual_padding)
                
                # Zoomed linear correlation and trend line
                if len(step_sigmas_zoomed) > 2:
                    try:
                        linear_zoomed_corr = np.corrcoef(step_sigmas_zoomed, step_residuals_zoomed)[0,1]
                        correlations_linear_zoomed.append(linear_zoomed_corr)
                        
                        ax_linear_zoomed.text(0.05, 0.90, f'r={linear_zoomed_corr:.2f}', 
                                             transform=ax_linear_zoomed.transAxes, 
                                             fontsize=9, fontweight='bold',
                                             bbox=dict(boxstyle='round,pad=0.3', facecolor='pink', alpha=0.8))
                        
                        # Add zoomed linear trend line
                        if len(step_sigmas_zoomed) > 5:
                            z = np.polyfit(step_sigmas_zoomed, step_residuals_zoomed, 1)
                            sigma_range_z = np.linspace(step_sigmas_zoomed.min(), step_sigmas_zoomed.max(), 50)
                            trend_residuals_z = z[0] * sigma_range_z + z[1]
                            ax_linear_zoomed.plot(sigma_range_z, trend_residuals_z, 'darkred', 
                                                 linestyle='--', alpha=0.8, linewidth=2)
                    except:
                        correlations_linear_zoomed.append(0)
                else:
                    correlations_linear_zoomed.append(0)
    
        # Hide unused subplots for all five figures
        for i in range(len(selected_steps), len(axes_log)):
            axes_log[i].set_visible(False)
            axes_linear[i].set_visible(False)
            axes_linear_clipped[i].set_visible(False)
            axes_log_zoomed[i].set_visible(False)           # NEW
            axes_linear_zoomed[i].set_visible(False)        # NEW
    
        # =============== FIGURE FINALIZATIONS ===============
        total_duration = unique_steps[-1] - unique_steps[0] if len(unique_steps) > 1 else 0
        
        # Calculate global zoom statistics
        all_zoom_mask = all_sigmas <= sigma_zoom_threshold
        total_zoom_points = np.sum(all_zoom_mask)
        zoom_percentage = 100 * total_zoom_points / len(all_sigmas) if len(all_sigmas) > 0 else 0
    
        # 1. LOG SCALE FIGURE (UNCHANGED)
        avg_log_corr = np.mean(correlations_log) if correlations_log else 0
        
        fig_log.suptitle(f'Training Evolution: Sigma vs PDE Residual (LOG SCALE)\n' + 
                        f'Duration: {total_duration} steps | Analyzed: {len(selected_steps)} of {len(unique_steps)} total | ' +
                        f'Avg Log Correlation: {avg_log_corr:.3f}',
                        fontsize=14, fontweight='bold', y=0.96)
        
        plt.figure(fig_log.number)
        plt.tight_layout(rect=[0, 0.05, 1, 0.90])
        plt.subplots_adjust(top=0.85, bottom=0.10, left=0.08, right=0.95, 
                           hspace=0.5, wspace=0.4)
        
        if save_dir:
            plt.figure(fig_log.number)
            plt.savefig(os.path.join(save_dir, 'training_evolution_comparison_log.png'), 
                       dpi=150, bbox_inches='tight', pad_inches=0.3,
                       facecolor='white', edgecolor='none')
        
        # 2. LINEAR SCALE FIGURE (UNCHANGED)
        avg_linear_corr = np.mean(correlations_linear) if correlations_linear else 0
        
        fig_linear.suptitle(f'Training Evolution: Sigma vs PDE Residual (LINEAR SCALE)\n' + 
                           f'Duration: {total_duration} steps | Analyzed: {len(selected_steps)} of {len(unique_steps)} total | ' +
                           f'Avg Linear Correlation: {avg_linear_corr:.3f}',
                           fontsize=14, fontweight='bold', y=0.96)
        
        plt.figure(fig_linear.number)
        plt.tight_layout(rect=[0, 0.05, 1, 0.90])
        plt.subplots_adjust(top=0.85, bottom=0.10, left=0.08, right=0.95, 
                           hspace=0.5, wspace=0.4)
        
        if save_dir:
            plt.figure(fig_linear.number)
            plt.savefig(os.path.join(save_dir, 'training_evolution_comparison_linear.png'), 
                       dpi=150, bbox_inches='tight', pad_inches=0.3,
                       facecolor='white', edgecolor='none')
        
        # 3. CLIPPED LINEAR SCALE FIGURE (UNCHANGED)
        avg_linear_clipped_corr = np.mean(correlations_linear_clipped) if correlations_linear_clipped else 0
        
        # Calculate global clipping statistics
        total_sigma_outliers = np.sum((all_sigmas < sigma_vmin_clip) | (all_sigmas > sigma_vmax_clip))
        total_residual_outliers = np.sum((all_residuals < residual_vmin_clip) | (all_residuals > residual_vmax_clip))
        total_points = len(all_sigmas)
        
        fig_linear_clipped.suptitle(f'Training Evolution: Sigma vs PDE Residual (LINEAR CLIPPED 5%-95%)\n' + 
                                   f'Duration: {total_duration} steps | Analyzed: {len(selected_steps)} of {len(unique_steps)} total | ' +
                                   f'Avg Clipped Correlation: {avg_linear_clipped_corr:.3f} | ' +
                                   f'Outliers: σ={total_sigma_outliers}/{total_points}, res={total_residual_outliers}/{total_points}',
                                   fontsize=13, fontweight='bold', y=0.97)
        
        plt.figure(fig_linear_clipped.number)
        plt.tight_layout(rect=[0, 0.05, 1, 0.88])
        plt.subplots_adjust(top=0.82, bottom=0.10, left=0.08, right=0.95, 
                           hspace=0.5, wspace=0.4)
        
        if save_dir:
            plt.figure(fig_linear_clipped.number)
            plt.savefig(os.path.join(save_dir, 'training_evolution_comparison_linear_clipped.png'), 
                       dpi=150, bbox_inches='tight', pad_inches=0.3,
                       facecolor='white', edgecolor='none')
        
        # 4. 🆕 LOG SCALE ZOOMED FIGURE (NEW)
        avg_log_zoomed_corr = np.mean([c for c in correlations_log_zoomed if c != 0]) if any(c != 0 for c in correlations_log_zoomed) else 0
        avg_zoom_count = np.mean(zoomed_step_counts) if zoomed_step_counts else 0
        
        fig_log_zoomed.suptitle(f'Training Evolution: Sigma vs PDE Residual (LOG ZOOMED σ≤{sigma_zoom_threshold})\n' + 
                               f'Duration: {total_duration} steps | Analyzed: {len(selected_steps)} of {len(unique_steps)} total | ' +
                               f'Avg Zoomed Log Correlation: {avg_log_zoomed_corr:.3f} | ' +
                               f'Low σ data: {total_zoom_points}/{len(all_sigmas)} ({zoom_percentage:.1f}%)',
                               fontsize=13, fontweight='bold', y=0.97)
        
        plt.figure(fig_log_zoomed.number)
        plt.tight_layout(rect=[0, 0.05, 1, 0.88])
        plt.subplots_adjust(top=0.82, bottom=0.10, left=0.08, right=0.95, 
                           hspace=0.5, wspace=0.4)
        
        if save_dir:
            plt.figure(fig_log_zoomed.number)
            plt.savefig(os.path.join(save_dir, 'training_evolution_comparison_log_zoomed.png'), 
                       dpi=150, bbox_inches='tight', pad_inches=0.3,
                       facecolor='white', edgecolor='none')
        
        # 5. 🆕 LINEAR SCALE ZOOMED FIGURE (NEW)
        avg_linear_zoomed_corr = np.mean([c for c in correlations_linear_zoomed if c != 0]) if any(c != 0 for c in correlations_linear_zoomed) else 0
        
        fig_linear_zoomed.suptitle(f'Training Evolution: Sigma vs PDE Residual (LINEAR ZOOMED σ≤{sigma_zoom_threshold})\n' + 
                                  f'Duration: {total_duration} steps | Analyzed: {len(selected_steps)} of {len(unique_steps)} total | ' +
                                  f'Avg Zoomed Linear Correlation: {avg_linear_zoomed_corr:.3f} | ' +
                                  f'Low σ data: {total_zoom_points}/{len(all_sigmas)} ({zoom_percentage:.1f}%)',
                                  fontsize=13, fontweight='bold', y=0.97)
        
        plt.figure(fig_linear_zoomed.number)
        plt.tight_layout(rect=[0, 0.05, 1, 0.88])
        plt.subplots_adjust(top=0.82, bottom=0.10, left=0.08, right=0.95, 
                           hspace=0.5, wspace=0.4)
        
        if save_dir:
            plt.figure(fig_linear_zoomed.number)
            plt.savefig(os.path.join(save_dir, 'training_evolution_comparison_linear_zoomed.png'), 
                       dpi=150, bbox_inches='tight', pad_inches=0.3,
                       facecolor='white', edgecolor='none')
    
        # Return all five figures as a tuple
        return fig_log, fig_linear, fig_linear_clipped, fig_log_zoomed, fig_linear_zoomed

    
    def create_sigma_vs_residual_single_step(self, save_dir=None, step=None):
        """Clean single-step analysis without training step coloring - FIXED"""
        if not self.recent_sigmas:
            return None
        
        plt.clf()
        plt.close('all')
        
        fig, axes = plt.subplots(1, 3, figsize=(22, 8))
    
        # Get data for CURRENT STEP ONLY
        n_recent = min(50, len(self.recent_sigmas))
        sigmas = np.array(self.recent_sigmas[-n_recent:])
        residuals = np.array(self.recent_residuals[-n_recent:])
        steps_recent = np.array(self.recent_steps[-n_recent:])
        
        if step is not None:
            recent_mask = np.abs(steps_recent - step) <= 2
            if np.any(recent_mask):
                sigmas = sigmas[recent_mask]
                residuals = residuals[recent_mask]
                steps_recent = steps_recent[recent_mask]

        if len(sigmas) == 0:
            plt.close(fig)
            return None

        # 1. MAIN SCATTER PLOT
        ax1 = axes[0]
        ax1.scatter(sigmas, residuals, alpha=0.7, s=60, color='blue', 
                   edgecolors='black', linewidth=0.5)
        ax1.set_xlabel('Sigma Value (log scale)', fontsize=12)
        ax1.set_ylabel('PDE Residual (log scale)', fontsize=12)
        ax1.set_title(f'Sigma vs PDE Residual\nStep {step} (n={len(sigmas)})', 
                     fontsize=13, fontweight='bold', pad=20)
        ax1.set_xscale('log')
        ax1.set_yscale('log')
        ax1.grid(True, alpha=0.3)
        
        if len(sigmas) > 2:
            try:
                log_corr = np.corrcoef(np.log(sigmas + 1e-8), np.log(residuals + 1e-8))[0,1]
                ax1.text(0.05, 0.85, f'Log Correlation: {log_corr:.3f}\n(Power-law relationship)', 
                        transform=ax1.transAxes, fontsize=10,
                        bbox=dict(boxstyle='round,pad=0.4', facecolor='yellow', alpha=0.8))
            except:
                pass

        # 2. BINNED ANALYSIS
        ax2 = axes[1]
        if len(sigmas) >= 5:
            n_bins = min(6, len(np.unique(sigmas)))
            if np.max(sigmas) / np.min(sigmas) > 100:
                sigma_bins = np.logspace(np.log10(max(sigmas.min(), 1e-6)), 
                                       np.log10(sigmas.max()), n_bins)
                use_log = True
            else:
                sigma_bins = np.linspace(sigmas.min(), sigmas.max(), n_bins)
                use_log = False
            
            bin_indices = np.digitize(sigmas, sigma_bins)
            bin_centers, bin_means, bin_stds, bin_counts = [], [], [], []
            
            for i in range(1, len(sigma_bins)):
                mask = bin_indices == i
                if np.any(mask):
                    if use_log:
                        bin_centers.append(np.sqrt(sigma_bins[i-1] * sigma_bins[i]))
                    else:
                        bin_centers.append((sigma_bins[i-1] + sigma_bins[i]) / 2)
                    bin_means.append(np.mean(residuals[mask]))
                    bin_stds.append(np.std(residuals[mask]))
                    bin_counts.append(np.sum(mask))
            
            if bin_centers:
                ax2.errorbar(bin_centers, bin_means, yerr=bin_stds, 
                            fmt='o-', capsize=5, linewidth=2, markersize=6,
                            color='red', markerfacecolor='orange',
                            markeredgecolor='darkred', markeredgewidth=1)
                
                for x, y, count in zip(bin_centers, bin_means, bin_counts):
                    ax2.annotate(f'n={count}', (x, y), xytext=(0, 15), 
                               textcoords='offset points', fontsize=9, 
                               ha='center', fontweight='bold',
                               bbox=dict(boxstyle='round,pad=0.2', facecolor='yellow', alpha=0.7))
                
                x_label = 'Sigma (log)' if use_log else 'Sigma Value'
                y_label = 'Mean Residual ± Std (log)' if use_log else 'Mean Residual ± Std'
                ax2.set_xlabel(x_label, fontsize=12)
                ax2.set_ylabel(y_label, fontsize=12)
                ax2.set_title(f'Binned Analysis\n({len(bin_centers)} bins)', 
                             fontsize=13, fontweight='bold', pad=20)
                
                if use_log:
                    ax2.set_xscale('log')
                    ax2.set_yscale('log')
                ax2.grid(True, alpha=0.3)
        else:
            ax2.text(0.5, 0.5, 'Need ≥5 points\nfor binning', 
                    ha='center', va='center', transform=ax2.transAxes, fontsize=11,
                    bbox=dict(boxstyle='round', facecolor='lightgray'))
            ax2.set_title('Binned Analysis\n(Insufficient data)', fontsize=13, pad=20)

        # 3. DISTRIBUTION COMPARISON
        ax3 = axes[2]
        if len(sigmas) > 0:
            sigma_percentiles = np.percentile(sigmas, [33, 67])
            
            easy_mask = sigmas <= sigma_percentiles[0]
            medium_mask = (sigmas > sigma_percentiles[0]) & (sigmas <= sigma_percentiles[1])
            hard_mask = sigmas > sigma_percentiles[1]
            
            easy_range = f'≤{sigma_percentiles[0]:.2f}'
            medium_range = f'{sigma_percentiles[0]:.2f}-{sigma_percentiles[1]:.2f}'
            hard_range = f'>{sigma_percentiles[1]:.2f}'
            
            categories = [f'Low σ\n{easy_range}', f'Med σ\n{medium_range}', f'High σ\n{hard_range}']
            means = []
            stds = []
            counts = []
            
            for mask in [easy_mask, medium_mask, hard_mask]:
                if np.any(mask):
                    means.append(np.mean(residuals[mask]))
                    stds.append(np.std(residuals[mask]))
                    counts.append(np.sum(mask))
                else:
                    means.append(0)
                    stds.append(0)
                    counts.append(0)
            
            bars = ax3.bar(categories, means, yerr=stds, capsize=4, alpha=0.7,
                          color=['green', 'yellow', 'red'], 
                          edgecolor='black', linewidth=1)
            
            for bar, count in zip(bars, counts):
                height = bar.get_height()
                if height > 0:
                    ax3.annotate(f'n={count}', xy=(bar.get_x() + bar.get_width()/2, height),
                                xytext=(0, 5), textcoords="offset points",
                                ha='center', va='bottom', fontweight='bold', fontsize=10)
            
            ax3.set_ylabel('Mean PDE Residual', fontsize=12)
            ax3.set_title('Performance by\nDifficulty Level', fontsize=13, fontweight='bold', pad=20)
            ax3.grid(True, alpha=0.3, axis='y')
            ax3.tick_params(axis='x', labelsize=10)
            ax3.tick_params(axis='y', labelsize=10)

        # FIXED: Proper title formatting
        fig.suptitle(f'Sigma vs PDE Residual Analysis - Step {step}', 
                    fontsize=16, fontweight='bold', y=0.95)
        
        plt.tight_layout(rect=[0, 0.05, 1, 0.92])
        plt.subplots_adjust(top=0.85, bottom=0.12, left=0.08, right=0.95, wspace=0.4)
        
        if save_dir:
            filename = f'sigma_vs_residual_single_step_{step}.png'
            plt.savefig(os.path.join(save_dir, filename), dpi=150, 
                       bbox_inches='tight', pad_inches=0.2,
                       facecolor='white', edgecolor='none')
        
        return fig

    def create_boundary_residual_analysis(self, save_dir=None, step=None):
        """Create comprehensive analysis of boundary vs interior residuals"""
        if not self.recent_boundary2_residuals:
            return None
            
        # CLEAR MATPLOTLIB STATE
        plt.clf()
        plt.close('all')
        
        # Create 2x2 subplot layout
        fig, axes = plt.subplots(2, 2, figsize=(16, 12))
        
        # Get recent data (last 50 samples)
        n_recent = min(50, len(self.recent_sigmas))
        sigmas = np.array(self.recent_sigmas[-n_recent:])
        interior_residuals = np.array(self.recent_residuals[-n_recent:])
        boundary2_residuals = np.array(self.recent_boundary2_residuals[-n_recent:])
        boundary1_residuals = np.array(self.recent_boundary1_residuals[-n_recent:])
        
        if len(sigmas) == 0:
            plt.close(fig)
            return None
        
        # 1. COMPARISON PLOT: Interior vs Boundary residuals
        ax1 = axes[0, 0]
        x_pos = np.arange(len(sigmas))
        width = 0.25
        
        ax1.bar(x_pos - width, interior_residuals, width, label='Interior', alpha=0.7, color='blue')
        ax1.bar(x_pos, boundary2_residuals, width, label='2-pixel Boundary', alpha=0.7, color='orange')
        ax1.bar(x_pos + width, boundary1_residuals, width, label='1-pixel Boundary', alpha=0.7, color='red')
        
        ax1.set_xlabel('Sample Index', fontsize=12)
        ax1.set_ylabel('PDE Residual', fontsize=12)
        ax1.set_title('Residual Comparison: Interior vs Boundaries', fontsize=14, fontweight='bold')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # 2. SCATTER PLOT: Sigma vs Residuals (all types)
        ax2 = axes[0, 1]
        ax2.scatter(sigmas, interior_residuals, alpha=0.7, s=50, label='Interior', color='blue')
        ax2.scatter(sigmas, boundary2_residuals, alpha=0.7, s=50, label='2-pixel Boundary', color='orange')
        ax2.scatter(sigmas, boundary1_residuals, alpha=0.7, s=50, label='1-pixel Boundary', color='red')
        
        ax2.set_xlabel('Sigma Value', fontsize=12)
        ax2.set_ylabel('PDE Residual', fontsize=12)
        ax2.set_title('Sigma vs Residuals: All Types', fontsize=14, fontweight='bold')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        # 3. RATIO ANALYSIS: Boundary/Interior ratios
        ax3 = axes[1, 0]
        if np.mean(interior_residuals) > 0:
            boundary2_ratio = boundary2_residuals / interior_residuals
            boundary1_ratio = boundary1_residuals / interior_residuals
            
            ax3.scatter(sigmas, boundary2_ratio, alpha=0.7, s=50, label='2-pixel Boundary/Interior', color='orange')
            ax3.scatter(sigmas, boundary1_ratio, alpha=0.7, s=50, label='1-pixel Boundary/Interior', color='red')
            ax3.axhline(y=1.0, color='black', linestyle='--', alpha=0.5, label='Equal to Interior')
            
            ax3.set_xlabel('Sigma Value', fontsize=12)
            ax3.set_ylabel('Boundary/Interior Ratio', fontsize=12)
            ax3.set_title('Boundary vs Interior Performance Ratio', fontsize=14, fontweight='bold')
            ax3.legend()
            ax3.grid(True, alpha=0.3)
        else:
            ax3.text(0.5, 0.5, 'No interior residuals\nfor ratio calculation', 
                    ha='center', va='center', transform=ax3.transAxes, fontsize=14)
            ax3.set_title('Boundary vs Interior Ratio\n(No data)', fontsize=14)
        
        # 4. STATISTICS SUMMARY
        ax4 = axes[1, 1]
        ax4.axis('off')
        
        # Calculate statistics
        stats_text = f"""Boundary Residual Analysis Summary
        
Interior Residuals:
  Mean: {np.mean(interior_residuals):.4f}
  Std:  {np.std(interior_residuals):.4f}
  Max:  {np.max(interior_residuals):.4f}
  Skewness: {self._calculate_skewness_from_array(interior_residuals):.4f}
  Kurtosis: {self._calculate_kurtosis_from_array(interior_residuals):.4f}

2-pixel Boundary:
  Mean: {np.mean(boundary2_residuals):.4f}
  Std:  {np.std(boundary2_residuals):.4f}
  Max:  {np.max(boundary2_residuals):.4f}
  Skewness: {self._calculate_skewness_from_array(boundary2_residuals):.4f}
  Kurtosis: {self._calculate_kurtosis_from_array(boundary2_residuals):.4f}

1-pixel Boundary:
  Mean: {np.mean(boundary1_residuals):.4f}
  Std:  {np.std(boundary1_residuals):.4f}
  Max:  {np.max(boundary1_residuals):.4f}
  Skewness: {self._calculate_skewness_from_array(boundary1_residuals):.4f}
  Kurtosis: {self._calculate_kurtosis_from_array(boundary1_residuals):.4f}

Ratios (Boundary/Interior):
  2-pixel: {np.mean(boundary2_residuals)/np.mean(interior_residuals):.2f}x
  1-pixel: {np.mean(boundary1_residuals)/np.mean(interior_residuals):.2f}x

Sample Count: {len(sigmas)}"""
        
        ax4.text(0.05, 0.95, stats_text, transform=ax4.transAxes, 
                verticalalignment='top', fontsize=10, fontfamily='monospace',
                bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))
        
        # Overall title
        fig.suptitle(f'Boundary Residual Analysis - Step {step}', 
                    fontsize=16, fontweight='bold', y=0.98)
        
        plt.tight_layout()
        plt.subplots_adjust(top=0.92)
        
        if save_dir:
            filename = f'boundary_residual_analysis_step_{step}.png'
            plt.savefig(os.path.join(save_dir, filename), dpi=150, 
                       bbox_inches='tight', pad_inches=0.2)
        
        return fig

    def create_thresholded_residual_maps(self, pde_residuals, sigmas, step, save_dir=None, n_samples=4):
        """Create simplified thresholded residual maps showing high residual regions
        
        Args:
            pde_residuals: Tensor of shape [batch_size, channels, height, width] or [batch_size, height, width]
            sigmas: Tensor of shape [batch_size]
            step: Current training step
            save_dir: Directory to save plots
            n_samples: Number of samples to visualize (default: 4)
            
        Returns:
            matplotlib figure or None if no data
        """
        if pde_residuals is None or pde_residuals.numel() == 0:
            return None
            
        # Handle tensor dimensions and select samples
        if len(pde_residuals.shape) == 4:
            batch_size, channels, height, width = pde_residuals.shape
            # Average across channels if multiple channels exist
            if channels > 1:
                residuals = torch.mean(pde_residuals, dim=1)  # [batch, height, width]
            else:
                residuals = pde_residuals.squeeze(1)  # Remove channel dimension
        elif len(pde_residuals.shape) == 3:
            batch_size, height, width = pde_residuals.shape
            residuals = pde_residuals
        else:
            print(f"Unexpected tensor shape: {pde_residuals.shape}")
            return None
        
        # Select samples to visualize
        n_samples = min(n_samples, batch_size)
        sample_indices = torch.randperm(batch_size)[:n_samples] if batch_size > n_samples else torch.arange(batch_size)
        
        # Create subplot layout: n_samples rows, 2 columns (original, high residual regions)
        fig, axes = plt.subplots(n_samples, 2, figsize=(12, 6*n_samples))
        if n_samples == 1:
            axes = axes.reshape(1, -1)
        
        # Set seaborn style for better-looking plots
        # sns.set_style("whitegrid", {'grid.linestyle': '--', 'grid.alpha': 0.3})
        # sns.set_palette("husl")
        
        # # Set figure style
        # plt.rcParams['figure.facecolor'] = 'white'
        # plt.rcParams['axes.facecolor'] = 'white'
        
        for i, sample_idx in enumerate(sample_indices):
            sample_residual = residuals[sample_idx]  # [height, width]
            sample_sigma = sigmas[sample_idx]
            
            # Convert tensor to scalar for formatting
            sample_sigma_scalar = sample_sigma.item() if hasattr(sample_sigma, 'item') else float(sample_sigma)
            
            # Calculate threshold based on absolute residual values
            # This ensures we capture both positive and negative high residuals
            abs_residuals = torch.abs(sample_residual)
            threshold = torch.mean(abs_residuals)
            
            # Create thresholded map: values with abs magnitude below threshold become 0
            # Values above threshold show their original sign and magnitude
            thresholded_map = torch.where(abs_residuals >= threshold, 
                                       sample_residual,  # Show original values (with sign) for high residuals
                                       torch.zeros_like(sample_residual))  # Below threshold = 0
            
            # Get actual value ranges for proper colormap scaling
            residual_min = torch.min(sample_residual).item()
            residual_max = torch.max(sample_residual).item()
            thresholded_max = torch.max(thresholded_map).item()
            
            # For thresholded map, we want to show the range from -max_abs to +max_abs
            # This ensures the colormap can show both positive and negative high residuals
            max_abs_residual = torch.max(abs_residuals).item()
            thresholded_vmin = threshold.item()
            thresholded_vmax = max_abs_residual
            
            # Debug: Print actual values to understand scaling
            # print(f"Sample {i+1} - Original Mean: {torch.mean(sample_residual):.6f}, Max: {residual_max:.6f}, Min: {residual_min:.6f}")
            # print(f"Sample {i+1} - Abs Mean (threshold): {threshold:.6f}, Max Abs: {max_abs_residual:.6f}")
            # print(f"Sample {i+1} - Thresholded range: [{thresholded_vmin:.6f}, {thresholded_vmax:.6f}]")
            # print(f"Sample {i+1} - Raw tensor stats: mean={sample_residual.mean():.6f}, std={sample_residual.std():.6f}")
            # print(f"Sample {i+1} - Tensor range: [{sample_residual.min():.6f}, {sample_residual.max():.6f}]")
            # print(f"Sample {i+1} - Data type: {sample_residual.dtype}, Device: {sample_residual.device}")
            # print(f"Sample {i+1} - Has NaN: {torch.isnan(sample_residual).any()}, Has Inf: {torch.isinf(sample_residual).any()}")
            
            # Plot 1: Original residual map
            ax1 = axes[i, 0]
            original_numpy = sample_residual.detach().cpu().numpy()
            # print(f"Sample {i+1} - Original numpy range: [{original_numpy.min():.6f}, {original_numpy.max():.6f}]")
            
            im1 = ax1.imshow(original_numpy, cmap='RdBu_r', 
                            vmin=residual_min, vmax=residual_max)
            # Use seaborn heatmap for better visualization
            # sns.heatmap(original_numpy, ax=ax1, cmap='RdBu_r', 
            #            vmin=residual_min, vmax=residual_max,
            #            cbar_kws={'shrink': 0.8, 'label': 'Residual Value'},
            #            xticklabels=False, yticklabels=False,
            #            square=True, linewidths=0.5, linecolor='white')
            ax1.set_title(f'Sample {i+1}: Original Residual\nσ={sample_sigma_scalar:.3f}', fontsize=12, fontweight='bold')
            ax1.set_xlabel('Width', fontsize=10)
            ax1.set_ylabel('Height', fontsize=10)
            plt.colorbar(im1, ax=ax1, shrink=0.8)
            
            # Plot 2: High residual regions only (values ≥ mean)
            ax2 = axes[i, 1]
            thresholded_numpy = thresholded_map.detach().cpu().numpy()
            # print(f"Sample {i+1} - Numpy array range: [{thresholded_numpy.min():.6f}, {thresholded_numpy.max():.6f}]")
            # print(f"Sample {i+1} - Numpy array shape: {thresholded_numpy.shape}")
            
            im2 = ax2.imshow(thresholded_numpy, cmap='Reds', 
                            vmin=thresholded_vmin, vmax=thresholded_vmax)
            # sns.heatmap(thresholded_numpy, ax=ax2, cmap='Reds', 
            #            vmin=thresholded_vmin, vmax=thresholded_vmax,
            #            cbar_kws={'shrink': 0.8, 'label': 'Residual Value'},
            #            xticklabels=False, yticklabels=False,
            #            square=True, linewidths=0.5, linecolor='white')
            ax2.set_title(f'Sample {i+1}: High Residual Regions\n(|residual| ≥ {threshold:.4f})', fontsize=12, fontweight='bold')
            ax2.set_xlabel('Width', fontsize=10)
            ax2.set_ylabel('Height', fontsize=10)
            plt.colorbar(im2, ax=ax2, shrink=0.8)
            
            # Add statistics text
            stats_text = f"""Stats:
Mean: {torch.mean(sample_residual):.4f}
Std:  {torch.std(sample_residual):.4f}
Max:  {residual_max:.4f}
Min:  {residual_min:.4f}
Abs Mean: {threshold:.4f}
Above threshold: {torch.sum(abs_residuals >= threshold)}/{height*width} pixels

Colormap Ranges:
Original: [{residual_min:.4f}, {residual_max:.4f}]
Thresholded: [{thresholded_vmin:.4f}, {thresholded_vmax:.4f}]

Distribution Shape:
Skewness: {self._calculate_skewness(sample_residual):.4f}
Kurtosis: {self._calculate_kurtosis(sample_residual):.4f}"""
            
            # ax2.text(0.02, 0.98, stats_text, transform=ax2.transAxes, 
            #         verticalalignment='top', fontsize=9, fontfamily='monospace',
            #         bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
        
        # Overall title
        fig.suptitle(f'Thresholded Residual Maps - Step {step}\n'
                    f'Left: Original residuals, Right: High residual regions (|residual| ≥ abs mean)', 
                    fontsize=16, fontweight='bold', y=0.98)
        
        plt.tight_layout()
        plt.subplots_adjust(top=0.92)
        
        if save_dir:
            filename = f'thresholded_residual_maps_step_{step}.png'
            plt.savefig(os.path.join(save_dir, filename), dpi=150, 
                       bbox_inches='tight', pad_inches=0.2)
        
        return fig

    def create_residual_heatmap_analysis(self, pde_residuals, sigmas, step, save_dir=None, n_samples=6):
        """Create comprehensive residual heatmap analysis with multiple visualization techniques
        
        Args:
            pde_residuals: Tensor of shape [batch_size, channels, height, width] or [batch_size, height, width]
            sigmas: Tensor of shape [batch_size]
            step: Current training step
            save_dir: Directory to save plots
            n_samples: Number of samples to visualize (default: 6)
            
        Returns:
            matplotlib figure or None if no data
        """
        if pde_residuals is None or pde_residuals.numel() == 0:
            return None
            
        # Handle tensor dimensions
        if len(pde_residuals.shape) == 4:
            batch_size, channels, height, width = pde_residuals.shape
            if channels > 1:
                residuals = torch.mean(pde_residuals, dim=1)
            else:
                residuals = pde_residuals.squeeze(1)
        elif len(pde_residuals.shape) == 3:
            batch_size, height, width = pde_residuals.shape
            residuals = pde_residuals
        else:
            print(f"Unexpected tensor shape: {pde_residuals.shape}")
            return None
        
        # Select samples to visualize
        n_samples = min(n_samples, batch_size)
        sample_indices = torch.randperm(batch_size)[:n_samples] if batch_size > n_samples else torch.arange(batch_size)
        
        # Create subplot layout: n_samples rows, 4 columns
        fig, axes = plt.subplots(n_samples, 4, figsize=(24, 6*n_samples))
        if n_samples == 1:
            axes = axes.reshape(1, -1)
        
        for i, sample_idx in enumerate(sample_indices):
            sample_residual = residuals[sample_idx]
            sample_sigma = sigmas[sample_idx]
            
            # Convert tensor to scalar for formatting
            sample_sigma_scalar = sample_sigma.item() if hasattr(sample_sigma, 'item') else float(sample_sigma)
            
            # 1. Original residual map
            ax1 = axes[i, 0]
            im1 = ax1.imshow(sample_residual.detach().cpu().numpy(), cmap='RdBu_r')
            ax1.set_title(f'Sample {i+1}: Original\nσ={sample_sigma_scalar:.3f}', fontsize=11, fontweight='bold')
            plt.colorbar(im1, ax=ax1, shrink=0.8)
            
            # 2. Absolute residual map
            ax2 = axes[i, 1]
            abs_residual = torch.abs(sample_residual)
            im2 = ax2.imshow(abs_residual.detach().cpu().numpy(), cmap='Reds')
            ax2.set_title(f'Sample {i+1}: Absolute\nMax={abs_residual.max().item():.4f}', fontsize=11, fontweight='bold')
            plt.colorbar(im2, ax=ax2, shrink=0.8)
            
            # 3. Thresholded map (top 25% residuals)
            ax3 = axes[i, 2]
            threshold_75 = torch.quantile(abs_residual, 0.75)
            thresholded_75 = torch.where(abs_residual >= threshold_75, abs_residual, torch.zeros_like(abs_residual))
            im3 = ax3.imshow(thresholded_75.detach().cpu().numpy(), cmap='Reds')
            ax3.set_title(f'Sample {i+1}: Top 25%\nThreshold={threshold_75.item():.4f}', fontsize=11, fontweight='bold')
            plt.colorbar(im3, ax=ax3, shrink=0.8)
            
            # 4. Residual distribution histogram
            ax4 = axes[i, 3]
            ax4.hist(sample_residual.detach().cpu().numpy().flatten(), bins=30, alpha=0.7, color='skyblue', edgecolor='black')
            ax4.axvline(x=0, color='red', linestyle='--', alpha=0.8, label='Zero')
            mean_residual = torch.mean(sample_residual).item()
            ax4.axvline(x=mean_residual, color='green', linestyle='--', alpha=0.8, label='Mean')
            ax4.set_xlabel('Residual Value', fontsize=10)
            ax4.set_ylabel('Frequency', fontsize=10)
            ax4.set_title(f'Sample {i+1}: Distribution\nMean={mean_residual:.4f}', fontsize=11, fontweight='bold')
            ax4.legend(fontsize=8)
            ax4.grid(True, alpha=0.3)
        
        # Overall title
        fig.suptitle(f'Residual Heatmap Analysis - Step {step}\n'
                    f'Multiple visualization techniques for residual analysis', 
                    fontsize=16, fontweight='bold', y=0.98)
        
        plt.tight_layout()
        plt.subplots_adjust(top=0.92)
        
        if save_dir:
            filename = f'residual_heatmap_analysis_step_{step}.png'
            plt.savefig(os.path.join(save_dir, filename), dpi=150, 
                       bbox_inches='tight', pad_inches=0.2)
        
        return fig

    def _calculate_skewness(self, tensor):
        """Calculate skewness of a tensor
        
        Skewness measures the asymmetry of the distribution:
        - Positive: right-tailed (longer right tail)
        - Negative: left-tailed (longer left tail)
        - Zero: symmetric distribution
        
        Args:
            tensor: PyTorch tensor
            
        Returns:
            float: Skewness value
        """
        # Convert to numpy for calculation
        data = tensor.detach().cpu().numpy().flatten()
        
        if len(data) < 3:
            return 0.0
            
        mean = np.mean(data)
        std = np.std(data)
        
        if std == 0:
            return 0.0
            
        # Calculate skewness: E[(X - μ)³] / σ³
        skewness = np.mean(((data - mean) / std) ** 3)
        return float(skewness)

    def _calculate_kurtosis(self, tensor):
        """Calculate kurtosis of a tensor
        
        Kurtosis measures the "tailedness" of the distribution:
        - High kurtosis: heavy tails, more extreme values
        - Low kurtosis: light tails, fewer extreme values
        - Normal distribution has kurtosis ≈ 3
        
        Args:
            tensor: PyTorch tensor
            
        Returns:
            float: Kurtosis value (excess kurtosis, so normal dist = 0)
        """
        # Convert to numpy for calculation
        data = tensor.detach().cpu().numpy().flatten()
        
        if len(data) < 4:
            return 0.0
            
        mean = np.mean(data)
        std = np.std(data)
        
        if std == 0:
            return 0.0
            
        # Calculate kurtosis: E[(X - μ)⁴] / σ⁴ - 3 (excess kurtosis)
        kurtosis = np.mean(((data - mean) / std) ** 4) - 3
        return float(kurtosis)

    def _calculate_skewness_from_array(self, array):
        """Calculate skewness of a numpy array
        
        Args:
            array: numpy array
            
        Returns:
            float: Skewness value
        """
        if len(array) < 3:
            return 0.0
            
        mean = np.mean(array)
        std = np.std(array)
        
        if std == 0:
            return 0.0
            
        # Calculate skewness: E[(X - μ)³] / σ³
        skewness = np.mean(((array - mean) / std) ** 3)
        return float(skewness)

    def _calculate_kurtosis_from_array(self, array):
        """Calculate kurtosis of a numpy array
        
        Args:
            array: numpy array
            
        Returns:
            float: Kurtosis value (excess kurtosis)
        """
        if len(array) < 4:
            return 0.0
            
        mean = np.mean(array)
        std = np.std(array)
        
        if std == 0:
            return 0.0
            
        # Calculate kurtosis: E[(X - μ)⁴] / σ⁴ - 3 (excess kurtosis)
        kurtosis = np.mean(((array - mean) / std) ** 4) - 3
        return float(kurtosis)

    def create_training_comparison_visualization(self, input_data, ground_truth, predictions, 
                                               step, save_dir=None, n_samples=16, 
                                               direction="forward", pde_loss_fn=None, sample_selection="random"):
        """Create comprehensive training visualization with multiple figures showing input, ground-truth, 
        predictions, residual maps, and difference maps
        
        Creates 4 figures with 4 samples each (16 total samples) with 7 columns:
        1. Input
        2. Ground Truth  
        3. Prediction (GT scale)
        4. Prediction (Pred scale)
        5. Difference Map (|Pred - GT|)
        6. PDE Residual (signed values)
        7. Boundary Residual (absolute residual at 3-pixel boundary, interior set to 0)
        
        Args:
            input_data: Input tensor [batch, channels, height, width] or [batch, height, width]
            ground_truth: Ground truth tensor [batch, channels, height, width] or [batch, height, width]
            predictions: Prediction tensor [batch, channels, height, width] or [batch, height, width]
            step: Current training step
            save_dir: Directory to save plots
            n_samples: Number of samples to visualize (default 16)
            direction: "forward" or "inverse" for problem direction
            pde_loss_fn: PDE loss function for residual computation
            sample_selection: "random" for different samples each time, "fixed" for same samples (not implemented)
            
        Returns:
            list of matplotlib figures or None if no data
        """
        if input_data is None or ground_truth is None or predictions is None:
            return None
            
        # Ensure consistent tensor shapes
        def ensure_3d(tensor):
            if tensor.ndim == 4:
                if tensor.shape[1] == 1:
                    return tensor.squeeze(1)  # Remove channel dimension if single channel
                else:
                    return tensor.mean(dim=1)  # Average across channels
            return tensor
        
        input_3d = ensure_3d(input_data)
        gt_3d = ensure_3d(ground_truth) 
        pred_3d = ensure_3d(predictions)
        
        batch_size = min(input_3d.shape[0], gt_3d.shape[0], pred_3d.shape[0])
        n_samples = min(n_samples, batch_size)
        
        # Select samples based on sample_selection mode
        if sample_selection == "fixed":
            # Fixed sample selection is not yet implemented
            print("Warning: Fixed sample selection is not implemented yet. Using random selection instead.")
            sample_indices = torch.randperm(batch_size)[:n_samples]
        else:  # sample_selection == "random"
            # Select random samples from current batch
            sample_indices = torch.randperm(batch_size)[:n_samples]
        
        # Create 4 figures with 4 samples each (total 16 samples)
        # 7 columns: Input, Ground Truth, Prediction (GT scale), Prediction (Pred scale), Difference, Residual, Thresholded Residual
        figures = []
        samples_per_figure = 4
        n_figures = min(4, (n_samples + samples_per_figure - 1) // samples_per_figure)  # Ceiling division
        
        for fig_idx in range(n_figures):
            start_idx = fig_idx * samples_per_figure
            end_idx = min(start_idx + samples_per_figure, len(sample_indices))
            current_sample_indices = sample_indices[start_idx:end_idx]
            current_n_samples = len(current_sample_indices)
            
            # Create figure with 7 columns
            fig, axes = plt.subplots(current_n_samples, 7, figsize=(28, 4*current_n_samples))
            if current_n_samples == 1:
                axes = axes.reshape(1, -1)
            
            for i, sample_idx in enumerate(current_sample_indices):
                input_sample = input_3d[sample_idx]
                gt_sample = gt_3d[sample_idx]
                pred_sample = pred_3d[sample_idx]
                
                # Convert to numpy for visualization
                input_np = input_sample.detach().cpu().numpy()
                gt_np = gt_sample.detach().cpu().numpy()
                pred_np = pred_sample.detach().cpu().numpy()
                
                # Compute difference map
                diff_np = np.abs(pred_np - gt_np)
                
                # Compute PDE residual if function provided
                if pde_loss_fn is not None:
                    try:
                        # Prepare tensors for PDE computation
                        if direction == "forward":
                            pde_input = input_sample.unsqueeze(0).unsqueeze(1)  # [1, 1, H, W]
                            pde_pred = pred_sample.unsqueeze(0).unsqueeze(1)    # [1, 1, H, W]
                        else:  # inverse
                            pde_input = gt_sample.unsqueeze(0).unsqueeze(1)     # [1, 1, H, W]
                            pde_pred = pred_sample.unsqueeze(0).unsqueeze(1)    # [1, 1, H, W]
                        
                        # Compute PDE residual
                        from training.evaluation_utils import compute_pde_loss
                        pde_residual = compute_pde_loss(
                            pde_loss_fn=pde_loss_fn,
                            pde_direction=direction,
                            images_pred_denorm=pde_pred,
                            labels_denorm=pde_input,
                            device=input_sample.device
                        )
                        # Keep original signed values (don't take absolute)
                        residual_np = pde_residual.squeeze().detach().cpu().numpy()
                    except Exception as e:
                        print(f"Warning: Could not compute PDE residual: {e}")
                        residual_np = np.zeros_like(pred_np)
                else:
                    residual_np = np.zeros_like(pred_np)
                
                # Compute boundary residual (absolute values, only at boundaries - 3 rows/cols from edges)
                abs_residual_np = np.abs(residual_np)
                boundary_residual_np = np.zeros_like(abs_residual_np)
                
                # Set boundary regions (3 pixels from each edge)
                h, w = abs_residual_np.shape
                boundary_width = 3
                
                # Top and bottom boundaries
                boundary_residual_np[:boundary_width, :] = abs_residual_np[:boundary_width, :]  # Top
                boundary_residual_np[-boundary_width:, :] = abs_residual_np[-boundary_width:, :] # Bottom
                
                # Left and right boundaries  
                boundary_residual_np[:, :boundary_width] = abs_residual_np[:, :boundary_width]   # Left
                boundary_residual_np[:, -boundary_width:] = abs_residual_np[:, -boundary_width:] # Right
                
                # Determine colormaps and scales
                if direction == "forward":
                    input_cmap = 'viridis'
                    pred_cmap = 'jet'
                    error_cmap = 'inferno'
                    residual_cmap = 'coolwarm'
                else:  # inverse
                    input_cmap = 'plasma'
                    pred_cmap = 'RdBu_r'
                    error_cmap = 'magma'
                    residual_cmap = 'seismic'
                
                global_sample_idx = start_idx + i + 1  # Global sample number across all figures
                
                # Plot 1: Input
                im1 = axes[i, 0].imshow(input_np, cmap=input_cmap, interpolation='bilinear')
                axes[i, 0].set_title(f'Sample {global_sample_idx}: Input ({direction})', fontsize=9, fontweight='bold')
                axes[i, 0].axis('off')
                plt.colorbar(im1, ax=axes[i, 0], shrink=0.8)
                
                # Plot 2: Ground Truth
                im2 = axes[i, 1].imshow(gt_np, cmap=pred_cmap, interpolation='bilinear')
                axes[i, 1].set_title(f'Sample {global_sample_idx}: Ground Truth', fontsize=9, fontweight='bold')
                axes[i, 1].axis('off')
                plt.colorbar(im2, ax=axes[i, 1], shrink=0.8)
                
                # Plot 3: Prediction (GT scale)
                im3 = axes[i, 2].imshow(pred_np, cmap=pred_cmap, vmin=gt_np.min(), vmax=gt_np.max(), interpolation='bilinear')
                axes[i, 2].set_title(f'Sample {global_sample_idx}: Pred (GT scale)', fontsize=9, fontweight='bold')
                axes[i, 2].axis('off')
                plt.colorbar(im3, ax=axes[i, 2], shrink=0.8)
                
                # Plot 4: Prediction (Pred scale)
                im4 = axes[i, 3].imshow(pred_np, cmap=pred_cmap, vmin=pred_np.min(), vmax=pred_np.max(), interpolation='bilinear')
                axes[i, 3].set_title(f'Sample {global_sample_idx}: Pred (Pred scale)', fontsize=9, fontweight='bold')
                axes[i, 3].axis('off')
                plt.colorbar(im4, ax=axes[i, 3], shrink=0.8)
                
                # Plot 5: Difference Map
                im5 = axes[i, 4].imshow(diff_np, cmap=error_cmap, interpolation='bilinear')
                axes[i, 4].set_title(f'Sample {global_sample_idx}: |Pred - GT|', fontsize=9, fontweight='bold')
                axes[i, 4].axis('off')
                plt.colorbar(im5, ax=axes[i, 4], shrink=0.8)
                
                # Plot 6: PDE Residual Map (signed values)
                im6 = axes[i, 5].imshow(residual_np, cmap=residual_cmap, interpolation='bilinear')
                axes[i, 5].set_title(f'Sample {global_sample_idx}: PDE Residual', fontsize=9, fontweight='bold')
                axes[i, 5].axis('off')
                plt.colorbar(im6, ax=axes[i, 5], shrink=0.8)
                
                # Plot 7: Boundary Residual (absolute residual at boundaries only)
                im7 = axes[i, 6].imshow(boundary_residual_np, cmap='hot', interpolation='bilinear')
                axes[i, 6].set_title(f'Sample {global_sample_idx}: Boundary Residual\n(3px from edges)', fontsize=9, fontweight='bold')
                axes[i, 6].axis('off')
                plt.colorbar(im7, ax=axes[i, 6], shrink=0.8)
                
                # Compute and display metrics
                rel_l2 = np.linalg.norm(pred_np - gt_np) / np.linalg.norm(gt_np)
                mean_residual = np.mean(abs_residual_np)
                boundary_residual_mean = np.mean(boundary_residual_np[boundary_residual_np > 0])  # Mean of non-zero boundary values
                boundary_pixels = np.sum(boundary_residual_np > 0)
                total_pixels = abs_residual_np.size
                
                # Add metrics text
                metrics_text = f"L2: {rel_l2:.4f}\nMean |Res|: {mean_residual:.4f}\nBoundary Res: {boundary_residual_mean:.4f}\nBoundary px: {boundary_pixels}/{total_pixels}"
                axes[i, 6].text(0.02, 0.98, metrics_text, transform=axes[i, 6].transAxes,
                               verticalalignment='top', fontsize=7, fontfamily='monospace',
                               bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
            
            # Overall title for this figure
            sample_mode_text = "Fixed Samples (Not Implemented)" if sample_selection == "fixed" else "Random Samples"
            fig.suptitle(f'Training Viz - Step {step} ({direction.capitalize()}) - Fig {fig_idx+1}/4 - {sample_mode_text}\n'
                        f'Input → GT → Pred(GT scale) → Pred(Pred scale) → |Diff| → Residual → Boundary Residual', 
                        fontsize=14, fontweight='bold', y=0.98)
            
            plt.tight_layout()
            plt.subplots_adjust(top=0.90)
            
            if save_dir:
                filename = f'training_visualization_{direction}_step_{step}_fig_{fig_idx+1}.png'
                plt.savefig(os.path.join(save_dir, filename), dpi=150, 
                           bbox_inches='tight', pad_inches=0.2)
            
            figures.append(fig)
        
        return figures

    def create_training_comparison_visualization_unified(self, ground_truth_a, ground_truth_u,
                                                         predictions_a, predictions_u, model_input_a, model_input_u,
                                                         mask_a, mask_u,
                                                         step, save_dir=None, n_samples=16,
                                                         direction="forward", pde_loss_fn=None, sample_selection="random"):
        """
        Create a comprehensive visualization for the 'unified' training mode, generating up to 4 figures.

        Args:
            ground_truth_a: Ground truth for PDE input (a)
            ground_truth_u: Ground truth for PDE output (u)
            predictions_a: Model prediction for input (a_pred)
            predictions_u: Model prediction for output (u_pred)
            model_input_a: Noisy input to the model for 'a'
            model_input_u: Noisy input to the model for 'u'
            step: Current training step
            save_dir: Directory to save plots
            n_samples: Total number of samples to visualize (e.g., 16)
            direction: 'forward' or 'inverse'
            pde_loss_fn: PDE loss function for residual calculation
            sample_selection: 'random' or 'first'
        """
        # print("--- Unified Visualization Debug ---")
        # print(f"Step: {step}")
        # print(f"ground_truth_a shape: {ground_truth_a.shape if hasattr(ground_truth_a, 'shape') else 'None'}")
        # print(f"ground_truth_u shape: {ground_truth_u.shape if hasattr(ground_truth_u, 'shape') else 'None'}")
        # print(f"predictions_a shape: {predictions_a.shape if hasattr(predictions_a, 'shape') else 'None'}")
        # print(f"predictions_u shape: {predictions_u.shape if hasattr(predictions_u, 'shape') else 'None'}")
        # print(f"model_input_a shape: {model_input_a.shape if hasattr(model_input_a, 'shape') else 'None'}")
        # print(f"model_input_u shape: {model_input_u.shape if hasattr(model_input_u, 'shape') else 'None'}")

        if ground_truth_a is None or ground_truth_u is None or predictions_a is None or predictions_u is None:
            print("Missing data for unified visualization.")
            return None

        batch_size = ground_truth_a.size(0)
        n_samples = min(n_samples, batch_size)
        # print(f"Batch size: {batch_size}, n_samples: {n_samples}")

        if sample_selection == "random":
            indices = torch.randperm(batch_size)[:n_samples]
        else:
            indices = torch.arange(n_samples)
        # print(f"Batch size: {batch_size}, n_samples: {n_samples}")

        figures = []
        samples_per_figure = 4
        n_figures = min(4, (n_samples + samples_per_figure - 1) // samples_per_figure)
        # print(f"Batch size: {batch_size}, n_samples: {n_samples}, n_figures: {n_figures}")

        for fig_idx in range(n_figures):
            start_idx = fig_idx * samples_per_figure
            end_idx = min(start_idx + samples_per_figure, len(indices))
            current_sample_indices = indices[start_idx:end_idx]
            current_n_samples = len(current_sample_indices)

            if current_n_samples == 0:
                continue

            n_cols = 11
            fig, axes = plt.subplots(current_n_samples, n_cols, figsize=(n_cols * 3, current_n_samples * 3.2))
            if current_n_samples == 1:
                axes = axes.reshape(1, -1)

            for i, sample_idx in enumerate(current_sample_indices):
                # Detach and move to CPU for the current sample
                gt_a = ground_truth_a[sample_idx].detach().cpu().numpy()
                gt_u = ground_truth_u[sample_idx].detach().cpu().numpy()
                pred_a = predictions_a[sample_idx].detach().cpu().numpy()
                pred_u = predictions_u[sample_idx].detach().cpu().numpy()
                model_in_a = model_input_a[sample_idx].detach().cpu().numpy()
                model_in_u = model_input_u[sample_idx].detach().cpu().numpy()
                # print("gt a shape: ", gt_a.shape)
                # print("mask shapes:", mask_a.shape, mask_u.shape)
                m_a = mask_a[sample_idx].detach().cpu().numpy()
                m_u = mask_u[sample_idx].detach().cpu().numpy()

                # print(f"Sample {sample_idx} shapes: gt_a={gt_a.shape}, gt_u={gt_u.shape}, pred_a={pred_a.shape}, pred_u={pred_u.shape}")
                # print(f"Sample {sample_idx} value ranges: gt_a=[{gt_a.min():.4f}, {gt_a.max():.4f}], gt_u=[{gt_u.min():.4f}, {gt_u.max():.4f}]")
                # print(f"Sample {sample_idx} value ranges: pred_a=[{pred_a.min():.4f}, {pred_a.max():.4f}], pred_u=[{pred_u.min():.4f}, {pred_u.max():.4f}]")


                # Calculate differences
                diff_a = gt_a[0] - pred_a[0]
                diff_u = gt_u[0] - pred_u[0]

                # Calculate PDE residuals for predictions
                pde_residuals_pred = np.zeros_like(pred_u[0])
                if pde_loss_fn is not None:
                    try:
                        from training.evaluation_utils import compute_pde_loss
                        pde_res_tensor = compute_pde_loss(
                            pde_loss_fn, "forward",
                            torch.from_numpy(pred_u).unsqueeze(0),
                            torch.from_numpy(pred_a).unsqueeze(0)
                        )
                        pde_residuals_pred = pde_res_tensor.squeeze().cpu().numpy()
                    except Exception as e:
                        print(f"Could not compute PDE residual for visualization: {e}")
                # Extract the correct 2D slices for plotting
                def get_2d_slice(array):
                    if array.ndim == 2:
                        return array
                    elif array.ndim == 3:
                        return array[0]  # Take first channel if multiple channels
                    elif array.ndim == 4:
                        return array[0, 0]  # Take first sample, first channel
                    else:
                        print(f"Unexpected array shape: {array.shape}")
                        return array.squeeze()  # Try to squeeze to 2D
                
                try:
                    gt_a_2d = get_2d_slice(gt_a)
                    gt_u_2d = get_2d_slice(gt_u)
                    pred_a_2d = get_2d_slice(pred_a)
                    pred_u_2d = get_2d_slice(pred_u)
                    model_in_a_2d = get_2d_slice(model_in_a)
                    model_in_u_2d = get_2d_slice(model_in_u)
                    diff_a_2d = get_2d_slice(diff_a)
                    diff_u_2d = get_2d_slice(diff_u)
                    mask_a_2d= get_2d_slice(m_a)
                    mask_u_2d= get_2d_slice(m_u)
                    # print(f"2D slices extracted: gt_a={gt_a_2d.shape}, gt_u={gt_u_2d.shape}, diff_a={diff_a_2d.shape}")
                

                    plots = {
                        "GT a": gt_a_2d, "GT u": gt_u_2d,
                        "Cond a": model_in_a_2d, "Cond u": model_in_u_2d,
                        "Pred a": pred_a_2d, "Pred u": pred_u_2d,
                        "Diff a": diff_a_2d, "Diff u": diff_u_2d,
                        "PDE Res": pde_residuals_pred,
                        "Mask a": mask_a_2d, "Mask u": mask_u_2d
                    }
                    
                    cmaps = {
                        "GT a": 'viridis', "GT u": 'jet',
                        "Cond a": 'viridis', "Cond u": 'jet',
                        "Pred a": 'viridis', "Pred u": 'jet',
                        "Diff a": 'coolwarm', "Diff u": 'coolwarm',
                        "PDE Res": 'hot',
                        "Mask a": 'binary', "Mask u": 'binary'
                    }

                    for j, (title, data) in enumerate(plots.items()):
                        ax = axes[i, j]
                        if title in ["Mask a", "Mask u"]:
                            # Ensure binary masks are displayed with same scale (0-1)
                            im = ax.imshow(data, cmap=cmaps[title], vmin=0, vmax=1)
                            # Add mask coverage percentage to title
                            coverage = np.mean(data) * 100  # Calculate percentage of 1s
                            ax.set_title(f'Sample {start_idx + i + 1}: {title}\n({coverage:.1f}% active)')
                        else:
                            # Regular plotting for non-mask data
                            im = ax.imshow(data, cmap=cmaps[title])
                            ax.set_title(f'Sample {start_idx + i + 1}: {title}')
                        fig.colorbar(im, ax=ax)
                        ax.set_xticks([])
                        ax.set_yticks([])
                except Exception as e:
                    print(f"Error during plotting: {e}")
                    traceback.print_exc()

            fig.suptitle(f'Unified Training Visualization - Step {step} - Fig {fig_idx+1}/{n_figures}', fontsize=16)
            plt.tight_layout(rect=[0, 0.03, 1, 0.95])

            if save_dir:
                os.makedirs(save_dir, exist_ok=True)
                save_path = os.path.join(save_dir, f'unified_training_comparison_step_{step}_fig_{fig_idx+1}.png')
                plt.savefig(save_path)
                # print(f"Saved unified training visualization to {save_path}")

            # print(f"Figure {fig_idx+1} complete. Adding to figures list.")   
            figures.append(fig)
        # print(f"Visualization complete. Returning {len(figures)} figures.")
        return figures

    def log_training_data(self, input_data, ground_truth, predictions, pde_residuals, sigmas, step, direction="forward", pde_loss_fn=None):
        """Log training data for comprehensive visualization
        
        Args:
            input_data: Input tensor [batch, channels, height, width] or [batch, height, width] 
            ground_truth: Ground truth tensor [batch, channels, height, width] or [batch, height, width]
            predictions: Prediction tensor [batch, channels, height, width] or [batch, height, width]
            pde_residuals: PDE residual tensor [batch, channels, height, width] or [batch, height, width]
            sigmas: Sigma values [batch]
            step: Current training step
            direction: "forward" or "inverse"
            pde_loss_fn: PDE loss function to cache for visualization
        """
        # Cache the PDE loss function for later use in visualizations
        if pde_loss_fn is not None:
            self._cached_pde_loss_fn = pde_loss_fn
            
        # Store training data (keep only recent batches to save memory)
        training_data = {
            'input_data': input_data.detach().cpu(),
            'ground_truth': ground_truth.detach().cpu(), 
            'predictions': predictions.detach().cpu(),
            'pde_residuals': pde_residuals.detach().cpu(),
            'sigmas': sigmas.detach().cpu(),
            'step': step,
            'direction': direction
        }
        
        self.recent_training_data.append(training_data)
        
        # Keep only the most recent batches
        if len(self.recent_training_data) > self.max_training_data:
            self.recent_training_data.pop(0)

    def log_training_data_unified(self, ground_truth_a, ground_truth_u, model_input_a, model_input_u,
                                    predictions_a, predictions_u, mask_a, mask_u,
                                    pde_residuals, sigmas, step, direction="forward", pde_loss_fn=None):
        """Log training data for comprehensive visualization
        
        Args:
            input_data: Input tensor [batch, channels, height, width] or [batch, height, width] 
            ground_truth: Ground truth tensor [batch, channels, height, width] or [batch, height, width]
            predictions: Prediction tensor [batch, channels, height, width] or [batch, height, width]
            pde_residuals: PDE residual tensor [batch, channels, height, width] or [batch, height, width]
            sigmas: Sigma values [batch]
            step: Current training step
            direction: "forward" or "inverse"
            pde_loss_fn: PDE loss function to cache for visualization
        """
        # Cache the PDE loss function for later use in visualizations
        if pde_loss_fn is not None:
            self._cached_pde_loss_fn = pde_loss_fn
            
        # Store training data (keep only recent batches to save memory)
        training_data = {
            'ground_truth_a': ground_truth_a.detach().cpu(), 
            'ground_truth_u': ground_truth_u.detach().cpu(), 
            'model_input_a': model_input_a.detach().cpu(),
            'model_input_u': model_input_u.detach().cpu(),
            'predictions_a': predictions_a.detach().cpu(),
            'predictions_u': predictions_u.detach().cpu(),
            'pde_residuals': pde_residuals.detach().cpu(),
            'mask_a': mask_a.detach().cpu(),
            'mask_u': mask_u.detach().cpu(),
            'sigmas': sigmas.detach().cpu(),
            'step': step,
            'direction': direction
        }

        self.recent_training_data.append(training_data)
        
        # Keep only the most recent batches
        if len(self.recent_training_data) > self.max_training_data:
            self.recent_training_data.pop(0)