"""
Complete Research Pipeline: DeepONet vs FNO vs LSTM for Ocean Temperature Prediction
with Captum Interpretability Analysis and Multi-Seed Statistical Rigor

Novel Contributions:
1. Physics-Informed Neural Operators with PDE residuals for oceanographic prediction
2. Multi-method interpretability analysis (6 Captum methods including layer conductance)
3. Deep Ensemble uncertainty quantification with epistemic uncertainty estimation
4. Proper LSTM sequence processing for temporal pattern learning
5. Physics validation metrics (stratification, bounds, T-S relationships, PDE compliance)
6. Layer-wise attribution for internal representation analysis
7. Cross-method interpretability validation

Requirements:
pip install argopy xarray netCDF4 torch captum numpy pandas matplotlib seaborn scikit-learn tqdm scipy

"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from scipy import stats
from tqdm import tqdm

# Captum for interpretability
from captum.attr import (
    IntegratedGradients, 
    Saliency, 
    DeepLift,
    GradientShap,
    LayerConductance,
    LayerGradientXActivation
)

# Set random seeds for reproducibility
def set_seed(seed: int):
    """Set all random seeds for reproducibility"""
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

# =====================================================================
# DATA LOADING AND PREPROCESSING
# =====================================================================

class ArgoDataProcessor:
    """Process Argo float data for neural operator learning"""
    
    def __init__(self, region_bounds: List[float], date_range: List[str], 
                 max_depth: float = 500.0):
        """
        Args:
            region_bounds: [lon_min, lon_max, lat_min, lat_max]
            date_range: ['start_date', 'end_date'] in 'YYYY-MM-DD' format
            max_depth: Maximum pressure/depth to consider
        """
        self.region_bounds = region_bounds
        self.date_range = date_range
        self.max_depth = max_depth
        self.scalers = {}
        
    def fetch_data(self):
        """
        Fetch real Argo data using argopy
        """
        try:
            from argopy import DataFetcher as ArgoDataFetcher
            
            print("Fetching real Argo data...")
            print(f"Region: {self.region_bounds}")
            print(f"Date range: {self.date_range}")
            print(f"Max depth: {self.max_depth} dbar")
            
            argo_loader = ArgoDataFetcher().region(
                [*self.region_bounds, 0, self.max_depth, *self.date_range]
            )
            ds = argo_loader.to_xarray()
            
            print(f"Successfully fetched Argo data!")
            print(f"Dataset variables: {list(ds.variables)}")
            print(f"Dataset dimensions: {dict(ds.dims)}")
            
            return ds
        except ImportError:
            raise ImportError(
                "argopy is not installed. Please install it with:\n"
                "pip install argopy xarray netCDF4"
            )
        except Exception as e:
            print(f"Error fetching Argo data: {e}")
            print("This might be due to:")
            print("  - No internet connection")
            print("  - Argopy server issues")
            print("  - No data available for this region/time")
            print("\nPlease check:")
            print("  1. Internet connection is active")
            print("  2. Region bounds and date range have available data")
            print("  3. Argopy servers are operational")
            raise RuntimeError(f"Failed to fetch real Argo data: {e}")
    
    def create_sequences(self, df, sequence_length: int = 5):
        """
        Create time sequences from Argo profiles for LSTM
        Groups by N_PROF (profile ID) and creates sliding windows
        
        Args:
            df: DataFrame with Argo data
            sequence_length: Number of timesteps in each sequence
            
        Returns:
            DataFrame with sequence information added
        """
        if 'N_PROF' not in df.columns:
            print("Warning: N_PROF not found, creating pseudo-profiles based on time/location")
            # Group nearby measurements in space and time
            df['N_PROF'] = (df.groupby(['LATITUDE', 'LONGITUDE']).ngroup())
        
        # Sort by profile and depth
        df = df.sort_values(['N_PROF', 'PRES' if 'PRES' in df.columns else 'pressure'])
        
        sequences = []
        profile_groups = df.groupby('N_PROF')
        
        print(f"Creating sequences from {len(profile_groups)} profiles...")
        
        for prof_id, group in profile_groups:
            if len(group) >= sequence_length:
                # Create sliding windows within each profile
                for i in range(len(group) - sequence_length + 1):
                    seq = group.iloc[i:i+sequence_length].copy()
                    seq['SEQ_ID'] = f"{prof_id}_{i}"
                    seq['SEQ_POS'] = range(sequence_length)
                    sequences.append(seq)
        
        if sequences:
            df_sequences = pd.concat(sequences, ignore_index=True)
            print(f"Created {len(df_sequences) // sequence_length} sequences")
            return df_sequences
        else:
            print("Warning: Could not create sequences, falling back to single-step mode")
            df['SEQ_ID'] = df.index
            df['SEQ_POS'] = 0
            return df
    
    def preprocess(self, data, test_size: float = 0.2, val_size: float = 0.1, 
                   use_sequences: bool = False, sequence_length: int = 5):
        """
        Preprocess data for neural operator training
        Handles both xarray.Dataset and pandas.DataFrame
        
        Args:
            data: Input data (xarray.Dataset or pandas.DataFrame)
            test_size: Fraction for test set
            val_size: Fraction for validation set
            use_sequences: If True, create time sequences for LSTM
            sequence_length: Length of sequences for LSTM
        
        Returns:
            Dict with train/val/test splits and scalers
        """
        # Convert xarray Dataset to pandas DataFrame if needed
        if hasattr(data, 'to_dataframe'):
            print("Converting xarray Dataset to DataFrame...")
            df = data.to_dataframe().reset_index()
        else:
            df = data.copy()
        
        # Remove NaN values
        print(f"Original data size: {len(df)}")
        
        # Handle potential column name variations
        required_cols = []
        lat_col = 'LATITUDE' if 'LATITUDE' in df.columns else ('latitude' if 'latitude' in df.columns else None)
        lon_col = 'LONGITUDE' if 'LONGITUDE' in df.columns else ('longitude' if 'longitude' in df.columns else None)
        pres_col = 'PRES' if 'PRES' in df.columns else ('pressure' if 'pressure' in df.columns else None)
        psal_col = 'PSAL' if 'PSAL' in df.columns else ('salinity' if 'salinity' in df.columns else None)
        temp_col = 'TEMP' if 'TEMP' in df.columns else ('temperature' if 'temperature' in df.columns else None)
        
        # Validate required columns exist
        if None in [lat_col, lon_col, pres_col, psal_col, temp_col]:
            raise ValueError(f"Missing required columns. Found: {list(df.columns)}")
        
        df = df.dropna(subset=[temp_col, psal_col, pres_col, lat_col, lon_col])
        print(f"After removing NaN: {len(df)}")
        
        if len(df) == 0:
            raise ValueError("No valid data remaining after removing NaN values")
        
        # Convert time to numeric (days since start)
        if 'TIME' in df.columns:
            if isinstance(df['TIME'].iloc[0], pd.Timestamp):
                time_origin = df['TIME'].min()
                df['TIME_NUMERIC'] = (df['TIME'] - time_origin).dt.total_seconds() / 86400
            else:
                df['TIME_NUMERIC'] = pd.to_datetime(df['TIME']).map(
                    lambda x: (x - pd.to_datetime(self.date_range[0])).total_seconds() / 86400
                )
        elif 'time' in df.columns:  # xarray often uses lowercase
            df['TIME'] = pd.to_datetime(df['time'])
            time_origin = df['TIME'].min()
            df['TIME_NUMERIC'] = (df['TIME'] - time_origin).dt.total_seconds() / 86400
        else:
            raise ValueError("No TIME or time column found in data")
        
        # Create sequences for LSTM if requested
        if use_sequences:
            df = self.create_sequences(df, sequence_length)
            # Store sequence information
            self.scalers['sequence_length'] = sequence_length
            self.scalers['use_sequences'] = True
        else:
            self.scalers['use_sequences'] = False
        
        # Extract features
        X_trunk = df[[lat_col, lon_col, 'TIME_NUMERIC']].values.astype(np.float32)
        X_branch = df[[pres_col, psal_col]].values.astype(np.float32)
        y = df[temp_col].values.reshape(-1, 1).astype(np.float32)
        
        # Normalize features
        self.scalers['trunk'] = StandardScaler()
        self.scalers['branch'] = StandardScaler()
        self.scalers['target'] = StandardScaler()
        
        X_trunk_scaled = self.scalers['trunk'].fit_transform(X_trunk)
        X_branch_scaled = self.scalers['branch'].fit_transform(X_branch)
        y_scaled = self.scalers['target'].fit_transform(y)
        
        # Handle sequences if created
        if use_sequences and 'SEQ_ID' in df.columns:
            # Reshape into sequences
            seq_ids = df['SEQ_ID'].unique()
            n_sequences = len(seq_ids)
            
            X_trunk_seq = np.zeros((n_sequences, sequence_length, X_trunk.shape[1]), dtype=np.float32)
            X_branch_seq = np.zeros((n_sequences, sequence_length, X_branch.shape[1]), dtype=np.float32)
            y_seq = np.zeros((n_sequences, 1), dtype=np.float32)
            time_seq = np.zeros(n_sequences, dtype=np.float32)
            
            for idx, seq_id in enumerate(seq_ids):
                seq_mask = df['SEQ_ID'] == seq_id
                seq_data_idx = np.where(seq_mask)[0]
                
                if len(seq_data_idx) == sequence_length:
                    X_trunk_seq[idx] = X_trunk_scaled[seq_data_idx]
                    X_branch_seq[idx] = X_branch_scaled[seq_data_idx]
                    # Target is the last timestep's temperature
                    y_seq[idx] = y_scaled[seq_data_idx[-1]]
                    time_seq[idx] = df.iloc[seq_data_idx[-1]]['TIME_NUMERIC']
            
            # Temporal split for sequences
            time_threshold_val = np.percentile(time_seq, (1 - test_size - val_size) * 100)
            time_threshold_test = np.percentile(time_seq, (1 - test_size) * 100)
            
            train_mask = time_seq < time_threshold_val
            val_mask = (time_seq >= time_threshold_val) & (time_seq < time_threshold_test)
            test_mask = time_seq >= time_threshold_test
            
            data_splits = {
                'X_trunk_train': torch.FloatTensor(X_trunk_seq[train_mask]),
                'X_branch_train': torch.FloatTensor(X_branch_seq[train_mask]),
                'y_train': torch.FloatTensor(y_seq[train_mask]),
                
                'X_trunk_val': torch.FloatTensor(X_trunk_seq[val_mask]),
                'X_branch_val': torch.FloatTensor(X_branch_seq[val_mask]),
                'y_val': torch.FloatTensor(y_seq[val_mask]),
                
                'X_trunk_test': torch.FloatTensor(X_trunk_seq[test_mask]),
                'X_branch_test': torch.FloatTensor(X_branch_seq[test_mask]),
                'y_test': torch.FloatTensor(y_seq[test_mask]),
            }
            
            print(f"Sequence splits: Train={train_mask.sum()}, Val={val_mask.sum()}, Test={test_mask.sum()}")
            return data_splits
        
        # Temporal split to avoid data leakage (non-sequence mode)
        time_values = df['TIME_NUMERIC'].values
        time_threshold_val = np.percentile(time_values, (1 - test_size - val_size) * 100)
        time_threshold_test = np.percentile(time_values, (1 - test_size) * 100)
        
        train_mask = time_values < time_threshold_val
        val_mask = (time_values >= time_threshold_val) & (time_values < time_threshold_test)
        test_mask = time_values >= time_threshold_test
        
        # Ensure all splits have data
        if train_mask.sum() == 0 or val_mask.sum() == 0 or test_mask.sum() == 0:
            raise ValueError("One or more data splits are empty. Check your test_size and val_size parameters.")
        
        data_splits = {
            'X_trunk_train': torch.FloatTensor(X_trunk_scaled[train_mask]),
            'X_branch_train': torch.FloatTensor(X_branch_scaled[train_mask]),
            'y_train': torch.FloatTensor(y_scaled[train_mask]),
            
            'X_trunk_val': torch.FloatTensor(X_trunk_scaled[val_mask]),
            'X_branch_val': torch.FloatTensor(X_branch_scaled[val_mask]),
            'y_val': torch.FloatTensor(y_scaled[val_mask]),
            
            'X_trunk_test': torch.FloatTensor(X_trunk_scaled[test_mask]),
            'X_branch_test': torch.FloatTensor(X_branch_scaled[test_mask]),
            'y_test': torch.FloatTensor(y_scaled[test_mask]),
        }
        
        print(f"Data splits: Train={train_mask.sum()}, Val={val_mask.sum()}, Test={test_mask.sum()}")
        
        return data_splits

# =====================================================================
# MODEL ARCHITECTURES
# =====================================================================

class DeepONet(nn.Module):
    """Deep Operator Network for learning operators"""
    
    def __init__(self, branch_dim: int, trunk_dim: int, 
                 hidden_dim: int = 128, num_layers: int = 4,
                 use_physics_loss: bool = False):
        super(DeepONet, self).__init__()
        
        if num_layers < 1:
            raise ValueError("num_layers must be at least 1")
        
        self.use_physics_loss = use_physics_loss
        
        # Branch network (processes input functions)
        branch_layers = []
        branch_layers.append(nn.Linear(branch_dim, hidden_dim))
        branch_layers.append(nn.Tanh())
        
        for _ in range(num_layers - 1):
            branch_layers.append(nn.Linear(hidden_dim, hidden_dim))
            branch_layers.append(nn.Tanh())
        
        self.branch_net = nn.Sequential(*branch_layers)
        
        # Trunk network (processes spatiotemporal coordinates)
        trunk_layers = []
        trunk_layers.append(nn.Linear(trunk_dim, hidden_dim))
        trunk_layers.append(nn.Tanh())
        
        for _ in range(num_layers - 1):
            trunk_layers.append(nn.Linear(hidden_dim, hidden_dim))
            trunk_layers.append(nn.Tanh())
        
        self.trunk_net = nn.Sequential(*trunk_layers)
        
        # Bias term
        self.bias = nn.Parameter(torch.zeros(1))
        
    def forward(self, x_branch, x_trunk):
        """
        Args:
            x_branch: (batch, branch_dim) - input function measurements
            x_trunk: (batch, trunk_dim) - spatiotemporal coordinates
        Returns:
            output: (batch, 1) - predicted temperature
        """
        branch_out = self.branch_net(x_branch)  # (batch, hidden_dim)
        trunk_out = self.trunk_net(x_trunk)      # (batch, hidden_dim)
        
        # Inner product + bias
        output = torch.sum(branch_out * trunk_out, dim=1, keepdim=True) + self.bias
        
        # Note: Removed hard clamp constraint - temperature bounds are now
        # enforced solely via soft penalty in PhysicsInformedLoss to avoid
        # zero-gradient "dead zones" and enable better learning
        
        return output


class FNO1d(nn.Module):
    """Fourier Neural Operator (simplified 1D version)"""
    
    def __init__(self, input_dim: int, hidden_dim: int = 64, 
                 num_layers: int = 4, modes: int = 16):
        super(FNO1d, self).__init__()
        
        if num_layers < 1:
            raise ValueError("num_layers must be at least 1")
        if modes < 1:
            raise ValueError("modes must be at least 1")
        
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        
        self.fourier_layers = nn.ModuleList([
            FourierLayer(hidden_dim, modes) for _ in range(num_layers)
        ])
        
        self.output_proj = nn.Linear(hidden_dim, 1)
        
    def forward(self, x_branch, x_trunk):
        """
        Args:
            x_branch: (batch, branch_dim)
            x_trunk: (batch, trunk_dim)
        Returns:
            output: (batch, 1)
        """
        # Concatenate inputs
        x = torch.cat([x_branch, x_trunk], dim=-1)
        
        x = self.input_proj(x)
        
        for layer in self.fourier_layers:
            x = layer(x)
        
        return self.output_proj(x)


class FourierLayer(nn.Module):
    """Single Fourier layer for FNO"""
    
    def __init__(self, hidden_dim: int, modes: int):
        super(FourierLayer, self).__init__()
        
        self.modes = modes
        self.hidden_dim = hidden_dim
        
        # Fourier weights sized for the truncated modes
        # For rfft, we get hidden_dim//2 + 1 frequency components
        max_modes = min(modes, hidden_dim // 2 + 1)
        self.max_modes = max_modes
        
        scale = 1 / (hidden_dim * max_modes)
        
        # Learnable weights in Fourier space (properly sized)
        self.weights_real = nn.Parameter(
            scale * torch.randn(hidden_dim, max_modes)
        )
        self.weights_imag = nn.Parameter(
            scale * torch.randn(hidden_dim, max_modes)
        )
        
        self.linear = nn.Linear(hidden_dim, hidden_dim)
        self.activation = nn.GELU()
        
    def forward(self, x):
        """
        Args:
            x: (batch, hidden_dim)
        Returns:
            output: (batch, hidden_dim)
        """
        batch_size, hidden_dim = x.shape
        
        # Apply linear transform first (residual path)
        x_linear = self.linear(x)
        
        # FFT along feature dimension
        x_ft = torch.fft.rfft(x, dim=1, norm='ortho')  # (batch, hidden_dim//2 + 1)
        
        # Prepare output in Fourier space (match x_ft dimensions)
        out_ft = torch.zeros(batch_size, hidden_dim, dtype=torch.complex64, device=x.device)
        
        # Determine actual modes to use (can't exceed available frequencies)
        modes_to_use = min(self.max_modes, x_ft.shape[1])
        
        # Create complex weights
        weights = torch.complex(
            self.weights_real[:, :modes_to_use], 
            self.weights_imag[:, :modes_to_use]
        )  # (hidden_dim, modes_to_use)
        
        # Apply spectral convolution using matrix multiplication
        # x_ft[:, :modes_to_use] has shape (batch, modes_to_use)
        # We need to apply the weight matrix per batch element
        out_ft_truncated = torch.einsum(
            'bm,dm->bd',
            x_ft[:, :modes_to_use],  # (batch, modes_to_use)
            weights  # (hidden_dim, modes_to_use)
        )  # (batch, hidden_dim)
        
        # Inverse FFT back to spatial domain directly from the truncated output
        x_conv = torch.fft.irfft(out_ft_truncated, n=hidden_dim, dim=1, norm='ortho')
        
        # Combine spectral and spatial information
        return self.activation(x_conv + x_linear)


class LSTMPredictor(nn.Module):
    """
    LSTM for sequential temperature prediction
    Properly handles time sequences for temporal pattern learning
    """
    
    def __init__(self, input_dim: int, hidden_dim: int = 128, 
                 num_layers: int = 2, dropout: float = 0.2, use_sequences: bool = True):
        super(LSTMPredictor, self).__init__()
        
        if num_layers < 1:
            raise ValueError("num_layers must be at least 1")
        if not 0 <= dropout < 1:
            raise ValueError("dropout must be between 0 and 1")
        
        self.use_sequences = use_sequences
        
        self.lstm = nn.LSTM(
            input_dim, 
            hidden_dim, 
            num_layers, 
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )
        
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1)
        )
        
    def forward(self, x_branch, x_trunk):
        """
        Args:
            x_branch: (batch, seq_len, branch_dim) if use_sequences else (batch, branch_dim)
            x_trunk: (batch, seq_len, trunk_dim) if use_sequences else (batch, trunk_dim)
        Returns:
            output: (batch, 1)
        """
        if self.use_sequences:
            # Input already has sequence dimension: (batch, seq_len, features)
            # Concatenate along feature dimension
            x = torch.cat([x_branch, x_trunk], dim=-1)  # (batch, seq_len, total_features)
        else:
            # Concatenate and add sequence dimension for backwards compatibility
            x = torch.cat([x_branch, x_trunk], dim=-1).unsqueeze(1)  # (batch, 1, total_features)
        
        lstm_out, _ = self.lstm(x)  # (batch, seq_len, hidden_dim)
        output = self.fc(lstm_out[:, -1, :])  # Use last timestep
        
        return output


# =====================================================================
# TRAINING AND EVALUATION
# =====================================================================

class ModelTrainer:
    """Unified trainer for all models with comprehensive metrics"""
    
    def __init__(self, model, device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
                 use_physics_loss: bool = False, scalers: dict = None):
        self.model = model.to(device)
        self.device = device
        self.history = {'train_loss': [], 'val_loss': []}
        self.best_model_state = None
        self.use_physics_loss = use_physics_loss
        self.scalers = scalers  # Store scalers for physics loss
        
    def train_epoch(self, train_loader, optimizer, criterion):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0
        num_batches = 0
        
        for x_branch, x_trunk, y in train_loader:
            x_branch = x_branch.to(self.device)
            x_trunk = x_trunk.to(self.device)
            y = y.to(self.device)
            
            # CRITICAL FIX: Enable gradients on inputs for PDE loss computation
            if self.use_physics_loss:
                x_branch = x_branch.requires_grad_(True)
                x_trunk = x_trunk.requires_grad_(True)
            
            optimizer.zero_grad()
            outputs = self.model(x_branch, x_trunk)
            
            # Use physics-informed loss if enabled
            if self.use_physics_loss and self.scalers is not None:
                loss = criterion(outputs, y, x_branch, x_trunk,
                               self.scalers.get('branch'), 
                               self.scalers.get('target'))
            else:
                loss = criterion(outputs, y)
            
            # Check for NaN loss
            if torch.isnan(loss):
                print("Warning: NaN loss detected, skipping batch")
                continue
            
            loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
        
        return total_loss / max(num_batches, 1)
    
    def validate(self, val_loader, criterion):
        """Validate model"""
        self.model.eval()
        total_loss = 0
        num_batches = 0
        
        with torch.no_grad():
            for x_branch, x_trunk, y in val_loader:
                x_branch = x_branch.to(self.device)
                x_trunk = x_trunk.to(self.device)
                y = y.to(self.device)
                
                outputs = self.model(x_branch, x_trunk)
                loss = criterion(outputs, y)
                
                if not torch.isnan(loss):
                    total_loss += loss.item()
                    num_batches += 1
        
        return total_loss / max(num_batches, 1)
    
    def train(self, train_loader, val_loader, epochs: int = 100, 
              lr: float = 1e-3, patience: int = 15):
        """Train model with early stopping"""
        if epochs < 1:
            raise ValueError("epochs must be at least 1")
        if lr <= 0:
            raise ValueError("learning rate must be positive")
        if patience < 1:
            raise ValueError("patience must be at least 1")
        
        # Use physics-informed loss if enabled, otherwise standard MSE
        if self.use_physics_loss:
            criterion = PhysicsInformedLoss(
                lambda_data=1.0, 
                lambda_physics=0.1, 
                lambda_ts_relation=0.05,
                lambda_pde=0.1  # Increased from 0.01 for stronger PDE enforcement
            )
            print("Using Physics-Informed Loss with PDE residual")
        else:
            criterion = nn.MSELoss()
            
        optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-5)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=5, verbose=False
        )
        
        best_val_loss = float('inf')
        patience_counter = 0
        
        pbar = tqdm(range(epochs), desc="Training")
        for epoch in pbar:
            train_loss = self.train_epoch(train_loader, optimizer, criterion)
            val_loss = self.validate(val_loader, nn.MSELoss())  # Always use MSE for validation
            
            self.history['train_loss'].append(train_loss)
            self.history['val_loss'].append(val_loss)
            
            scheduler.step(val_loss)
            
            pbar.set_postfix({
                'train_loss': f'{train_loss:.4f}',
                'val_loss': f'{val_loss:.4f}'
            })
            
            # Early stopping
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                self.best_model_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()}
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"\nEarly stopping at epoch {epoch+1}")
                    break
        
        # Restore best model
        if self.best_model_state is not None:
            self.model.load_state_dict({k: v.to(self.device) for k, v in self.best_model_state.items()})
        
    def evaluate(self, test_loader, scaler):
        """Comprehensive evaluation with multiple metrics"""
        self.model.eval()
        predictions = []
        targets = []
        
        with torch.no_grad():
            for x_branch, x_trunk, y in test_loader:
                x_branch = x_branch.to(self.device)
                x_trunk = x_trunk.to(self.device)
                
                outputs = self.model(x_branch, x_trunk)
                predictions.append(outputs.cpu().numpy())
                targets.append(y.numpy())
        
        predictions = np.concatenate(predictions, axis=0)
        targets = np.concatenate(targets, axis=0)
        
        # Inverse transform to original scale
        predictions = scaler.inverse_transform(predictions)
        targets = scaler.inverse_transform(targets)
        
        # Compute metrics
        mse_val = mean_squared_error(targets, predictions)
        mae_val = mean_absolute_error(targets, predictions)
        r2_val = r2_score(targets, predictions)
        
        # Avoid division by zero in MAPE
        mask = np.abs(targets) > 1e-8
        mape_val = np.mean(np.abs((targets[mask] - predictions[mask]) / targets[mask])) * 100 if mask.sum() > 0 else 0.0
        
        metrics = {
            'mse': float(mse_val),
            'rmse': float(np.sqrt(mse_val)),
            'mae': float(mae_val),
            'r2': float(r2_val),
            'mape': float(mape_val)
        }
        
        return metrics, predictions, targets
    
    def evaluate_with_physics(self, test_loader, scaler, test_data_raw):
        """
        Comprehensive evaluation with physics validation
        Returns standard metrics + physics compliance scores
        """
        metrics, predictions, targets = self.evaluate(test_loader, scaler)
        
        # Extract raw features for physics validation
        x_branch_raw = test_data_raw['X_branch'].numpy()
        x_trunk_raw = test_data_raw['X_trunk'].numpy()
        
        # Physics validation
        validator = PhysicsValidator()
        
        # Temperature bounds
        bounds_check = validator.check_temperature_bounds(predictions)
        
        # Stratification (need pressure from branch input)
        pressures = x_branch_raw[:, 0]  # First column is pressure
        stratification_check = validator.check_stratification(
            predictions, pressures, x_trunk_raw
        )
        
        # Spatial smoothness
        smoothness_check = validator.check_spatial_smoothness(
            predictions, x_trunk_raw
        )
        
        # Combine all metrics
        physics_metrics = {
            **bounds_check,
            **stratification_check,
            **smoothness_check
        }
        
        return metrics, predictions, targets, physics_metrics


# =====================================================================
# PHYSICS-INFORMED LOSS FUNCTIONS
# =====================================================================

class PhysicsInformedLoss(nn.Module):
    """
    Loss function with ocean physics constraints for XAI4Science
    Incorporates:
    1. Temperature-salinity relationships
    2. Physical bounds
    3. PDE residual (heat advection-diffusion equation)
    """
    
    def __init__(self, lambda_data=1.0, lambda_physics=0.1, lambda_ts_relation=0.05, lambda_pde=0.01):
        super().__init__()
        self.lambda_data = lambda_data  # Weight for data-fitting term
        self.lambda_physics = lambda_physics  # Weight for physics constraints
        self.lambda_ts_relation = lambda_ts_relation  # Weight for T-S relationship
        self.lambda_pde = lambda_pde  # Weight for PDE residual
        
    def forward(self, predictions, targets, inputs_branch, inputs_trunk=None, 
                scaler_branch=None, scaler_target=None):
        """
        Args:
            predictions: Model predictions (scaled)
            targets: Ground truth (scaled)
            inputs_branch: Branch inputs containing pressure and salinity (scaled)
            inputs_trunk: Trunk inputs containing lat, lon, time (scaled) - for PDE
            scaler_branch: Scaler for branch inputs (to get original salinity)
            scaler_target: Scaler for target (to get original temperature)
        """
        # Standard MSE loss
        mse_loss = nn.functional.mse_loss(predictions, targets)
        
        # Physics constraint 1: Temperature-Salinity relationship
        # In scaled space, we need to inverse transform for physical constraints
        if scaler_branch is not None and scaler_target is not None:
            # Get original scale values
            branch_original = torch.from_numpy(
                scaler_branch.inverse_transform(inputs_branch.detach().cpu().numpy())
            ).to(predictions.device)
            temp_original = torch.from_numpy(
                scaler_target.inverse_transform(predictions.detach().cpu().numpy())
            ).to(predictions.device)
            
            salinity = branch_original[:, 1:2]  # Second column is salinity
            temperature = temp_original
            
            # T-S relationship violation
            # In ocean: Higher salinity generally correlates with higher density
            # Extreme T-S combinations are physically unlikely
            ts_violation = self._check_ts_relationship(temperature, salinity)
            
            # Temperature bounds violation (in original scale)
            bounds_violation = self._check_temperature_bounds(temperature)
            
            # PDE residual (if inputs have gradients and trunk is provided)
            if inputs_trunk is not None and inputs_branch.requires_grad and inputs_trunk.requires_grad:
                pde_residual = self._compute_pde_residual(
                    predictions, inputs_branch, inputs_trunk, 
                    scaler_branch, scaler_target
                )
                # Store for debugging (only occasionally to avoid spam)
                if not hasattr(self, '_pde_debug_counter'):
                    self._pde_debug_counter = 0
                self._pde_debug_counter += 1
                if self._pde_debug_counter % 100 == 0:  # Print every 100 batches
                    print(f"\n  PDE residual: {pde_residual.item():.6f}, T-S: {ts_violation.item():.6f}, Bounds: {bounds_violation.item():.6f}")
            else:
                pde_residual = torch.tensor(0.0, device=predictions.device)
                if not hasattr(self, '_gradient_warning_shown'):
                    print("Warning: PDE loss disabled - inputs don't have gradients enabled")
                    self._gradient_warning_shown = True
        else:
            ts_violation = torch.tensor(0.0, device=predictions.device)
            bounds_violation = torch.tensor(0.0, device=predictions.device)
            pde_residual = torch.tensor(0.0, device=predictions.device)
        
        # Combined loss with PDE residual
        total_loss = (
            self.lambda_data * mse_loss +
            self.lambda_physics * bounds_violation +
            self.lambda_ts_relation * ts_violation +
            self.lambda_pde * pde_residual
        )
        
        return total_loss
    
    def _compute_pde_residual(self, predictions, inputs_branch, inputs_trunk, 
                              scaler_branch, scaler_target):
        """
        Compute PDE residual for ocean heat equation (simplified 1D vertical):
        ∂T/∂t - κ ∂²T/∂z² = 0
        
        where κ is thermal diffusivity, z is depth (pressure proxy), t is time
        
        CRITICAL: Requires inputs_branch and inputs_trunk to have gradients enabled
        
        NOTE: We compute derivatives w.r.t. the FULL input tensors, then extract
        the relevant columns. This is because the neural network mixes all features
        together, so we need to differentiate through the entire forward pass.
        """
        try:
            # Verify gradients are enabled on FULL tensors
            if not inputs_branch.requires_grad or not inputs_trunk.requires_grad:
                print("Warning: PDE loss requires gradients on inputs, but they are disabled")
                return torch.tensor(0.0, device=predictions.device)
            
            # Compute full Jacobians (derivatives w.r.t. all input dimensions)
            # This captures how predictions change with each input feature
            grad_branch = torch.autograd.grad(
                outputs=predictions,
                inputs=inputs_branch,
                grad_outputs=torch.ones_like(predictions),
                create_graph=True,
                retain_graph=True,
                only_inputs=True
            )[0]  # Shape: (batch, branch_dim)
            
            grad_trunk = torch.autograd.grad(
                outputs=predictions,
                inputs=inputs_trunk,
                grad_outputs=torch.ones_like(predictions),
                create_graph=True,
                retain_graph=True,
                only_inputs=True
            )[0]  # Shape: (batch, trunk_dim)
            
            # Extract derivatives w.r.t. specific physical variables
            # Branch: [pressure, salinity] - we want ∂T/∂pressure
            dT_dz = grad_branch[:, 0:1]  # First column is pressure (depth proxy)
            
            # Trunk: [lat, lon, time] - we want ∂T/∂time
            dT_dt = grad_trunk[:, 2:3]  # Third column is time
            
            # Second derivative w.r.t. pressure (need to differentiate again)
            # This computes ∂²T/∂z² by differentiating dT_dz w.r.t. pressure column
            d2T_dz2 = torch.autograd.grad(
                outputs=dT_dz,
                inputs=inputs_branch,
                grad_outputs=torch.ones_like(dT_dz),
                create_graph=True,
                retain_graph=True,
                only_inputs=True
            )[0][:, 0:1]  # Extract pressure column again
            
            # Thermal diffusivity (typical ocean value: ~1e-5 m²/s)
            kappa = 1e-5
            
            # PDE residual: ∂T/∂t - κ ∂²T/∂z²
            pde_residual = dT_dt - kappa * d2T_dz2
            
            # Return mean squared residual
            return pde_residual.pow(2).mean()
            
        except Exception as e:
            # If PDE computation fails, return zero (don't break training)
            # This should only happen if gradients are not properly enabled
            print(f"Warning: PDE residual computation failed ({e}), returning zero")
            return torch.tensor(0.0, device=predictions.device)
    
    def _check_temperature_bounds(self, temperature):
        """Penalize temperatures outside physical bounds"""
        temp_min, temp_max = -2.0, 35.0
        
        below_min = torch.relu(temp_min - temperature)
        above_max = torch.relu(temperature - temp_max)
        
        return (below_min + above_max).mean()
    
    def _check_ts_relationship(self, temperature, salinity):
        """
        Check Temperature-Salinity relationship
        Ocean water masses have characteristic T-S signatures
        """
        # Typical ocean T-S relationship (simplified)
        # Warmer water can hold more salt in solution
        # Expected salinity range: 33-37 PSU for open ocean
        # Expected temp range: 0-30°C for this depth/region
        
        # Penalize unusual T-S combinations
        # Low temp (< 5°C) with very high salinity (> 36) is unusual
        unusual_cold_salty = torch.relu((salinity - 36.0) * torch.relu(5.0 - temperature))
        
        # Very high temp (> 28°C) with very low salinity (< 34) is unusual
        unusual_warm_fresh = torch.relu((temperature - 28.0) * torch.relu(34.0 - salinity))
        
        # Extreme salinity deviations from mean (35 PSU)
        salinity_extreme = torch.relu(torch.abs(salinity - 35.0) - 3.0)  # Beyond ±3 PSU
        
        violation = unusual_cold_salty.mean() + unusual_warm_fresh.mean() + salinity_extreme.mean()
        
        return violation


# =====================================================================
# UNCERTAINTY QUANTIFICATION (DEEP ENSEMBLE)
# =====================================================================

class DeepEnsemble:
    """
    Deep Ensemble for Uncertainty Quantification
    Uses multiple trained models to estimate epistemic uncertainty
    """
    
    def __init__(self, models: List[nn.Module], device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
        """
        Args:
            models: List of trained models (from different random seeds)
            device: Device to run inference on
        """
        self.models = models
        self.device = device
        self.n_models = len(models)
        
        # Set all models to eval mode
        for model in self.models:
            model.eval()
            model.to(device)
    
    def predict_with_uncertainty(self, x_branch, x_trunk):
        """
        Compute predictions with epistemic uncertainty
        
        Args:
            x_branch: Branch inputs
            x_trunk: Trunk inputs
            
        Returns:
            mean_pred: Ensemble mean prediction (batch, 1)
            std_pred: Epistemic uncertainty (batch, 1)
            all_preds: All individual predictions (n_models, batch, 1)
        """
        x_branch = x_branch.to(self.device)
        x_trunk = x_trunk.to(self.device)
        
        predictions = []
        with torch.no_grad():
            for model in self.models:
                pred = model(x_branch, x_trunk)
                predictions.append(pred)
        
        # Stack predictions: (n_models, batch, 1)
        all_preds = torch.stack(predictions, dim=0)
        
        # Compute ensemble statistics
        mean_pred = all_preds.mean(dim=0)  # (batch, 1)
        std_pred = all_preds.std(dim=0)    # (batch, 1) - epistemic uncertainty
        
        return mean_pred, std_pred, all_preds
    
    def evaluate_with_uncertainty(self, test_loader, scaler):
        """
        Evaluate ensemble on test set with uncertainty
        
        Returns:
            Dictionary with metrics, predictions, uncertainties, and targets
        """
        all_mean_preds = []
        all_std_preds = []
        all_targets = []
        
        for x_branch, x_trunk, y in test_loader:
            mean_pred, std_pred, _ = self.predict_with_uncertainty(x_branch, x_trunk)
            
            all_mean_preds.append(mean_pred.cpu())
            all_std_preds.append(std_pred.cpu())
            all_targets.append(y.cpu())
        
        # Concatenate all batches
        mean_predictions = torch.cat(all_mean_preds, dim=0).numpy()
        std_predictions = torch.cat(all_std_preds, dim=0).numpy()
        targets = torch.cat(all_targets, dim=0).numpy()
        
        # Inverse transform to original scale
        mean_predictions_original = scaler.inverse_transform(mean_predictions)
        targets_original = scaler.inverse_transform(targets)
        
        # Compute metrics on original scale
        mse_val = mean_squared_error(targets_original, mean_predictions_original)
        mae_val = mean_absolute_error(targets_original, mean_predictions_original)
        r2_val = r2_score(targets_original, mean_predictions_original)
        
        mask = np.abs(targets_original) > 1e-8
        mape_val = np.mean(np.abs((targets_original[mask] - mean_predictions_original[mask]) / 
                                   targets_original[mask])) * 100 if mask.sum() > 0 else 0.0
        
        metrics = {
            'mse': float(mse_val),
            'rmse': float(np.sqrt(mse_val)),
            'mae': float(mae_val),
            'r2': float(r2_val),
            'mape': float(mape_val)
        }
        
        return {
            'metrics': metrics,
            'mean_predictions': mean_predictions_original.flatten(),
            'std_predictions': std_predictions.flatten(),
            'targets': targets_original.flatten()
        }
    
    def uncertainty_vs_error_analysis(self, test_loader, scaler):
        """
        Analyze relationship between uncertainty and prediction error
        High correlation = good calibration (model knows when it's uncertain)
        
        Returns:
            Dictionary with correlation, plots data, and statistics
        """
        results = self.evaluate_with_uncertainty(test_loader, scaler)
        
        # Compute absolute errors
        errors = np.abs(results['targets'] - results['mean_predictions'])
        uncertainties = results['std_predictions']
        
        # Compute correlation
        correlation = np.corrcoef(uncertainties, errors)[0, 1]
        
        # Compute calibration metrics
        # Divide into uncertainty quantiles
        n_bins = 10
        uncertainty_percentiles = np.percentile(uncertainties, np.linspace(0, 100, n_bins + 1))
        
        bin_means_uncertainty = []
        bin_means_error = []
        
        for i in range(n_bins):
            mask = (uncertainties >= uncertainty_percentiles[i]) & (uncertainties < uncertainty_percentiles[i + 1])
            if mask.sum() > 0:
                bin_means_uncertainty.append(uncertainties[mask].mean())
                bin_means_error.append(errors[mask].mean())
        
        return {
            'correlation': float(correlation),
            'uncertainties': uncertainties,
            'errors': errors,
            'bin_means_uncertainty': np.array(bin_means_uncertainty),
            'bin_means_error': np.array(bin_means_error),
            'mean_predictions': results['mean_predictions'],
            'targets': results['targets'],
            'metrics': results['metrics']
        }


# =====================================================================
# PHYSICS VALIDATION METRICS (FOR XAI4SCIENCE)
# =====================================================================

class PhysicsValidator:
    """Validate that model predictions respect ocean physics"""
    
    @staticmethod
    def check_temperature_bounds(predictions: np.ndarray) -> Dict[str, float]:
        """Check if temperatures are within physical bounds"""
        valid_min, valid_max = -2.0, 35.0
        
        below_min = (predictions < valid_min).sum()
        above_max = (predictions > valid_max).sum()
        total = len(predictions)
        
        return {
            'violations_below_min': int(below_min),
            'violations_above_max': int(above_max),
            'violation_rate': float((below_min + above_max) / total),
            'mean_temp': float(np.mean(predictions)),
            'std_temp': float(np.std(predictions))
        }
    
    @staticmethod
    def check_stratification(predictions: np.ndarray, pressures: np.ndarray,
                            coords: np.ndarray, tolerance: float = 0.01) -> Dict[str, float]:
        """
        Check if temperature decreases with depth (stable stratification)
        Groups by spatial location and checks vertical profile
        """
        violations = 0
        total_pairs = 0
        
        # Group measurements by approximate location (lat, lon within tolerance)
        from collections import defaultdict
        profiles = defaultdict(list)
        
        for i, (pred, pres, coord) in enumerate(zip(predictions, pressures, coords)):
            # Round coordinates to group nearby measurements
            lat_lon_key = (round(coord[0], 1), round(coord[1], 1))
            profiles[lat_lon_key].append((pres, pred[0]))
        
        # Check each profile for stratification
        for profile in profiles.values():
            if len(profile) < 2:
                continue
            
            # Sort by pressure (depth)
            sorted_profile = sorted(profile, key=lambda x: x[0])
            
            # Check temperature gradient
            for i in range(len(sorted_profile) - 1):
                shallow_temp = sorted_profile[i][1]
                deep_temp = sorted_profile[i + 1][1]
                
                # Temperature should decrease or stay constant with depth
                if deep_temp > shallow_temp + tolerance:
                    violations += 1
                total_pairs += 1
        
        return {
            'stratification_violations': int(violations),
            'total_depth_pairs': int(total_pairs),
            'stratification_violation_rate': float(violations / max(total_pairs, 1)),
            'num_profiles': len(profiles)
        }
    
    @staticmethod
    def check_spatial_smoothness(predictions: np.ndarray, coords: np.ndarray) -> Dict[str, float]:
        """Check if predictions are spatially smooth (no extreme jumps)"""
        # Calculate spatial distances and temperature differences
        from scipy.spatial.distance import pdist, squareform
        
        # Sample subset for efficiency
        n_samples = min(1000, len(predictions))
        idx = np.random.choice(len(predictions), n_samples, replace=False)
        
        sampled_coords = coords[idx, :2]  # lat, lon only
        sampled_temps = predictions[idx]
        
        # Compute pairwise distances
        spatial_dist = squareform(pdist(sampled_coords))
        temp_diff = squareform(pdist(sampled_temps))
        
        # Find nearby points (within 1 degree)
        nearby_mask = (spatial_dist > 0) & (spatial_dist < 1.0)
        
        if nearby_mask.sum() > 0:
            nearby_temp_diffs = temp_diff[nearby_mask]
            
            # Large temperature differences between nearby points indicate roughness
            large_jumps = (np.abs(nearby_temp_diffs) > 5.0).sum()
            
            return {
                'mean_nearby_temp_diff': float(np.mean(np.abs(nearby_temp_diffs))),
                'max_nearby_temp_diff': float(np.max(np.abs(nearby_temp_diffs))),
                'large_jump_rate': float(large_jumps / nearby_mask.sum()),
                'spatial_roughness_score': float(np.std(nearby_temp_diffs))
            }
        else:
            return {
                'mean_nearby_temp_diff': 0.0,
                'max_nearby_temp_diff': 0.0,
                'large_jump_rate': 0.0,
                'spatial_roughness_score': 0.0
            }


# =====================================================================
# INTERPRETABILITY ANALYSIS
# =====================================================================

class InterpretabilityAnalyzer:
    """Captum-based interpretability analysis"""
    
    def __init__(self, model, device: str = 'cuda' if torch.cuda.is_available() else 'cpu', is_sequence_model: bool = False):
        self.model = model.to(device)
        self.device = device
        self.is_sequence_model = is_sequence_model
        
        # CRITICAL: LSTM with dropout requires training mode for gradients
        # Keep model in train mode for XAI (gradients won't flow in eval mode)
        if is_sequence_model:
            self.model.train()
            print("  ⚠️  LSTM in training mode for XAI (required for gradient computation)")
        else:
            self.model.eval()
        
    def _prepare_inputs(self, x_branch, x_trunk):
        """Prepare inputs for Captum (requires gradients)"""
        x_branch = x_branch.to(self.device).requires_grad_(True)
        x_trunk = x_trunk.to(self.device).requires_grad_(True)
        return x_branch, x_trunk
    
    def _aggregate_sequence_attributions(self, attributions):
        """
        Aggregate 3D sequence attributions to 2D for visualization.
        For sequence models: (batch, seq_len, features) → (batch, features)
        
        Strategy: Sum over time dimension - represents total feature importance
        across the entire sequence (methodologically valid for temporal models)
        """
        if len(attributions.shape) == 3:
            # Sum over sequence dimension (dim=1)
            return attributions.sum(dim=1)
        return attributions
    
    def _prepare_combined_input(self, x_branch, x_trunk):
        """
        Prepare combined input and forward function for Captum.
        Handles both 2D and 3D inputs automatically.
        
        Returns:
            combined_input: Concatenated tensor
            forward_func: Function that splits and forwards through model
            is_3d: Boolean indicating if inputs are 3D sequences
            split_dims: Dimensions for splitting (branch_features, trunk_features)
        """
        is_3d = len(x_branch.shape) == 3
        
        if is_3d:
            # 3D: (batch, seq_len, features)
            branch_features = x_branch.shape[2]
            trunk_features = x_trunk.shape[2]
            combined_input = torch.cat([x_branch, x_trunk], dim=-1)
            
            def forward_func(inputs):
                branch = inputs[:, :, :branch_features]
                trunk = inputs[:, :, branch_features:]
                return self.model(branch, trunk)
        else:
            # 2D: (batch, features)
            branch_features = x_branch.shape[1]
            trunk_features = x_trunk.shape[1]
            combined_input = torch.cat([x_branch, x_trunk], dim=1)
            
            def forward_func(inputs):
                branch, trunk = torch.split(inputs, [branch_features, trunk_features], dim=1)
                return self.model(branch, trunk)
        
        return combined_input, forward_func, is_3d, (branch_features, trunk_features)
    
    def _split_and_aggregate_attributions(self, attributions, is_3d, split_dims):
        """
        Split attributions back into branch/trunk and aggregate if 3D.
        
        Returns:
            branch_attr: (batch, branch_features) as numpy array
            trunk_attr: (batch, trunk_features) as numpy array
        """
        branch_features, trunk_features = split_dims
        
        if is_3d:
            # Split on feature dimension (dim=2)
            branch_attr = attributions[:, :, :branch_features]
            trunk_attr = attributions[:, :, branch_features:]
            
            # Aggregate over sequence dimension
            branch_attr = self._aggregate_sequence_attributions(branch_attr).detach().cpu().numpy()
            trunk_attr = self._aggregate_sequence_attributions(trunk_attr).detach().cpu().numpy()
        else:
            # Split on feature dimension (dim=1)
            branch_attr = attributions[:, :branch_features].detach().cpu().numpy()
            trunk_attr = attributions[:, branch_features:].detach().cpu().numpy()
        
        return branch_attr, trunk_attr
    
    def integrated_gradients(self, x_branch, x_trunk, n_steps: int = 50):
        """
        Compute Integrated Gradients attribution.
        Handles both 2D (batch, features) and 3D (batch, seq_len, features) inputs.
        For 3D: Aggregates over sequence dimension (sum = total importance across time).
        """
        x_branch, x_trunk = self._prepare_inputs(x_branch, x_trunk)
        
        combined_input, forward_func, is_3d, split_dims = self._prepare_combined_input(x_branch, x_trunk)
        baseline = torch.zeros_like(combined_input)
        
        ig = IntegratedGradients(forward_func)
        
        try:
            attributions = ig.attribute(combined_input, baseline, n_steps=n_steps)
        except Exception as e:
            print(f"  ⚠ Integrated Gradients failed: {e}")
            import traceback
            traceback.print_exc()
            attributions = torch.zeros_like(combined_input)
        
        return self._split_and_aggregate_attributions(attributions, is_3d, split_dims)
    
    def saliency_maps(self, x_branch, x_trunk):
        """
        Compute saliency maps (gradient-based attribution).
        Handles both 2D and 3D inputs. For 3D: Aggregates over sequence dimension.
        """
        x_branch, x_trunk = self._prepare_inputs(x_branch, x_trunk)
        
        combined_input, forward_func, is_3d, split_dims = self._prepare_combined_input(x_branch, x_trunk)
        
        saliency = Saliency(forward_func)
        
        try:
            attributions = saliency.attribute(combined_input)
        except Exception as e:
            print(f"  ⚠ Saliency Maps failed: {e}")
            attributions = torch.zeros_like(combined_input)
        
        return self._split_and_aggregate_attributions(attributions, is_3d, split_dims)
    
    def deeplift(self, x_branch, x_trunk):
        """
        Compute DeepLIFT attribution.
        Handles both 2D and 3D inputs. For 3D: Aggregates over sequence dimension.
        
        Note: DeepLIFT requires nn.Module, not a function. We wrap the model.
        """
        x_branch, x_trunk = self._prepare_inputs(x_branch, x_trunk)
        
        combined_input, forward_func, is_3d, split_dims = self._prepare_combined_input(x_branch, x_trunk)
        baseline = torch.zeros_like(combined_input)
        
        try:
            # DeepLIFT needs nn.Module with proper forward hooks
            # Wrap the model to accept combined input
            class CombinedInputWrapper(nn.Module):
                def __init__(self, base_model, branch_features, is_3d):
                    super().__init__()
                    self.base_model = base_model
                    self.branch_features = branch_features
                    self.is_3d = is_3d
                
                def forward(self, combined_input):
                    if self.is_3d:
                        # 3D: split on last dimension
                        branch = combined_input[:, :, :self.branch_features]
                        trunk = combined_input[:, :, self.branch_features:]
                    else:
                        # 2D: split on second dimension
                        branch = combined_input[:, :self.branch_features]
                        trunk = combined_input[:, self.branch_features:]
                    return self.base_model(branch, trunk)
            
            wrapper = CombinedInputWrapper(self.model, split_dims[0], is_3d)
            if self.is_sequence_model:
                wrapper.train()  # Keep in train mode for LSTM
            else:
                wrapper.eval()
            
            dl = DeepLift(wrapper)
            attributions = dl.attribute(combined_input, baseline)
        except Exception as e:
            print(f"  ⚠ DeepLIFT failed: {e}")
            import traceback
            traceback.print_exc()
            attributions = torch.zeros_like(combined_input)
        
        return self._split_and_aggregate_attributions(attributions, is_3d, split_dims)
    
    def gradient_shap(self, x_branch, x_trunk, n_samples: int = 50):
        """
        Compute GradientSHAP attribution (combines gradients with SHAP).
        Handles both 2D and 3D inputs. For 3D: Aggregates over sequence dimension.
        """
        x_branch, x_trunk = self._prepare_inputs(x_branch, x_trunk)
        
        combined_input, forward_func, is_3d, split_dims = self._prepare_combined_input(x_branch, x_trunk)
        
        # Create baseline distribution (Gaussian noise around zero)
        baselines = torch.randn(n_samples, *combined_input.shape[1:], device=combined_input.device) * 0.1
        
        gs = GradientShap(forward_func)
        
        try:
            attributions = gs.attribute(combined_input, baselines=baselines, n_samples=n_samples)
        except Exception as e:
            print(f"  ⚠ GradientSHAP failed: {e}")
            attributions = torch.zeros_like(combined_input)
        
        return self._split_and_aggregate_attributions(attributions, is_3d, split_dims)
    
    def layer_attribution(self, x_branch, x_trunk, layer_name: str = None):
        """Compute layer-wise attributions to understand internal representations"""
        x_branch, x_trunk = self._prepare_inputs(x_branch, x_trunk)
        
        # Auto-detect appropriate layer if not specified
        if layer_name is None:
            if hasattr(self.model, 'branch_net'):
                layer_name = 'branch_net.0'  # DeepONet
            elif hasattr(self.model, 'fourier_layers'):
                layer_name = 'fourier_layers.0.linear'  # FNO
            elif hasattr(self.model, 'fc'):
                layer_name = 'fc.0'  # LSTM
            else:
                print(f"Warning: Could not auto-detect layer")
                return None, None
        
        # Get the specified layer
        try:
            layer = dict(self.model.named_modules())[layer_name]
        except KeyError:
            print(f"Warning: Layer {layer_name} not found in model")
            # List available layers for debugging
            available = list(dict(self.model.named_modules()).keys())[:10]
            print(f"  Available layers (first 10): {available}")
            return None, None
        
        # Prepare combined input and forward function (handle 2D vs 3D)
        combined_input, forward_func, is_3d, split_dims = self._prepare_combined_input(x_branch, x_trunk)
        
        try:
            layer_gc = LayerGradientXActivation(forward_func, layer)
            attributions = layer_gc.attribute(combined_input)
            
            # Aggregate over neurons
            if is_3d:
                # For 3D: aggregate over sequence dimension first
                aggregated = attributions.abs().mean(dim=[1, 2]).detach().cpu().numpy()
            else:
                aggregated = attributions.abs().mean(dim=1, keepdim=True).detach().cpu().numpy()
            
            return aggregated, layer_name
        except Exception as e:
            print(f"  ⚠ Layer Attributions failed: {e}")
            return None, None
    
    def layer_conductance_analysis(self, x_branch, x_trunk):
        """
        Comprehensive layer conductance analysis for all layers
        Shows which layers contribute most to predictions
        """
        x_branch, x_trunk = self._prepare_inputs(x_branch, x_trunk)
        
        # Prepare combined input and forward function (handle 2D vs 3D)
        combined_input, forward_func, is_3d, split_dims = self._prepare_combined_input(x_branch, x_trunk)
        
        # Get all layers from the model
        layer_results = {}
        
        # Analyze DeepONet branch and trunk networks separately
        if hasattr(self.model, 'branch_net'):
            print("  Analyzing branch network layers...")
            for name, layer in self.model.branch_net.named_children():
                if isinstance(layer, nn.Linear):
                    try:
                        layer_cond = LayerConductance(forward_func, layer)
                        conductance = layer_cond.attribute(combined_input)
                        
                        # Aggregate conductance across neurons and samples
                        mean_conductance = conductance.abs().mean().item()
                        layer_results[f'branch_{name}'] = mean_conductance
                    except Exception as e:
                        print(f"    Warning: Conductance for branch_{name} failed ({e})")
        
        if hasattr(self.model, 'trunk_net'):
            print("  Analyzing trunk network layers...")
            for name, layer in self.model.trunk_net.named_children():
                if isinstance(layer, nn.Linear):
                    try:
                        layer_cond = LayerConductance(forward_func, layer)
                        conductance = layer_cond.attribute(combined_input)
                        
                        mean_conductance = conductance.abs().mean().item()
                        layer_results[f'trunk_{name}'] = mean_conductance
                    except Exception as e:
                        print(f"    Warning: Conductance for trunk_{name} failed ({e})")
        
        # For FNO and LSTM, analyze their layers
        if hasattr(self.model, 'fourier_layers'):
            print("  Analyzing Fourier layers...")
            for idx, layer in enumerate(self.model.fourier_layers):
                try:
                    layer_cond = LayerConductance(forward_func, layer.linear)
                    conductance = layer_cond.attribute(combined_input)
                    mean_conductance = conductance.abs().mean().item()
                    layer_results[f'fourier_{idx}'] = mean_conductance
                except Exception as e:
                    print(f"    Warning: Conductance for fourier_{idx} failed ({e})")
        
        if hasattr(self.model, 'lstm'):
            print("  Analyzing LSTM layers...")
            
            # Skip LSTM layer itself (too complex for conductance with 3D inputs)
            # Instead analyze FC layers which work with 2D outputs
            if hasattr(self.model, 'fc'):
                for idx, layer in enumerate(self.model.fc):
                    if isinstance(layer, nn.Linear):
                        try:
                            layer_cond = LayerConductance(forward_func, layer)
                            conductance = layer_cond.attribute(combined_input)
                            mean_conductance = conductance.abs().mean().item()
                            layer_results[f'lstm_fc_{idx}'] = mean_conductance
                        except Exception as e:
                            print(f"    Warning: Conductance for lstm_fc_{idx} failed ({e})")
            
            # Note: LSTM recurrent layer skipped (complex due to 3D sequences)
            print("    Note: LSTM recurrent layer skipped (use other XAI methods)")
        
        return layer_results
    
    def analyze_all(self, x_branch, x_trunk):
        """Run all interpretability methods (6 methods for XAI4Science)"""
        results = {}
        
        # Determine feature dimensions (handle 2D and 3D inputs)
        branch_features = x_branch.shape[-1]  # Last dim is always features
        trunk_features = x_trunk.shape[-1]
        batch_size = x_branch.shape[0]
        
        try:
            print("Computing Integrated Gradients...")
            results['ig_branch'], results['ig_trunk'] = self.integrated_gradients(x_branch, x_trunk)
        except Exception as e:
            print(f"  ⚠ Integrated Gradients failed: {e}")
            results['ig_branch'] = np.zeros((batch_size, branch_features))
            results['ig_trunk'] = np.zeros((batch_size, trunk_features))
        
        try:
            print("Computing Saliency Maps...")
            results['saliency_branch'], results['saliency_trunk'] = self.saliency_maps(x_branch, x_trunk)
        except Exception as e:
            print(f"  ⚠ Saliency Maps failed: {e}")
            results['saliency_branch'] = np.zeros((batch_size, branch_features))
            results['saliency_trunk'] = np.zeros((batch_size, trunk_features))
        
        try:
            print("Computing DeepLIFT...")
            results['deeplift_branch'], results['deeplift_trunk'] = self.deeplift(x_branch, x_trunk)
        except Exception as e:
            print(f"  ⚠ DeepLIFT failed: {e}")
            results['deeplift_branch'] = np.zeros((batch_size, branch_features))
            results['deeplift_trunk'] = np.zeros((batch_size, trunk_features))
        
        try:
            print("Computing GradientSHAP...")
            results['gradshap_branch'], results['gradshap_trunk'] = self.gradient_shap(x_branch, x_trunk)
        except Exception as e:
            print(f"  ⚠ GradientSHAP failed: {e}")
            results['gradshap_branch'] = np.zeros((batch_size, branch_features))
            results['gradshap_trunk'] = np.zeros((batch_size, trunk_features))
        
        try:
            print("Computing Layer Conductance Analysis...")
            layer_conductance = self.layer_conductance_analysis(x_branch, x_trunk)
            if layer_conductance and len(layer_conductance) > 0:
                results['layer_conductance'] = layer_conductance
            else:
                print("  ⚠ No layer conductance results")
        except Exception as e:
            print(f"  ⚠ Layer Conductance Analysis failed: {e}")
        
        try:
            print("Computing Layer Attributions...")
            layer_attr, layer_name = self.layer_attribution(x_branch, x_trunk)
            if layer_attr is not None:
                results['layer_attribution'] = layer_attr
                results['layer_name'] = layer_name
        except Exception as e:
            print(f"  ⚠ Layer Attributions failed: {e}")
        
        return results


# =====================================================================
# MULTI-SEED STATISTICAL ANALYSIS
# =====================================================================

class MultiSeedExperiment:
    """Run experiments with multiple seeds for statistical rigor"""
    
    def __init__(self, data_splits, num_seeds: int = 5):
        if num_seeds < 1:
            raise ValueError("num_seeds must be at least 1")
        
        self.data_splits = data_splits
        self.num_seeds = num_seeds
        self.results = {
            'DeepONet': [],
            'DeepONet-Physics': [],
            'FNO': [],
            'LSTM': []
        }
        
    def run_single_experiment(self, model_name: str, model_class, seed: int, 
                            model_kwargs: dict, train_kwargs: dict):
        """Run single experiment with given seed"""
        set_seed(seed)
        
        # Create data loaders
        batch_size = min(256, len(self.data_splits['X_branch_train']) // 4)
        
        train_dataset = TensorDataset(
            self.data_splits['X_branch_train'],
            self.data_splits['X_trunk_train'],
            self.data_splits['y_train']
        )
        val_dataset = TensorDataset(
            self.data_splits['X_branch_val'],
            self.data_splits['X_trunk_val'],
            self.data_splits['y_val']
        )
        test_dataset = TensorDataset(
            self.data_splits['X_branch_test'],
            self.data_splits['X_trunk_test'],
            self.data_splits['y_test']
        )
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
        
        # Initialize model
        model = model_class(**model_kwargs)
        
        # Check if model should use physics loss
        use_physics_loss = model_kwargs.get('use_physics_loss', False)
        
        trainer = ModelTrainer(model, use_physics_loss=use_physics_loss, scalers=processor.scalers)
        
        # Train
        print(f"\n{'='*60}")
        print(f"Training {model_name} (Seed {seed})")
        print(f"{'='*60}")
        trainer.train(train_loader, val_loader, **train_kwargs)
        
        # Evaluate
        metrics, preds, targets = trainer.evaluate(
            test_loader, 
            processor.scalers['target']
        )
        
        return {
            'model': model,
            'trainer': trainer,
            'metrics': metrics,
            'predictions': preds,
            'targets': targets
        }
    
    def run_all_models(self, epochs: int = 100):
        """Run all models with multiple seeds"""
        
        # Model configurations
        models_config = {
            'DeepONet': {
                'class': DeepONet,
                'kwargs': {
                    'branch_dim': 2,  # PRES, PSAL
                    'trunk_dim': 3,   # LAT, LON, TIME
                    'hidden_dim': 128,
                    'num_layers': 4,
                    'use_physics_loss': False
                }
            },
            'DeepONet-Physics': {
                'class': DeepONet,
                'kwargs': {
                    'branch_dim': 2,
                    'trunk_dim': 3,
                    'hidden_dim': 128,
                    'num_layers': 4,
                    'use_physics_loss': True  # Soft constraints via loss function
                }
            },
            'FNO': {
                'class': FNO1d,
                'kwargs': {
                    'input_dim': 5,  # branch + trunk
                    'hidden_dim': 64,
                    'num_layers': 4,
                    'modes': 16
                }
            },
            'LSTM': {
                'class': LSTMPredictor,
                'kwargs': {
                    'input_dim': 5,  # branch + trunk
                    'hidden_dim': 128,
                    'num_layers': 2,
                    'dropout': 0.2,
                    'use_sequences': True  # Enable proper sequence processing
                },
                'use_sequence_data': True  # Flag to use sequence dataset
            }
        }
        
        train_kwargs = {
            'epochs': epochs,
            'lr': 1e-3,
            'patience': 15
        }
        
        for model_name, config in models_config.items():
            print(f"\n{'#'*60}")
            print(f"# Running {model_name} with {self.num_seeds} seeds")
            print(f"{'#'*60}")
            
            # Use sequence data for LSTM if available
            if config.get('use_sequence_data', False) and hasattr(self, 'data_splits_lstm'):
                print(f"  → Using sequence data for {model_name}")
                current_data_splits = self.data_splits_lstm
            else:
                current_data_splits = self.data_splits
            
            for seed in range(self.num_seeds):
                try:
                    # Temporarily override data_splits for this model
                    original_splits = self.data_splits
                    self.data_splits = current_data_splits
                    
                    result = self.run_single_experiment(
                        model_name, 
                        config['class'],
                        seed,
                        config['kwargs'],
                        train_kwargs
                    )
                    self.results[model_name].append(result)
                    
                    # Restore original data splits
                    self.data_splits = original_splits
                except Exception as e:
                    print(f"Error training {model_name} with seed {seed}: {e}")
                    print("Continuing with next seed...")
                    self.data_splits = original_splits  # Ensure restoration even on error
                    continue
        
        return self.results
    
    def statistical_analysis(self):
        """Perform statistical analysis across seeds"""
        summary = {}
        
        for model_name, results in self.results.items():
            if len(results) == 0:
                print(f"Warning: No results for {model_name}")
                continue
            
            metrics_list = [r['metrics'] for r in results]
            
            # Aggregate metrics
            metrics_df = pd.DataFrame(metrics_list)
            
            summary[model_name] = {
                'mean': metrics_df.mean().to_dict(),
                'std': metrics_df.std().to_dict(),
                'ci_95': {
                    metric: stats.t.interval(
                        0.95, 
                        len(metrics_list) - 1,
                        loc=metrics_df[metric].mean(),
                        scale=stats.sem(metrics_df[metric])
                    ) if len(metrics_list) > 1 else (metrics_df[metric].mean(), metrics_df[metric].mean())
                    for metric in metrics_df.columns
                }
            }
        
        return summary
    
    def compare_models(self):
        """Statistical comparison between models"""
        comparisons = {}
        
        models = [m for m in self.results.keys() if len(self.results[m]) > 0]
        
        for i, model1 in enumerate(models):
            for model2 in models[i+1:]:
                # Extract RMSE values
                rmse1 = [r['metrics']['rmse'] for r in self.results[model1]]
                rmse2 = [r['metrics']['rmse'] for r in self.results[model2]]
                
                if len(rmse1) < 2 or len(rmse2) < 2:
                    print(f"Warning: Not enough samples to compare {model1} vs {model2}")
                    continue
                
                # Paired t-test
                t_stat, p_value = stats.ttest_rel(rmse1, rmse2)
                
                comparisons[f"{model1}_vs_{model2}"] = {
                    't_statistic': float(t_stat),
                    'p_value': float(p_value),
                    'significant': bool(p_value < 0.05),
                    'mean_diff': float(np.mean(rmse1) - np.mean(rmse2))
                }
        
        return comparisons


# =====================================================================
# VISUALIZATION
# =====================================================================

class ResultVisualizer:
    """Comprehensive visualization for results and interpretability"""
    
    @staticmethod
    def plot_training_curves(results_dict, save_path: Optional[str] = None):
        """Plot training curves for all models"""
        fig, axes = plt.subplots(1, len(results_dict), figsize=(15, 4))
        
        if len(results_dict) == 1:
            axes = [axes]
        
        for idx, (model_name, results) in enumerate(results_dict.items()):
            if len(results) == 0:
                continue
            
            ax = axes[idx]
            
            for seed_idx, result in enumerate(results):
                history = result['trainer'].history
                ax.plot(history['train_loss'], alpha=0.3, color='blue')
                ax.plot(history['val_loss'], alpha=0.3, color='red')
            
            ax.set_xlabel('Epoch')
            ax.set_ylabel('Loss (MSE)')
            ax.set_title(f'{model_name} Training Curves')
            ax.legend(['Train', 'Val'])
            ax.grid(True, alpha=0.3)
        
        plt.tight_layout()
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"  ✓ Saved: {save_path}")
        plt.show()
        plt.close()
    
    @staticmethod
    def plot_metrics_comparison(summary_stats, save_path: Optional[str] = None):
        """Plot comparison of metrics across models with error bars"""
        metrics = ['rmse', 'mae', 'r2', 'mape']
        models = list(summary_stats.keys())
        
        if len(models) == 0:
            print("No models to plot")
            return
        
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))
        axes = axes.flatten()
        
        for idx, metric in enumerate(metrics):
            ax = axes[idx]
            
            means = [summary_stats[model]['mean'][metric] for model in models]
            stds = [summary_stats[model]['std'][metric] for model in models]
            
            x = np.arange(len(models))
            bars = ax.bar(x, means, yerr=stds, capsize=5, alpha=0.7,
                         color=['#2E86AB', '#A23B72', '#F18F01'][:len(models)])
            
            ax.set_xlabel('Model')
            ax.set_ylabel(metric.upper())
            ax.set_title(f'{metric.upper()} Comparison (Mean ± Std)')
            ax.set_xticks(x)
            ax.set_xticklabels(models, rotation=0)
            ax.grid(True, alpha=0.3, axis='y')
            
            # Add value labels on bars
            for bar, mean, std in zip(bars, means, stds):
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height,
                       f'{mean:.3f}\n±{std:.3f}',
                       ha='center', va='bottom', fontsize=9)
        
        plt.tight_layout()
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"  ✓ Saved: {save_path}")
        plt.show()
        plt.close()
    
    @staticmethod
    def plot_predictions_vs_actual(results_dict, save_path: Optional[str] = None):
        """Plot predicted vs actual temperatures"""
        valid_results = {k: v for k, v in results_dict.items() if len(v) > 0}
        
        if len(valid_results) == 0:
            print("No results to plot")
            return
        
        fig, axes = plt.subplots(1, len(valid_results), figsize=(15, 4))
        
        if len(valid_results) == 1:
            axes = [axes]
        
        for idx, (model_name, results) in enumerate(valid_results.items()):
            ax = axes[idx]
            
            # Use first seed for visualization
            preds = results[0]['predictions'].flatten()
            targets = results[0]['targets'].flatten()
            
            # Scatter plot
            ax.scatter(targets, preds, alpha=0.3, s=10, color='navy')
            
            # Perfect prediction line
            min_val = min(targets.min(), preds.min())
            max_val = max(targets.max(), preds.max())
            ax.plot([min_val, max_val], [min_val, max_val], 
                   'r--', linewidth=2, label='Perfect Prediction')
            
            # Metrics
            r2 = results[0]['metrics']['r2']
            rmse = results[0]['metrics']['rmse']
            
            ax.set_xlabel('Actual Temperature (°C)')
            ax.set_ylabel('Predicted Temperature (°C)')
            ax.set_title(f'{model_name}\nR²={r2:.3f}, RMSE={rmse:.3f}')
            ax.legend()
            ax.grid(True, alpha=0.3)
        
        plt.tight_layout()
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"  ✓ Saved: {save_path}")
        plt.show()
        plt.close()
    
    @staticmethod
    def plot_residuals(results_dict, save_path: Optional[str] = None):
        """Plot residual analysis"""
        valid_results = {k: v for k, v in results_dict.items() if len(v) > 0}
        
        if len(valid_results) == 0:
            print("No results to plot")
            return
        
        fig, axes = plt.subplots(2, len(valid_results), figsize=(15, 8))
        
        if len(valid_results) == 1:
            axes = axes.reshape(-1, 1)
        
        for idx, (model_name, results) in enumerate(valid_results.items()):
            preds = results[0]['predictions'].flatten()
            targets = results[0]['targets'].flatten()
            residuals = targets - preds
            
            # Residual scatter
            axes[0, idx].scatter(preds, residuals, alpha=0.3, s=10, color='navy')
            axes[0, idx].axhline(y=0, color='r', linestyle='--', linewidth=2)
            axes[0, idx].set_xlabel('Predicted Temperature (°C)')
            axes[0, idx].set_ylabel('Residuals (°C)')
            axes[0, idx].set_title(f'{model_name} - Residual Plot')
            axes[0, idx].grid(True, alpha=0.3)
            
            # Residual histogram
            axes[1, idx].hist(residuals, bins=50, alpha=0.7, color='navy', edgecolor='black')
            axes[1, idx].axvline(x=0, color='r', linestyle='--', linewidth=2)
            axes[1, idx].set_xlabel('Residuals (°C)')
            axes[1, idx].set_ylabel('Frequency')
            axes[1, idx].set_title(f'Residual Distribution (μ={np.mean(residuals):.3f}°C)')
            axes[1, idx].grid(True, alpha=0.3)
        
        plt.tight_layout()
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"  ✓ Saved: {save_path}")
        plt.show()
        plt.close()
    
    @staticmethod
    def plot_interpretability_analysis(attr_results, feature_names_branch, 
                                      feature_names_trunk, save_path: Optional[str] = None):
        """Plot Captum attribution results (5 methods for XAI4Science)"""
        methods = ['ig', 'saliency', 'deeplift', 'gradshap']
        method_labels = ['Integrated Gradients', 'Saliency Maps', 'DeepLIFT', 'GradientSHAP']
        
        fig, axes = plt.subplots(2, 4, figsize=(20, 10))
        
        for idx, (method, label) in enumerate(zip(methods, method_labels)):
            # Branch attributions
            branch_attr = attr_results[f'{method}_branch']
            branch_mean = np.abs(branch_attr).mean(axis=0)
            
            axes[0, idx].bar(feature_names_branch, branch_mean, color='steelblue', alpha=0.7)
            axes[0, idx].set_title(f'{label} - Branch Network\n(Input Functions)')
            axes[0, idx].set_ylabel('Mean Absolute Attribution')
            axes[0, idx].tick_params(axis='x', rotation=45)
            axes[0, idx].grid(True, alpha=0.3, axis='y')
            
            # Trunk attributions
            trunk_attr = attr_results[f'{method}_trunk']
            trunk_mean = np.abs(trunk_attr).mean(axis=0)
            
            axes[1, idx].bar(feature_names_trunk, trunk_mean, color='coral', alpha=0.7)
            axes[1, idx].set_title(f'{label} - Trunk Network\n(Spatiotemporal Coords)')
            axes[1, idx].set_ylabel('Mean Absolute Attribution')
            axes[1, idx].tick_params(axis='x', rotation=45)
            axes[1, idx].grid(True, alpha=0.3, axis='y')
        
        # Cross-method comparison (synthesis plot)
        ax_summary = axes[0, 3]
        ax_summary_trunk = axes[1, 3]
        
        # Aggregate attribution importance across methods
        branch_importance = {}
        trunk_importance = {}
        
        for feat in feature_names_branch:
            branch_importance[feat] = []
        for feat in feature_names_trunk:
            trunk_importance[feat] = []
        
        for method in methods:
            branch_attr = attr_results[f'{method}_branch']
            trunk_attr = attr_results[f'{method}_trunk']
            
            branch_mean = np.abs(branch_attr).mean(axis=0)
            trunk_mean = np.abs(trunk_attr).mean(axis=0)
            
            for i, feat in enumerate(feature_names_branch):
                branch_importance[feat].append(branch_mean[i])
            for i, feat in enumerate(feature_names_trunk):
                trunk_importance[feat].append(trunk_mean[i])
        
        # Plot aggregated importance
        branch_means = [np.mean(branch_importance[f]) for f in feature_names_branch]
        branch_stds = [np.std(branch_importance[f]) for f in feature_names_branch]
        
        ax_summary.bar(feature_names_branch, branch_means, yerr=branch_stds, 
                      capsize=5, color='purple', alpha=0.7)
        ax_summary.set_title('Consensus Importance\n(Mean ± Std across methods)')
        ax_summary.set_ylabel('Attribution Magnitude')
        ax_summary.tick_params(axis='x', rotation=45)
        ax_summary.grid(True, alpha=0.3, axis='y')
        
        trunk_means = [np.mean(trunk_importance[f]) for f in feature_names_trunk]
        trunk_stds = [np.std(trunk_importance[f]) for f in feature_names_trunk]
        
        ax_summary_trunk.bar(feature_names_trunk, trunk_means, yerr=trunk_stds,
                            capsize=5, color='purple', alpha=0.7)
        ax_summary_trunk.set_title('Consensus Importance\n(Mean ± Std across methods)')
        ax_summary_trunk.set_ylabel('Attribution Magnitude')
        ax_summary_trunk.tick_params(axis='x', rotation=45)
        ax_summary_trunk.grid(True, alpha=0.3, axis='y')
        
        plt.tight_layout()
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"  ✓ Saved: {save_path}")
        plt.show()
        plt.close()
    
    @staticmethod
    def plot_layer_conductance_comparison(attr_results_dict, save_path: Optional[str] = None):
        """
        Plot layer conductance comparison across models
        Shows which layers contribute most to predictions in each architecture
        """
        # Extract layer conductance data
        conductance_data = {}
        for model_name, attr_data in attr_results_dict.items():
            if 'layer_conductance' in attr_data:
                conductance_data[model_name] = attr_data['layer_conductance']
        
        if len(conductance_data) == 0:
            print("No layer conductance data to plot")
            return
        
        # Create subplots for each model
        n_models = len(conductance_data)
        fig, axes = plt.subplots(1, n_models, figsize=(5*n_models, 6))
        
        if n_models == 1:
            axes = [axes]
        
        for idx, (model_name, layer_cond) in enumerate(conductance_data.items()):
            ax = axes[idx]
            
            layers = list(layer_cond.keys())
            values = list(layer_cond.values())
            
            # Color code by network type
            colors = []
            for layer in layers:
                if 'branch' in layer:
                    colors.append('#2E86AB')  # Blue for branch
                elif 'trunk' in layer:
                    colors.append('#A23B72')  # Purple for trunk
                elif 'fourier' in layer:
                    colors.append('#F18F01')  # Orange for Fourier
                elif 'lstm' in layer:
                    colors.append('#C73E1D')  # Red for LSTM
                else:
                    colors.append('gray')
            
            bars = ax.bar(range(len(layers)), values, color=colors, alpha=0.7)
            ax.set_xlabel('Layer')
            ax.set_ylabel('Mean Conductance')
            ax.set_title(f'{model_name}\nLayer Conductance')
            ax.set_xticks(range(len(layers)))
            ax.set_xticklabels(layers, rotation=45, ha='right', fontsize=8)
            ax.grid(True, alpha=0.3, axis='y')
            
            # Add value labels on bars
            for bar, value in zip(bars, values):
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height,
                       f'{value:.3f}',
                       ha='center', va='bottom', fontsize=8)
        
        plt.tight_layout()
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"  ✓ Saved: {save_path}")
        plt.show()
        plt.close()
    
    @staticmethod
    def plot_statistical_comparison(comparison_results, save_path: Optional[str] = None):
        """Plot statistical comparison between models"""
        if len(comparison_results) == 0:
            print("No comparisons to plot")
            return
        
        comparisons = list(comparison_results.keys())
        p_values = [comparison_results[comp]['p_value'] for comp in comparisons]
        mean_diffs = [comparison_results[comp]['mean_diff'] for comp in comparisons]
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
        
        # P-values
        colors = ['green' if p < 0.05 else 'red' for p in p_values]
        ax1.barh(comparisons, p_values, color=colors, alpha=0.7)
        ax1.axvline(x=0.05, color='black', linestyle='--', linewidth=2, label='α=0.05')
        ax1.set_xlabel('P-value')
        ax1.set_title('Statistical Significance (Paired t-test)')
        ax1.legend()
        ax1.grid(True, alpha=0.3, axis='x')
        
        # Mean RMSE differences
        colors = ['green' if d < 0 else 'red' for d in mean_diffs]
        ax2.barh(comparisons, mean_diffs, color=colors, alpha=0.7)
        ax2.axvline(x=0, color='black', linestyle='--', linewidth=2)
        ax2.set_xlabel('Mean RMSE Difference (°C)')
        ax2.set_title('Performance Difference (Negative = First Model Better)')
        ax2.grid(True, alpha=0.3, axis='x')
        
        plt.tight_layout()
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"  ✓ Saved: {save_path}")
        plt.show()
        plt.close()
    
    @staticmethod
    def plot_uncertainty_analysis(uq_results_dict, save_path: Optional[str] = None):
        """
        Plot uncertainty quantification results
        Shows uncertainty vs error correlation and calibration
        """
        n_models = len(uq_results_dict)
        if n_models == 0:
            print("No UQ results to plot")
            return
        
        fig, axes = plt.subplots(2, n_models, figsize=(5*n_models, 10))
        
        if n_models == 1:
            axes = axes.reshape(-1, 1)
        
        for idx, (model_name, uq_data) in enumerate(uq_results_dict.items()):
            # Top row: Uncertainty vs Error scatter
            ax_scatter = axes[0, idx]
            uncertainties = uq_data['uncertainties']
            errors = uq_data['errors']
            correlation = uq_data['correlation']
            
            # Scatter plot with density coloring
            scatter = ax_scatter.scatter(
                uncertainties, errors,
                alpha=0.3, s=10, c=uncertainties,
                cmap='viridis', edgecolors='none'
            )
            
            # Add trend line
            z = np.polyfit(uncertainties, errors, 1)
            p = np.poly1d(z)
            x_line = np.linspace(uncertainties.min(), uncertainties.max(), 100)
            ax_scatter.plot(x_line, p(x_line), "r--", linewidth=2, label=f'Trend (r={correlation:.3f})')
            
            ax_scatter.set_xlabel('Epistemic Uncertainty (°C)')
            ax_scatter.set_ylabel('Absolute Error (°C)')
            ax_scatter.set_title(f'{model_name}\nUncertainty vs Error')
            ax_scatter.legend()
            ax_scatter.grid(True, alpha=0.3)
            plt.colorbar(scatter, ax=ax_scatter, label='Uncertainty (°C)')
            
            # Bottom row: Calibration plot (binned uncertainty vs error)
            ax_calib = axes[1, idx]
            bin_means_unc = uq_data['bin_means_uncertainty']
            bin_means_err = uq_data['bin_means_error']
            
            ax_calib.plot(bin_means_unc, bin_means_err, 'o-', linewidth=2, markersize=8, label='Observed')
            ax_calib.plot([bin_means_unc.min(), bin_means_unc.max()],
                         [bin_means_unc.min(), bin_means_unc.max()],
                         'r--', linewidth=2, label='Perfect Calibration')
            
            ax_calib.set_xlabel('Mean Uncertainty (°C)')
            ax_calib.set_ylabel('Mean Absolute Error (°C)')
            ax_calib.set_title(f'{model_name}\nCalibration Plot')
            ax_calib.legend()
            ax_calib.grid(True, alpha=0.3)
        
        plt.tight_layout()
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"  ✓ Saved: {save_path}")
        plt.show()
        plt.close()
    
    @staticmethod
    def create_summary_report(summary_stats, comparison_results):
        """Create a comprehensive summary report"""
        print("\n" + "="*80)
        print(" "*25 + "COMPREHENSIVE ANALYSIS REPORT")
        print("="*80 + "\n")
        
        # Performance Summary
        print("📊 PERFORMANCE SUMMARY (Mean ± Std across seeds)")
        print("-"*80)
        
        if len(summary_stats) == 0:
            print("No models to summarize")
            return
        
        metrics_df = pd.DataFrame({
            model: {
                'RMSE': f"{stats['mean']['rmse']:.4f} ± {stats['std']['rmse']:.4f}",
                'MAE': f"{stats['mean']['mae']:.4f} ± {stats['std']['mae']:.4f}",
                'R²': f"{stats['mean']['r2']:.4f} ± {stats['std']['r2']:.4f}",
                'MAPE (%)': f"{stats['mean']['mape']:.2f} ± {stats['std']['mape']:.2f}"
            }
            for model, stats in summary_stats.items()
        }).T
        
        print(metrics_df.to_string())
        print("\n")
        
        # Statistical Comparisons
        if len(comparison_results) > 0:
            print("📈 STATISTICAL COMPARISONS (Paired t-tests)")
            print("-"*80)
            
            for comp_name, comp_data in comparison_results.items():
                model1, model2 = comp_name.split('_vs_')
                significance = "✓ SIGNIFICANT" if comp_data['significant'] else "✗ Not Significant"
                better_model = model1 if comp_data['mean_diff'] < 0 else model2
                
                print(f"\n{model1} vs {model2}:")
                print(f"  • t-statistic: {comp_data['t_statistic']:.4f}")
                print(f"  • p-value: {comp_data['p_value']:.4f}")
                print(f"  • Mean RMSE difference: {abs(comp_data['mean_diff']):.4f}°C")
                print(f"  • Result: {significance}")
                print(f"  • Better Model: {better_model}")
        
        print("\n" + "="*80 + "\n")


# =====================================================================
# MAIN EXECUTION
# =====================================================================

def main():
    """Main execution pipeline"""
    
    print("="*80)
    print(" "*15 + "DEEPONET vs FNO vs LSTM RESEARCH PIPELINE")
    print(" "*20 + "with Captum Interpretability Analysis")
    print("="*80 + "\n")
    
    # Configuration
    REGION_BOUNDS = [-60, -50, 20, 30]  # North Atlantic
    DATE_RANGE = ['2011-01-01', '2011-12-31']
    NUM_SEEDS = 10  # Increased from 5 for stronger statistical significance
    EPOCHS = 50  # Reduced for demo; increase to 100+ for full training
    
    # Create output directory for saving results
    import os
    output_dir = 'results'
    os.makedirs(output_dir, exist_ok=True)
    print(f"📁 Results will be saved to: {output_dir}/\n")
    
    # Step 1: Data Loading and Preprocessing
    print("📥 Step 1: Loading and Preprocessing Argo Data")
    print("-"*80)
    
    global processor  # Make accessible to experiment
    processor = ArgoDataProcessor(REGION_BOUNDS, DATE_RANGE, max_depth=500.0)
    df = processor.fetch_data()  # Fetch real Argo data only
    
    # Process data with sequences for LSTM
    print("\nCreating datasets:")
    print("  - Standard (for DeepONet, FNO): point-wise predictions")
    print("  - Sequences (for LSTM): temporal patterns (5-step windows)")
    
    # Standard data for DeepONet and FNO
    data_splits = processor.preprocess(df, test_size=0.2, val_size=0.1, use_sequences=False)
    
    # Sequence data for LSTM
    data_splits_lstm = processor.preprocess(df, test_size=0.2, val_size=0.1, 
                                            use_sequences=True, sequence_length=5)
    
    # Store sequence data on processor for later access (UQ, XAI)
    processor.data_splits_lstm = data_splits_lstm
    
    print("\n✓ Data preprocessing complete!\n")
    
    # Step 2: Multi-Seed Training
    print("🚀 Step 2: Training Models with Multiple Seeds")
    print("-"*80)
    
    experiment = MultiSeedExperiment(data_splits, num_seeds=NUM_SEEDS)
    # Store LSTM sequence data for later use during training
    experiment.data_splits_lstm = data_splits_lstm
    results = experiment.run_all_models(epochs=EPOCHS)
    
    print("\n✓ All models trained!\n")
    
    # Step 3: Statistical Analysis
    print("📊 Step 3: Statistical Analysis")
    print("-"*80)
    
    summary_stats = experiment.statistical_analysis()
    comparison_results = experiment.compare_models()
    
    print("\n✓ Statistical analysis complete!\n")
    
    # Step 4: Interpretability Analysis (All Models)
    print("🔍 Step 4: Interpretability Analysis (All Models)")
    print("-"*80)
    
    attr_results = {}
    
    # Analyze ALL models for comprehensive comparison
    for model_name in ['DeepONet', 'DeepONet-Physics', 'FNO', 'LSTM']:
        if len(results[model_name]) > 0:
            print(f"\nAnalyzing {model_name}...")
            best_model = results[model_name][0]['model']
            
            # Create analyzer (LSTM now supported with 3D sequence handling!)
            is_sequence_model = (model_name == 'LSTM')
            analyzer = InterpretabilityAnalyzer(best_model, is_sequence_model=is_sequence_model)
            
            # Use appropriate data format (LSTM needs sequence data)
            if model_name == 'LSTM' and 'X_branch_test' in data_splits_lstm:
                # LSTM uses sequence data
                sample_size = min(100, len(data_splits_lstm['X_branch_test']))
                test_dataset = TensorDataset(
                    data_splits_lstm['X_branch_test'][:sample_size],
                    data_splits_lstm['X_trunk_test'][:sample_size],
                    data_splits_lstm['y_test'][:sample_size]
                )
            else:
                # Regular models use non-sequence data
                sample_size = min(100, len(data_splits['X_branch_test']))
                test_dataset = TensorDataset(
                    data_splits['X_branch_test'][:sample_size],
                    data_splits['X_trunk_test'][:sample_size],
                    data_splits['y_test'][:sample_size]
                )
            
            test_loader = DataLoader(test_dataset, batch_size=sample_size)
            x_branch, x_trunk, _ = next(iter(test_loader))
            
            attr_results[model_name] = analyzer.analyze_all(x_branch, x_trunk)
        else:
            print(f"\n⚠ No {model_name} models available for interpretability analysis")
    
    if len(attr_results) > 0:
        print(f"\n✓ Interpretability analysis complete for {len(attr_results)} models!\n")
    else:
        print("\n⚠ No models available for interpretability analysis\n")
    
    # Step 4.5: Uncertainty Quantification with Deep Ensembles
    print("🎲 Step 4.5: Uncertainty Quantification (Deep Ensembles)")
    print("-"*80)
    
    uq_results = {}
    
    # Create test loader for UQ analysis (non-sequence data)
    test_dataset = TensorDataset(
        data_splits['X_branch_test'],
        data_splits['X_trunk_test'],
        data_splits['y_test']
    )
    test_loader_uq = DataLoader(test_dataset, batch_size=256)
    
    # Create separate test loader for LSTM (sequence data)
    if hasattr(processor, 'data_splits_lstm'):
        test_dataset_lstm = TensorDataset(
            processor.data_splits_lstm['X_branch_test'],
            processor.data_splits_lstm['X_trunk_test'],
            processor.data_splits_lstm['y_test']
        )
        test_loader_lstm = DataLoader(test_dataset_lstm, batch_size=256)
    else:
        test_loader_lstm = test_loader_uq  # Fallback to regular data
    
    # Analyze each model type with Deep Ensemble
    for model_name in ['DeepONet', 'DeepONet-Physics', 'FNO', 'LSTM']:
        if len(results[model_name]) >= 3:  # Need at least 3 models for ensemble
            print(f"\nCreating ensemble for {model_name}...")
            
            # Get all trained models for this type
            models = [res['model'] for res in results[model_name]]
            
            # Create Deep Ensemble
            ensemble = DeepEnsemble(models)
            
            # Use appropriate test loader (LSTM needs sequence data)
            current_test_loader = test_loader_lstm if model_name == 'LSTM' else test_loader_uq
            
            # Analyze uncertainty vs error
            uq_analysis = ensemble.uncertainty_vs_error_analysis(
                current_test_loader, processor.scalers['target']
            )
            
            uq_results[model_name] = uq_analysis
            
            print(f"  Correlation (uncertainty vs error): {uq_analysis['correlation']:.3f}")
            print(f"  Ensemble RMSE: {uq_analysis['metrics']['rmse']:.3f}°C")
            print(f"  Mean uncertainty: {uq_analysis['uncertainties'].mean():.3f}°C")
        else:
            print(f"\n⚠ Not enough models for {model_name} ensemble (need ≥3, got {len(results[model_name])})")
    
    if len(uq_results) > 0:
        print(f"\n✓ Uncertainty quantification complete for {len(uq_results)} models!\n")
    else:
        print("\n⚠ No models available for UQ analysis\n")
    
    # Step 5: Visualization
    print("📈 Step 5: Generating Visualizations")
    print("-"*80)
    
    visualizer = ResultVisualizer()
    
    # Training curves
    print("  • Plotting training curves...")
    visualizer.plot_training_curves(results, save_path=f'{output_dir}/training_curves.png')
    
    # Metrics comparison
    print("  • Plotting metrics comparison...")
    visualizer.plot_metrics_comparison(summary_stats, save_path=f'{output_dir}/metrics_comparison.png')
    
    # Predictions vs actual
    print("  • Plotting predictions vs actual...")
    visualizer.plot_predictions_vs_actual(results, save_path=f'{output_dir}/predictions_vs_actual.png')
    
    # Residual analysis
    print("  • Plotting residual analysis...")
    visualizer.plot_residuals(results, save_path=f'{output_dir}/residual_analysis.png')
    
    # Interpretability analysis
    if len(attr_results) > 0:
        print("  • Plotting interpretability analysis...")
        feature_names_branch = ['Pressure (dbar)', 'Salinity (PSU)']
        feature_names_trunk = ['Latitude', 'Longitude', 'Time']
        
        # Plot interpretability for each model
        for model_name, attr_data in attr_results.items():
            safe_name = model_name.replace('-', '_')
            visualizer.plot_interpretability_analysis(
                attr_data, 
                feature_names_branch, 
                feature_names_trunk,
                save_path=f'{output_dir}/interpretability_{safe_name}.png'
            )
    
    # Statistical comparison
    print("  • Plotting statistical comparison...")
    visualizer.plot_statistical_comparison(comparison_results, save_path=f'{output_dir}/statistical_comparison.png')
    
    # Uncertainty quantification plots
    if len(uq_results) > 0:
        print("  • Plotting uncertainty analysis...")
        visualizer.plot_uncertainty_analysis(uq_results, save_path=f'{output_dir}/uncertainty_analysis.png')
    
    # Layer conductance comparison
    if len(attr_results) > 0:
        print("  • Plotting layer conductance comparison...")
        visualizer.plot_layer_conductance_comparison(attr_results, save_path=f'{output_dir}/layer_conductance_comparison.png')
    
    print("\n✓ All visualizations complete!\n")
    
    # Step 6: Save Results to CSV
    print("💾 Step 6: Saving Results to CSV")
    print("-"*80)
    
    # Save summary statistics
    summary_df_list = []
    for model_name, stats in summary_stats.items():
        row = {
            'Model': model_name,
            'RMSE_Mean': stats['mean']['rmse'],
            'RMSE_Std': stats['std']['rmse'],
            'MAE_Mean': stats['mean']['mae'],
            'MAE_Std': stats['std']['mae'],
            'R2_Mean': stats['mean']['r2'],
            'R2_Std': stats['std']['r2'],
            'MAPE_Mean': stats['mean']['mape'],
            'MAPE_Std': stats['std']['mape']
        }
        summary_df_list.append(row)
    
    summary_df = pd.DataFrame(summary_df_list)
    summary_df.to_csv(f'{output_dir}/summary_statistics.csv', index=False)
    print(f"  ✓ Saved: {output_dir}/summary_statistics.csv")
    
    # Save detailed results for each seed
    detailed_results = []
    for model_name, results_list in results.items():
        for seed_idx, result in enumerate(results_list):
            row = {
                'Model': model_name,
                'Seed': seed_idx,
                **result['metrics']
            }
            detailed_results.append(row)
    
    detailed_df = pd.DataFrame(detailed_results)
    detailed_df.to_csv(f'{output_dir}/detailed_results.csv', index=False)
    print(f"  ✓ Saved: {output_dir}/detailed_results.csv")
    
    # Save statistical comparisons
    if len(comparison_results) > 0:
        comp_df_list = []
        for comp_name, comp_data in comparison_results.items():
            row = {
                'Comparison': comp_name,
                't_statistic': comp_data['t_statistic'],
                'p_value': comp_data['p_value'],
                'significant': comp_data['significant'],
                'mean_diff': comp_data['mean_diff']
            }
            comp_df_list.append(row)
        
        comp_df = pd.DataFrame(comp_df_list)
        comp_df.to_csv(f'{output_dir}/statistical_comparisons.csv', index=False)
        print(f"  ✓ Saved: {output_dir}/statistical_comparisons.csv")
    
    # Save predictions for best model of each type
    for model_name, results_list in results.items():
        if len(results_list) > 0:
            best_result = results_list[0]  # First seed
            pred_df = pd.DataFrame({
                'Actual_Temperature': best_result['targets'].flatten(),
                'Predicted_Temperature': best_result['predictions'].flatten(),
                'Residual': best_result['targets'].flatten() - best_result['predictions'].flatten()
            })
            pred_df.to_csv(f'{output_dir}/predictions_{model_name}.csv', index=False)
            print(f"  ✓ Saved: {output_dir}/predictions_{model_name}.csv")
    
    # Save interpretability results
    if len(attr_results) > 0:
        feature_names_branch = ['Pressure', 'Salinity']
        feature_names_trunk = ['Latitude', 'Longitude', 'Time']
        
        for model_variant, attr_data in attr_results.items():
            interp_data = []
            methods = ['ig', 'saliency', 'deeplift', 'gradshap']
            
            for method in methods:
                branch_attr = attr_data[f'{method}_branch']
                trunk_attr = attr_data[f'{method}_trunk']
                
                for i, feat in enumerate(feature_names_branch):
                    interp_data.append({
                        'Model': model_variant,
                        'Method': method,
                        'Network': 'Branch',
                        'Feature': feat,
                        'Mean_Attribution': np.abs(branch_attr[:, i]).mean(),
                        'Std_Attribution': np.abs(branch_attr[:, i]).std()
                    })
                
                for i, feat in enumerate(feature_names_trunk):
                    interp_data.append({
                        'Model': model_variant,
                        'Method': method,
                        'Network': 'Trunk',
                        'Feature': feat,
                        'Mean_Attribution': np.abs(trunk_attr[:, i]).mean(),
                        'Std_Attribution': np.abs(trunk_attr[:, i]).std()
                    })
            
            safe_name = model_variant.replace('-', '_')
            interp_df = pd.DataFrame(interp_data)
            interp_df.to_csv(f'{output_dir}/interpretability_{safe_name}.csv', index=False)
            print(f"  ✓ Saved: {output_dir}/interpretability_{safe_name}.csv")
        
        # Save layer conductance results
        conductance_data = []
        for model_variant, attr_data in attr_results.items():
            if 'layer_conductance' in attr_data:
                for layer_name, conductance_val in attr_data['layer_conductance'].items():
                    conductance_data.append({
                        'Model': model_variant,
                        'Layer': layer_name,
                        'Mean_Conductance': conductance_val
                    })
        
        if len(conductance_data) > 0:
            conductance_df = pd.DataFrame(conductance_data)
            conductance_df.to_csv(f'{output_dir}/layer_conductance_analysis.csv', index=False)
            print(f"  ✓ Saved: {output_dir}/layer_conductance_analysis.csv")
    
    print("\n")
    
    # Step 7: Summary Report
    print("📋 Step 7: Generating Summary Report")
    print("-"*80)
    
    visualizer.create_summary_report(summary_stats, comparison_results)
    
    print("\n✅ PIPELINE COMPLETE!\n")
    print("="*80)
    print("\n📊 Key Findings:")
    print("  • All models trained with " + str(NUM_SEEDS) + " different seeds")
    print("  • 4 models compared: DeepONet, DeepONet-Physics, FNO, LSTM")
    print("  • Physics-informed version uses T-S relationship constraints")
    print("  • Hard temperature bounds + soft T-S penalty in loss function")
    print("  • Statistical significance tested with paired t-tests")
    print("  • Interpretability: 4 feature attribution + layer conductance")
    print("  • Cross-model feature importance comparison")
    print("  • Layer-wise contribution analysis for all architectures")
    print(f"\n📁 All results saved to: {output_dir}/")
    print("  • Training curves: training_curves.png")
    print("  • Metrics comparison: metrics_comparison.png")
    print("  • Predictions: predictions_vs_actual.png")
    print("  • Residuals: residual_analysis.png")
    print("  • Interpretability (DeepONet): interpretability_DeepONet.png")
    print("  • Interpretability (DeepONet-Physics): interpretability_DeepONet_Physics.png")
    print("  • Interpretability (FNO): interpretability_FNO.png")
    print("  • Interpretability (LSTM): interpretability_LSTM.png")
    print("  • Layer conductance: layer_conductance_comparison.png")
    print("  • Statistical tests: statistical_comparison.png")
    print("  • Summary stats: summary_statistics.csv")
    print("  • Detailed results: detailed_results.csv")
    print("  • Predictions per model: predictions_*.csv")
    print("  • Interpretability data: interpretability_*.csv (all 4 models)")
    print("  • Layer conductance data: layer_conductance_analysis.csv")
    print("\n💡 Next Steps:")
    print("  • Compare DeepONet vs DeepONet-Physics performance")
    print("  • Analyze interpretability differences with/without physics")
    print("  • Increase epochs for better convergence")
    print("  • Try different regions or time periods")
    print("  • Experiment with different architectures")
    print("="*80 + "\n")
    
    return {
        'results': results,
        'summary_stats': summary_stats,
        'comparison_results': comparison_results,
        'attr_results': attr_results,
        'data_splits': data_splits
    }


if __name__ == "__main__":
    # Run the complete pipeline
    output = main()
    
    # Access results
    # output['results'] - All model results across seeds
    # output['summary_stats'] - Statistical summary
    # output['comparison_results'] - Model comparisons
    # output['attr_results'] - Interpretability attributions
