
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
import pandas as pd
from tqdm import tqdm
import warnings
import os
import json
from datetime import datetime
from pathlib import Path

warnings.filterwarnings('ignore')

# ==================== Model Definitions ====================

class HomogeneousCNN(nn.Module):
    """Homogeneous CNN with circular padding and He initialization"""
    def __init__(self, depth, channels=64, kernel_size=3, num_classes=10, 
                 input_channels=3, activation='relu', dropout_rate=0.0):
        super().__init__()
        self.depth = depth
        self.channels = channels
        self.activation_type = activation
        self.dropout_rate = dropout_rate
        
        layers = []
        in_channels = input_channels
        
        for i in range(depth):
            padding = kernel_size // 2
            conv = nn.Conv2d(in_channels, channels, kernel_size, 
                           stride=1, padding=padding, padding_mode='circular', bias=False)
            
            # He initialization
            if activation == 'relu':
                nn.init.kaiming_normal_(conv.weight, mode='fan_in', nonlinearity='relu')
            else:  # gelu
                nn.init.kaiming_normal_(conv.weight, mode='fan_in', nonlinearity='linear')
                with torch.no_grad():
                    conv.weight.data *= np.sqrt(2.0)
            
            layers.append(conv)
            
            # Add activation
            if activation == 'relu':
                layers.append(nn.ReLU(inplace=True))
            elif activation == 'gelu':
                layers.append(nn.GELU())
            
            # Add dropout if specified
            if dropout_rate > 0:
                layers.append(nn.Dropout2d(dropout_rate))
            
            in_channels = channels
        
        self.features = nn.Sequential(*layers)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(channels, num_classes, bias=False)
        nn.init.kaiming_normal_(self.classifier.weight, mode='fan_in', nonlinearity='linear')
    
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
    
    def get_effective_depth(self):
        return self.depth


class ResNetBlock(nn.Module):
    """ResNet block with optional dropout and batch norm"""
    def __init__(self, in_channels, out_channels, stride=1, activation='relu', 
                 total_blocks=1, dropout_rate=0.0, use_bn=False):
        super().__init__()
        self.activation_type = activation
        self.use_bn = use_bn
        
        self.conv = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=not use_bn)
        self.bn = nn.BatchNorm2d(out_channels) if use_bn else None
        self.dropout = nn.Dropout2d(dropout_rate) if dropout_rate > 0 else None
        
        # He initialization
        if activation == 'relu':
            nn.init.kaiming_normal_(self.conv.weight, mode='fan_in', nonlinearity='relu')
        else:
            nn.init.kaiming_normal_(self.conv.weight, mode='fan_in', nonlinearity='linear')
            with torch.no_grad():
                self.conv.weight.data *= np.sqrt(2.0)

        # Scale by 1/sqrt(K) for residual
        with torch.no_grad():
            if total_blocks is not None and total_blocks > 0:
                self.conv.weight.data /= np.sqrt(float(total_blocks))
        
        # Shortcut connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            if use_bn:
                self.shortcut = nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
                    nn.BatchNorm2d(out_channels)
                )
            else:
                self.shortcut = nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, 1, stride, bias=True)
                )
            nn.init.kaiming_normal_(self.shortcut[0].weight, mode='fan_in', 
                                   nonlinearity='relu' if activation == 'relu' else 'linear')
    
    def forward(self, x):
        out = self.conv(x)
        if self.bn is not None:
            out = self.bn(out)
        
        if self.activation_type == 'relu':
            out = F.relu(out)
        else:
            out = F.gelu(out)
        
        if self.dropout is not None:
            out = self.dropout(out)
        
        out = out + self.shortcut(x)
        return out


class PreActResNet(nn.Module):
    """ResNet with optional dropout and batch norm"""
    def __init__(self, depth, num_classes=10, input_channels=3, activation='relu',
                 dropout_rate=0.0, use_bn=False):
        super().__init__()
        if depth < 1:
            raise ValueError("PreActResNet depth must be >= 1")
            
        self.depth = depth
        self.activation_type = activation
        
        layers = []
        in_channels = input_channels
        out_channels = 64
        
        for i in range(depth):
            layers.append(ResNetBlock(in_channels, out_channels, stride=1, 
                                      activation=activation, total_blocks=depth, 
                                      dropout_rate=dropout_rate, use_bn=use_bn))
            in_channels = out_channels
        
        self.features = nn.Sequential(*layers)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(out_channels, num_classes, bias=False)
        nn.init.kaiming_normal_(self.classifier.weight, mode='fan_in', nonlinearity='linear')
    
    def forward(self, x):
        out = self.features(x)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.classifier(out)
        return out
    
    def get_effective_depth(self):
        return self.depth


# ==================== Data Loading ====================

def get_data_loaders(dataset_name, batch_size=128):
    """Load CIFAR-10, CIFAR-100, or ImageNet subset"""
    if dataset_name == 'cifar10':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        trainset = torchvision.datasets.CIFAR10(root='../data', train=True,
                                               download=True, transform=transform)
        testset = torchvision.datasets.CIFAR10(root='../data', train=False,
                                              download=True, transform=transform)
        input_channels = 3
        num_classes = 10
    
    elif dataset_name == 'cifar100':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
        ])
        trainset = torchvision.datasets.CIFAR100(root='../data', train=True,
                                                download=True, transform=transform)
        testset = torchvision.datasets.CIFAR100(root='../data', train=False,
                                               download=True, transform=transform)
        input_channels = 3
        num_classes = 100
    
    elif dataset_name == 'imagenet':
        # Use CIFAR-100 resized as ImageNet proxy
        transform = transforms.Compose([
            transforms.Resize(64),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        trainset = torchvision.datasets.CIFAR100(root='../data', train=True,
                                                download=True, transform=transform)
        testset = torchvision.datasets.CIFAR100(root='../data', train=False,
                                               download=True, transform=transform)
        input_channels = 3
        num_classes = 100
    
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                             shuffle=True, num_workers=2)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                            shuffle=False, num_workers=2)
    
    return trainloader, testloader, input_channels, num_classes


# ==================== Training Functions ====================

def train_one_epoch(model, dataloader, learning_rate, device, optimizer_type='sgd', max_batches=None):
    """Train for one epoch and return final loss"""
    model.train()
    criterion = nn.CrossEntropyLoss()
    
    if optimizer_type == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=learning_rate)
    elif optimizer_type == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    else:
        raise ValueError(f"Unknown optimizer: {optimizer_type}")
    
    total_loss = 0
    num_batches = 0
    
    for batch_idx, (inputs, targets) in enumerate(dataloader):
        if max_batches and batch_idx >= max_batches:
            break
            
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
    
    return total_loss / num_batches


def grid_search_lr_detailed(model_class, model_kwargs, dataloader, lr_range, device, 
                            optimizer_type='sgd', num_trials=3, max_batches=100):
    """
    Grid search for optimal learning rate with detailed logging.
    Returns: best_lr, best_loss, all_lr_losses (dict with all trial data)
    """
    best_lr = None
    best_loss = float('inf')
    
    # Store detailed results: {lr: {'mean': x, 'std': x, 'trials': [...]}}
    all_lr_losses = {}
    
    for lr in tqdm(lr_range, desc=f"Grid search (depth={model_kwargs.get('depth', 'N/A')})"):
        trial_losses = []
        
        for trial in range(num_trials):
            model = model_class(**model_kwargs).to(device)
            loss = train_one_epoch(model, dataloader, lr, device, optimizer_type, max_batches)
            trial_losses.append(loss)
            del model
            torch.cuda.empty_cache()
        
        avg_loss = np.mean(trial_losses)
        std_loss = np.std(trial_losses)
        
        all_lr_losses[float(lr)] = {
            'mean': float(avg_loss),
            'std': float(std_loss),
            'trials': [float(l) for l in trial_losses]
        }
        
        if avg_loss < best_loss:
            best_loss = avg_loss
            best_lr = lr
    
    return best_lr, best_loss, all_lr_losses


# ==================== Main Experiment Function ====================

def run_experiment(model_type, dataset_name, activation='relu', optimizer_type='sgd',
                   dropout_rate=0.0, use_bn=False, device='cuda', 
                   output_dir='./experiment_results'):
    """
    Run a complete experiment with detailed data saving.
    
    Args:
        model_type: 'cnn' or 'resnet'
        dataset_name: 'cifar10', 'cifar100', or 'imagenet'
        activation: 'relu' or 'gelu'
        optimizer_type: 'sgd' or 'adam'
        dropout_rate: float (0.0 = no dropout, 0.1 = 10% dropout)
        use_bn: bool (only for resnet)
        device: 'cuda' or 'cpu'
        output_dir: directory to save results
    """
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Generate experiment name
    dropout_str = 'dropout' if dropout_rate > 0 else 'nodropout'
    bn_str = 'bn' if use_bn else 'nobn'
    exp_name = f"{model_type}_{dataset_name}_{activation}_{optimizer_type}_{dropout_str}"
    if model_type == 'resnet':
        exp_name += f"_{bn_str}"
    
    print(f"\n{'='*70}")
    print(f"EXPERIMENT: {exp_name}")
    print(f"{'='*70}")
    print(f"Model: {model_type.upper()}")
    print(f"Dataset: {dataset_name.upper()}")
    print(f"Activation: {activation.upper()}")
    print(f"Optimizer: {optimizer_type.upper()}")
    print(f"Dropout: {dropout_rate}")
    if model_type == 'resnet':
        print(f"BatchNorm: {use_bn}")
    print(f"{'='*70}\n")
    
    # Load data
    trainloader, testloader, input_channels, num_classes = get_data_loaders(dataset_name)
    
    # Define depths
    all_depths = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 20, 22, 24, 26, 28, 30]
    
    # Learning rate search range
    lr_range = np.logspace(-5, 0, 80)
    
    # Store all results
    all_results = []
    detailed_lr_data = {}  # {depth: {lr: loss_data}}
    
    # Run grid search for each depth
    for depth in all_depths:
        print(f"\n--- Testing depth L={depth} ---")
        
        # Build model kwargs
        if model_type == 'cnn':
            model_kwargs = {
                'depth': depth,
                'channels': 64,
                'kernel_size': 3,
                'num_classes': num_classes,
                'input_channels': input_channels,
                'activation': activation,
                'dropout_rate': dropout_rate
            }
            model_class = HomogeneousCNN
        elif model_type == 'resnet':
            model_kwargs = {
                'depth': depth,
                'num_classes': num_classes,
                'input_channels': input_channels,
                'activation': activation,
                'dropout_rate': dropout_rate,
                'use_bn': use_bn
            }
            model_class = PreActResNet
        else:
            raise ValueError(f"Unknown model type: {model_type}")
        
        # Grid search with detailed logging
        best_lr, best_loss, lr_losses = grid_search_lr_detailed(
            model_class, 
            model_kwargs, 
            trainloader, 
            lr_range, 
            device,
            optimizer_type=optimizer_type,
            num_trials=3,
            max_batches=100
        )
        
        # Store results
        all_results.append({
            'depth': depth,
            'best_lr': float(best_lr),
            'best_loss': float(best_loss)
        })
        
        detailed_lr_data[depth] = lr_losses
        
        print(f"Depth {depth}: Best LR = {best_lr:.6f}, Best Loss = {best_loss:.4f}")
    
    # Fit global power law
    depths_array = np.array([r['depth'] for r in all_results])
    lrs_array = np.array([r['best_lr'] for r in all_results])
    
    log_depths = np.log(depths_array)
    log_lrs = np.log(lrs_array)
    reg = LinearRegression()
    reg.fit(log_depths.reshape(-1, 1), log_lrs)
    global_alpha = reg.coef_[0]
    global_k = np.exp(reg.intercept_)
    
    # Calculate R^2
    r_squared = reg.score(log_depths.reshape(-1, 1), log_lrs)
    
    print(f"\n{'='*70}")
    print(f"RESULTS SUMMARY")
    print(f"{'='*70}")
    print(f"Fitted exponent α: {global_alpha:.4f}")
    print(f"Theoretical α: -1.5")
    print(f"Deviation: {abs(global_alpha + 1.5):.4f} ({abs(global_alpha + 1.5)/1.5*100:.1f}%)")
    print(f"R²: {r_squared:.4f}")
    print(f"{'='*70}\n")
    
    # ==================== Save Results ====================
    
    # 1. Save summary results (CSV)
    df_summary = pd.DataFrame(all_results)
    df_summary['log10_depth'] = np.log10(df_summary['depth'])
    df_summary['log10_lr'] = np.log10(df_summary['best_lr'])
    summary_path = os.path.join(output_dir, f"{exp_name}_summary.csv")
    df_summary.to_csv(summary_path, index=False)
    print(f"Saved summary to: {summary_path}")
    
    # 2. Save detailed LR-loss data (JSON)
    detailed_data = {
        'experiment_info': {
            'model_type': model_type,
            'dataset': dataset_name,
            'activation': activation,
            'optimizer': optimizer_type,
            'dropout_rate': dropout_rate,
            'use_bn': use_bn if model_type == 'resnet' else None,
            'timestamp': datetime.now().isoformat(),
            'fitted_alpha': float(global_alpha),
            'theoretical_alpha': -1.5,
            'r_squared': float(r_squared)
        },
        'depth_results': {}
    }
    
    for depth in all_depths:
        depth_result = next(r for r in all_results if r['depth'] == depth)
        detailed_data['depth_results'][str(depth)] = {
            'best_lr': depth_result['best_lr'],
            'best_loss': depth_result['best_loss'],
            'all_lr_losses': detailed_lr_data[depth]
        }
    
    detailed_path = os.path.join(output_dir, f"{exp_name}_detailed.json")
    with open(detailed_path, 'w') as f:
        json.dump(detailed_data, f, indent=2)
    print(f"Saved detailed data to: {detailed_path}")
    
    # 3. Save all LR-loss pairs as flat CSV (for easy plotting)
    flat_data = []
    for depth in all_depths:
        for lr_str, loss_data in detailed_lr_data[depth].items():
            flat_data.append({
                'depth': depth,
                'learning_rate': float(lr_str),
                'mean_loss': loss_data['mean'],
                'std_loss': loss_data['std'],
                'trial_1': loss_data['trials'][0] if len(loss_data['trials']) > 0 else None,
                'trial_2': loss_data['trials'][1] if len(loss_data['trials']) > 1 else None,
                'trial_3': loss_data['trials'][2] if len(loss_data['trials']) > 2 else None,
            })
    
    df_flat = pd.DataFrame(flat_data)
    flat_path = os.path.join(output_dir, f"{exp_name}_all_lr_losses.csv")
    df_flat.to_csv(flat_path, index=False)
    print(f"Saved all LR-loss data to: {flat_path}")
    
    # ==================== Create Plots ====================
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 12))
    
    # Plot 1: Depth vs Optimal LR (log-log)
    ax1 = axes[0, 0]
    ax1.scatter(depths_array, lrs_array, s=80, alpha=0.7, label='Grid Search', zorder=3)
    
    depth_range = np.linspace(min(depths_array), max(depths_array), 100)
    fitted_line = global_k * (depth_range ** global_alpha)
    ax1.plot(depth_range, fitted_line, 'r--', linewidth=2, 
             label=f'Fit: η ∝ L^({global_alpha:.3f})')
    
    # Theoretical line
    k_theory = lrs_array[0] * (depths_array[0] ** 1.5)
    theory_line = k_theory * (depth_range ** (-1.5))
    ax1.plot(depth_range, theory_line, 'g-.', linewidth=2, 
             label='Theory: η ∝ L^(-1.5)')
    
    ax1.set_xscale('log')
    ax1.set_yscale('log')
    ax1.set_xlabel('Depth L', fontsize=12)
    ax1.set_ylabel('Optimal Learning Rate η*', fontsize=12)
    ax1.set_title(f'Depth-LR Scaling (α = {global_alpha:.3f})', fontsize=14)
    ax1.legend(fontsize=10)
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Loss landscapes for selected depths
    ax2 = axes[0, 1]
    selected_depths = [3, 6, 10, 15, 20, 30]
    colors = plt.cm.viridis(np.linspace(0, 1, len(selected_depths)))
    
    for i, depth in enumerate(selected_depths):
        if depth in detailed_lr_data:
            lrs = sorted([float(lr) for lr in detailed_lr_data[depth].keys()])
            losses = [detailed_lr_data[depth][str(lr) if str(lr) in detailed_lr_data[depth] 
                     else lr]['mean'] for lr in lrs]
            ax2.plot(lrs, losses, color=colors[i], linewidth=1.5, 
                    label=f'L={depth}', alpha=0.8)
            
            # Mark best LR
            best_lr = all_results[all_depths.index(depth)]['best_lr']
            best_loss = all_results[all_depths.index(depth)]['best_loss']
            ax2.scatter([best_lr], [best_loss], color=colors[i], s=100, 
                       marker='*', zorder=5, edgecolors='black')
    
    ax2.set_xscale('log')
    ax2.set_xlabel('Learning Rate', fontsize=12)
    ax2.set_ylabel('Loss', fontsize=12)
    ax2.set_title('Loss Landscapes by Depth', fontsize=14)
    ax2.legend(fontsize=9, loc='upper right')
    ax2.grid(True, alpha=0.3)
    
    # Plot 3: Residuals from theory
    ax3 = axes[1, 0]
    predicted_lrs = k_theory * (depths_array ** (-1.5))
    relative_errors = (lrs_array - predicted_lrs) / predicted_lrs * 100
    
    ax3.bar(range(len(depths_array)), relative_errors, alpha=0.7, 
           color=['green' if e > 0 else 'red' for e in relative_errors])
    ax3.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
    ax3.axhline(y=10, color='gray', linestyle='--', linewidth=1, alpha=0.5)
    ax3.axhline(y=-10, color='gray', linestyle='--', linewidth=1, alpha=0.5)
    ax3.set_xticks(range(len(depths_array)))
    ax3.set_xticklabels([str(d) for d in depths_array], rotation=45, fontsize=8)
    ax3.set_xlabel('Depth L', fontsize=12)
    ax3.set_ylabel('Relative Error (%)', fontsize=12)
    ax3.set_title('Deviation from Theory', fontsize=14)
    ax3.grid(True, alpha=0.3, axis='y')
    
    # Plot 4: Best loss vs depth
    ax4 = axes[1, 1]
    best_losses = [r['best_loss'] for r in all_results]
    ax4.plot(depths_array, best_losses, 'o-', markersize=6, linewidth=1.5)
    ax4.set_xlabel('Depth L', fontsize=12)
    ax4.set_ylabel('Best Loss', fontsize=12)
    ax4.set_title('Best Loss vs Depth', fontsize=14)
    ax4.grid(True, alpha=0.3)
    
    plt.suptitle(f'{exp_name}', fontsize=16, fontweight='bold')
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    
    plot_path = os.path.join(output_dir, f"{exp_name}_plots.png")
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Saved plots to: {plot_path}")
    
    return {
        'experiment_name': exp_name,
        'global_alpha': global_alpha,
        'r_squared': r_squared,
        'all_results': all_results,
        'detailed_lr_data': detailed_lr_data
    }


# ==================== Batch Experiment Runner ====================

def run_missing_experiments(device='cuda', output_dir='./experiment_results'):
    """
    Run all missing experiments for CNN and ResNet ablation study.
    """
    
    experiments_to_run = [
        # ===== CNN Missing Experiments =====
        # CNN + SGD + GELU + wo + CIFAR-100 (公平对比)
        {'model_type': 'cnn', 'dataset_name': 'cifar100', 'activation': 'gelu', 
         'optimizer_type': 'sgd', 'dropout_rate': 0.0, 'use_bn': False},
        
        # CNN + SGD + ReLU + with + CIFAR-10 (隔离dropout)
        {'model_type': 'cnn', 'dataset_name': 'cifar10', 'activation': 'relu', 
         'optimizer_type': 'sgd', 'dropout_rate': 0.1, 'use_bn': False},
        
        # CNN + SGD + GELU + with + CIFAR-10 (完成2x2 grid)
        {'model_type': 'cnn', 'dataset_name': 'cifar10', 'activation': 'gelu', 
         'optimizer_type': 'sgd', 'dropout_rate': 0.1, 'use_bn': False},
        
        # ===== ResNet Missing Experiments =====
        # ResNet + Adam + ReLU + wo + CIFAR-100 (Adam跨数据集)
        {'model_type': 'resnet', 'dataset_name': 'cifar100', 'activation': 'relu', 
         'optimizer_type': 'adam', 'dropout_rate': 0.0, 'use_bn': False},
    ]
    
    results = []
    
    for exp_config in experiments_to_run:
        try:
            result = run_experiment(
                device=device,
                output_dir=output_dir,
                **exp_config
            )
            results.append({
                'config': exp_config,
                'alpha': result['global_alpha'],
                'r_squared': result['r_squared'],
                'status': 'success'
            })
        except Exception as e:
            print(f"ERROR in experiment {exp_config}: {e}")
            results.append({
                'config': exp_config,
                'alpha': None,
                'r_squared': None,
                'status': f'failed: {str(e)}'
            })
    
    # Save summary of all experiments
    summary_df = pd.DataFrame([
        {
            'model': r['config']['model_type'],
            'dataset': r['config']['dataset_name'],
            'activation': r['config']['activation'],
            'optimizer': r['config']['optimizer_type'],
            'dropout': r['config']['dropout_rate'],
            'batchnorm': r['config']['use_bn'],
            'fitted_alpha': r['alpha'],
            'r_squared': r['r_squared'],
            'status': r['status']
        }
        for r in results
    ])
    
    summary_path = os.path.join(output_dir, 'all_experiments_summary.csv')
    summary_df.to_csv(summary_path, index=False)
    print(f"\n{'='*70}")
    print(f"All experiments completed! Summary saved to: {summary_path}")
    print(f"{'='*70}")
    print(summary_df.to_string())
    
    return results


def run_complete_ablation(device='cuda', output_dir='./experiment_results'):
    """
    Run complete ablation study for paper.
    Includes all CNN and ResNet configurations.
    """
    
    all_experiments = [
        # ===== CNN Baseline =====
        {'model_type': 'cnn', 'dataset_name': 'cifar10', 'activation': 'relu', 
         'optimizer_type': 'sgd', 'dropout_rate': 0.0, 'use_bn': False},
        {'model_type': 'cnn', 'dataset_name': 'cifar100', 'activation': 'relu', 
         'optimizer_type': 'sgd', 'dropout_rate': 0.0, 'use_bn': False},
        
        # ===== CNN Activation Ablation =====
        {'model_type': 'cnn', 'dataset_name': 'cifar10', 'activation': 'gelu', 
         'optimizer_type': 'sgd', 'dropout_rate': 0.0, 'use_bn': False},
        {'model_type': 'cnn', 'dataset_name': 'cifar100', 'activation': 'gelu', 
         'optimizer_type': 'sgd', 'dropout_rate': 0.0, 'use_bn': False},
        
        # ===== CNN Optimizer Ablation =====
        {'model_type': 'cnn', 'dataset_name': 'cifar10', 'activation': 'relu', 
         'optimizer_type': 'adam', 'dropout_rate': 0.0, 'use_bn': False},
        
        # ===== CNN Dropout Ablation =====
        {'model_type': 'cnn', 'dataset_name': 'cifar10', 'activation': 'relu', 
         'optimizer_type': 'sgd', 'dropout_rate': 0.1, 'use_bn': False},
        {'model_type': 'cnn', 'dataset_name': 'cifar10', 'activation': 'gelu', 
         'optimizer_type': 'sgd', 'dropout_rate': 0.1, 'use_bn': False},
        
        # ===== ResNet Baseline =====
        {'model_type': 'resnet', 'dataset_name': 'cifar10', 'activation': 'relu', 
         'optimizer_type': 'sgd', 'dropout_rate': 0.0, 'use_bn': False},
        {'model_type': 'resnet', 'dataset_name': 'cifar100', 'activation': 'relu', 
         'optimizer_type': 'sgd', 'dropout_rate': 0.0, 'use_bn': False},
        
        # ===== ResNet BatchNorm Ablation =====
        {'model_type': 'resnet', 'dataset_name': 'cifar10', 'activation': 'relu', 
         'optimizer_type': 'sgd', 'dropout_rate': 0.0, 'use_bn': True},
        {'model_type': 'resnet', 'dataset_name': 'cifar100', 'activation': 'relu', 
         'optimizer_type': 'sgd', 'dropout_rate': 0.0, 'use_bn': True},
        
        # ===== ResNet Dropout Ablation =====
        {'model_type': 'resnet', 'dataset_name': 'cifar10', 'activation': 'relu', 
         'optimizer_type': 'sgd', 'dropout_rate': 0.1, 'use_bn': False},
        {'model_type': 'resnet', 'dataset_name': 'cifar100', 'activation': 'relu', 
         'optimizer_type': 'sgd', 'dropout_rate': 0.1, 'use_bn': False},
        
        # ===== ResNet BN + Dropout =====
        {'model_type': 'resnet', 'dataset_name': 'cifar10', 'activation': 'relu', 
         'optimizer_type': 'sgd', 'dropout_rate': 0.1, 'use_bn': True},
        {'model_type': 'resnet', 'dataset_name': 'cifar100', 'activation': 'relu', 
         'optimizer_type': 'sgd', 'dropout_rate': 0.1, 'use_bn': True},
        
        # ===== ResNet Optimizer Ablation =====
        {'model_type': 'resnet', 'dataset_name': 'cifar10', 'activation': 'relu', 
         'optimizer_type': 'adam', 'dropout_rate': 0.0, 'use_bn': False},
        {'model_type': 'resnet', 'dataset_name': 'cifar100', 'activation': 'relu', 
         'optimizer_type': 'adam', 'dropout_rate': 0.0, 'use_bn': False},
    ]
    
    results = []
    
    for i, exp_config in enumerate(all_experiments):
        print(f"\n{'#'*70}")
        print(f"# Experiment {i+1}/{len(all_experiments)}")
        print(f"{'#'*70}")
        
        try:
            result = run_experiment(
                device=device,
                output_dir=output_dir,
                **exp_config
            )
            results.append({
                'config': exp_config,
                'alpha': result['global_alpha'],
                'r_squared': result['r_squared'],
                'status': 'success'
            })
        except Exception as e:
            print(f"ERROR: {e}")
            results.append({
                'config': exp_config,
                'alpha': None,
                'r_squared': None,
                'status': f'failed: {str(e)}'
            })
    
    # Create final summary table
    summary_data = []
    for r in results:
        c = r['config']
        dropout_str = 'with' if c['dropout_rate'] > 0 else 'wo'
        bn_str = 'bn' if c.get('use_bn', False) else 'none'
        
        summary_data.append({
            'model': c['model_type'],
            'optimizer': c['optimizer_type'],
            'activation': c['activation'],
            'norm': bn_str if c['model_type'] == 'resnet' else 'none',
            'dropout': dropout_str,
            'dataset': c['dataset_name'],
            'alpha': f"{r['alpha']:.3f}" if r['alpha'] else 'N/A',
            'r_squared': f"{r['r_squared']:.3f}" if r['r_squared'] else 'N/A'
        })
    
    summary_df = pd.DataFrame(summary_data)
    summary_path = os.path.join(output_dir, 'complete_ablation_summary.csv')
    summary_df.to_csv(summary_path, index=False)
    
    print(f"\n{'='*70}")
    print("COMPLETE ABLATION STUDY RESULTS")
    print(f"{'='*70}")
    print(summary_df.to_string())
    print(f"\nSaved to: {summary_path}")
    
    return results


# ==================== Main ====================

if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    output_dir = './experiment_results'
    os.makedirs(output_dir, exist_ok=True)
    
    # Option 1: Run only missing experiments
    # run_missing_experiments(device, output_dir)
    
    # Option 2: Run complete ablation study
    run_complete_ablation(device, output_dir)
    
    # Option 3: Run single experiment
    # result = run_experiment(
    #     model_type='cnn',
    #     dataset_name='cifar10',
    #     activation='relu',
    #     optimizer_type='sgd',
    #     dropout_rate=0.0,
    #     use_bn=False,
    #     device=device,
    #     output_dir=output_dir
    # )