import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from scipy.stats import pearsonr
import pandas as pd
from tqdm import tqdm
import warnings
import os
import torchaudio
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
warnings.filterwarnings('ignore')

# ==================== Model Definitions ====================

class HomogeneousCNN(nn.Module):
    """Homogeneous CNN with circular padding and He initialization"""
    def __init__(self, depth, channels, kernel_size=3, num_classes=10, 
                 input_channels=3, activation='relu'):
        super().__init__()
        self.depth = depth
        self.channels = channels
        self.activation_type = activation
        
        layers = []
        in_channels = input_channels
        
        # Build homogeneous convolutional blocks
        for i in range(depth):
            # Circular padding is approximated with reflection padding in PyTorch
            padding = kernel_size // 2
            conv = nn.Conv2d(in_channels, channels, kernel_size, 
                           stride=1, padding=padding, padding_mode='circular', bias=False)
            
            # He initialization
            if activation == 'relu':
                nn.init.kaiming_normal_(conv.weight, mode='fan_in', nonlinearity='relu')
            else:  # gelu
                nn.init.kaiming_normal_(conv.weight, mode='fan_in', nonlinearity='linear')
                # Adjust for GELU variance
                with torch.no_grad():
                    conv.weight.data *= np.sqrt(2.0)
            
            layers.append(conv)
            
            # Add activation
            if activation == 'relu':
                layers.append(nn.ReLU(inplace=True))
            elif activation == 'gelu':
                layers.append(nn.GELU())
            
            in_channels = channels
        
        self.features = nn.Sequential(*layers)
        
        # Global average pooling and classifier
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(channels, num_classes, bias=False)
        
        # He initialization for classifier
        nn.init.kaiming_normal_(self.classifier.weight, mode='fan_in', nonlinearity='linear')
    
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
    
    def get_effective_depth(self):
        return self.depth


class BasicBlock(nn.Module):
    """Simplified residual block with single conv+activation, similar to CNN single layer"""
    def __init__(self, in_channels, out_channels, stride=1, activation='relu',total_blocks=1):
        super().__init__()
        self.activation_type = activation
        
        # Single convolution layer like CNN
        self.conv = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)
        
        # He initialization
        if activation == 'relu':
            nn.init.kaiming_normal_(self.conv.weight, mode='fan_in', nonlinearity='relu')
        else:  # gelu
            nn.init.kaiming_normal_(self.conv.weight, mode='fan_in', nonlinearity='linear')
            with torch.no_grad():
                self.conv.weight.data *= np.sqrt(2.0)

        with torch.no_grad():
            if total_blocks is not None and total_blocks > 0:
                self.conv.weight.data /= float(total_blocks)
        # Simple shortcut connection (identity if same dimensions)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride, bias=False)
            )
            if activation == 'relu':
                nn.init.kaiming_normal_(self.shortcut[0].weight, mode='fan_in', nonlinearity='relu')
            else:
                nn.init.kaiming_normal_(self.shortcut[0].weight, mode='fan_in', nonlinearity='linear')
                with torch.no_grad():
                    self.shortcut[0].weight.data *= np.sqrt(2.0)
    
    def forward(self, x):
        # Single convolution
        out = self.conv(x)
        
        # Activation
        if self.activation_type == 'relu':
            out = F.relu(out)
        else:
            out = F.gelu(out)
        
        # Add residual connection
        shortcut = self.shortcut(x)
        out = out + shortcut
        
        return out


class PreActResNet(nn.Module):
    """Simplified ResNet without BatchNorm, closer to CNN but with residuals"""
    def __init__(self, depth, num_classes=10, input_channels=3, activation='relu'):
        super().__init__()
        self.depth = depth
        self.activation_type = activation
        
        # Build sequential residual blocks starting from input directly
        layers = []
        in_channels = input_channels
        out_channels = 64
        
        # First block projects input channels to 64
        if depth > 0:
            layers.append(BasicBlock(in_channels, out_channels, stride=1, activation=activation, total_blocks=depth))
            in_channels = out_channels
        
        # Remaining blocks keep 64 channels
        for i in range(1, depth):
            layers.append(BasicBlock(in_channels, out_channels, stride=1, activation=activation,total_blocks=depth))
        
        self.features = nn.Sequential(*layers)
        
        # Final layers (same as CNN)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(out_channels, num_classes, bias=False)
        nn.init.kaiming_normal_(self.classifier.weight, mode='fan_in', nonlinearity='linear')
    
    def forward(self, x):
        # Pass through residual blocks directly from input
        out = self.features(x)
        
        # Global average pooling and classification
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.classifier(out)
        return out
    
    def get_effective_depth(self):
        return self.depth


# ==================== Vision Transformer Model Definitions ====================

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2    
        self.projection = nn.Linear(patch_size * patch_size * in_channels, embed_dim)
        nn.init.kaiming_normal_(self.projection.weight, mode='fan_in', nonlinearity='linear')
        
    def forward(self, x):
        # x: (batch_size, channels, height, width)
        B, C, H, W = x.shape
        # (B, C, H, W) -> (B, C, H/P, P, W/P, P)
        x = x.view(B, C, H // self.patch_size, self.patch_size, W // self.patch_size, self.patch_size)
        # -> (B, H/P, W/P, P, P, C)
        x = x.permute(0, 2, 4, 3, 5, 1).contiguous()
        # -> (B, N, P*P*C)
        x = x.view(B, -1, self.patch_size * self.patch_size * C)   
        x = self.projection(x) # (B, N, embed_dim)
        return x


class MultiHeadSelfAttention(nn.Module):
    """Multi-head self-attention with residual connection"""
    def __init__(self, embed_dim=768, num_heads=12, dropout_rate=0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        self.proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.dropout = nn.Dropout(dropout_rate)
        
        # He initialization
        nn.init.kaiming_normal_(self.qkv.weight, mode='fan_in', nonlinearity='linear')
        nn.init.kaiming_normal_(self.proj.weight, mode='fan_in', nonlinearity='linear')
    
    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape
        
        # Generate Q, K, V
        qkv = self.qkv(x)  # (batch_size, seq_len, embed_dim * 3)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, batch_size, num_heads, seq_len, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Scaled dot-product attention
        scale = self.head_dim ** -0.5
        attn = (q @ k.transpose(-2, -1)) * scale
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        
        # Apply attention to values
        out = attn @ v  # (batch_size, num_heads, seq_len, head_dim)
        out = out.transpose(1, 2)  # (batch_size, seq_len, num_heads, head_dim)
        out = out.reshape(batch_size, seq_len, embed_dim)
        
        # Final projection
        out = self.proj(out)
        return out


class MLP(nn.Module):
    """MLP block with residual connection"""
    def __init__(self, embed_dim=768, mlp_ratio=4, dropout_rate=0.0):
        super().__init__()
        self.fc1 = nn.Linear(embed_dim, int(embed_dim * mlp_ratio), bias=False)
        self.fc2 = nn.Linear(int(embed_dim * mlp_ratio), embed_dim, bias=False)
        self.dropout = nn.Dropout(dropout_rate)
        
        # He initialization
        nn.init.kaiming_normal_(self.fc1.weight, mode='fan_in', nonlinearity='relu')
        nn.init.kaiming_normal_(self.fc2.weight, mode='fan_in', nonlinearity='linear')
    
    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x)  # Use GELU activation
        x = self.dropout(x)
        x = self.fc2(x)
        return x


class TransformerBlock(nn.Module):
    """Transformer block with two residual connections (attention + MLP)"""
    def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4, dropout_rate=0.0):
        super().__init__()
        self.embed_dim = embed_dim
        
        # First residual block: Multi-head self-attention
        self.attention = MultiHeadSelfAttention(embed_dim, num_heads, dropout_rate)
        self.norm1 = nn.LayerNorm(embed_dim)
        
        # Second residual block: MLP
        self.mlp = MLP(embed_dim, mlp_ratio, dropout_rate)
        self.norm2 = nn.LayerNorm(embed_dim)
    
    def forward(self, x):
        # First residual block: attention + residual + layernorm
        attn_out = self.attention(x)
        x = x + attn_out  # Residual connection
        x = self.norm1(x)  # Layer normalization
        
        # Second residual block: MLP + residual + layernorm
        mlp_out = self.mlp(x)
        x = x + mlp_out  # Residual connection
        x = self.norm2(x)  # Layer normalization
        
        return x


class VisionTransformer(nn.Module):
    """Vision Transformer model"""
    def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, dropout_rate=0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.depth = depth
        
        # Patch embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.n_patches
        
        # Positional embedding
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # Initialize positional and class token embeddings with small std (engineering practice)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout_rate)
            for _ in range(depth)
        ])
        
        # Final layer norm and classifier
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes, bias=False)
        
        # He initialization for classifier
        nn.init.kaiming_normal_(self.head.weight, mode='fan_in', nonlinearity='linear')
    
    def forward(self, x):
        batch_size = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x)  # (batch_size, num_patches, embed_dim)
        
        # Add class token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)  # (batch_size, num_patches + 1, embed_dim)
        
        # Add positional embedding
        x = x + self.pos_embed
        
        # Pass through transformer blocks
        for block in self.blocks:
            x = block(x)
        
        # Final layer norm
        x = self.norm(x)
        
        # Extract class token and classify
        cls_token_final = x[:, 0]  # (batch_size, embed_dim)
        logits = self.head(cls_token_final)
        
        return logits
    
    def get_effective_depth(self):
        # Each transformer block has 2 residual connections (attention + MLP)
        # Plus patch embedding and final classifier
        return self.depth * 2 + 2




# ==================== ViT Variants ====================

# ==================== ViT Variants ====================

class DeiT(VisionTransformer):
    """DeiT: Data-efficient Image Transformer (ViT + Distillation Token)
    Note: Distillation token is included but no distillation loss is used during training,
    similar to DeiT-III/DeiT architecture.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        # Add distillation token
        self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        nn.init.trunc_normal_(self.dist_token, std=0.02)
        
        # Resize positional embedding to accommodate extra token
        # Original pos_embed is (1, num_patches + 1, embed_dim)
        num_patches = self.patch_embed.n_patches
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        
        # Distillation head
        num_classes = kwargs.get('num_classes', 1000)
        self.dist_head = nn.Linear(self.embed_dim, num_classes, bias=False)
        nn.init.kaiming_normal_(self.dist_head.weight, mode='fan_in', nonlinearity='linear')
        
    def forward(self, x):
        batch_size = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x)
        
        # Add class and distillation tokens
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        dist_tokens = self.dist_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, dist_tokens, x), dim=1)
        
        # Add positional embedding
        x = x + self.pos_embed
        
        # Transformer blocks
        for block in self.blocks:
            x = block(x)
            
        x = self.norm(x)
        
        # Average of two heads for final prediction
        cls_logits = self.head(x[:, 0])
        dist_logits = self.dist_head(x[:, 1])
        return (cls_logits + dist_logits) / 2
        
    def get_effective_depth(self):
        # Same as ViT: PatchEmbed(1) + 2*L + Head(1)
        # Distillation token doesn't add depth
        return self.depth * 2 + 2


class SeqPool(nn.Module):
    """Sequence Pooling for CCT"""
    def __init__(self, embed_dim):
        super().__init__()
        self.attention_pool = nn.Linear(embed_dim, 1)
        
    def forward(self, x):
        # x: (batch_size, seq_len, embed_dim)
        attn_weights = F.softmax(self.attention_pool(x), dim=1)  # (batch_size, seq_len, 1)
        x = torch.matmul(x.transpose(1, 2), attn_weights)  # (batch_size, embed_dim, 1)
        x = x.squeeze(2)  # (batch_size, embed_dim)
        return x

class CCT(VisionTransformer):
    """Compact Convolutional Transformer
    
    Architecture:
    1. Convolutional Tokenizer (instead of PatchEmbed)
       - Typically 1 or 2 conv layers with pooling
    2. Transformer Blocks (Standard)
    3. Sequence Pooling (instead of CLS token)
    4. Linear Head
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        img_size = kwargs.get('img_size', 32)
        embed_dim = self.embed_dim
        in_channels = kwargs.get('in_channels', 3)
        
        # Replace PatchEmbed with Conv Tokenizer
        # For simplicity and effective depth calculation, we use a 2-layer Conv Tokenizer
        # Layer 1: Conv3x3 -> ReLU -> MaxPool (downsample /2)
        # Layer 2: Conv3x3 -> ReLU -> MaxPool (downsample /2)
        # Total downsample: /4 (similar to patch_size=4)
        
        self.tokenizer = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            
            nn.Conv2d(64, embed_dim, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        # He init for Tokenizer
        nn.init.kaiming_normal_(self.tokenizer[0].weight, mode='fan_in', nonlinearity='relu')
        nn.init.kaiming_normal_(self.tokenizer[3].weight, mode='fan_in', nonlinearity='relu')
        
        # Remove original patch_embed, pos_embed, cls_token
        self.patch_embed = None
        self.cls_token = None
        
        # CCT uses learnable positional embedding for the sequence
        # We need to calculate sequence length
        # Assuming input is square
        feature_size = img_size // 4
        num_patches = feature_size * feature_size
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        
        # Sequence Pooling
        self.seq_pool = SeqPool(embed_dim)
        
        # Remove standard norm and head, rebuild them
        # (Though they are same structure, we want clean init)
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, kwargs.get('num_classes', 10), bias=False)
        nn.init.kaiming_normal_(self.head.weight, mode='fan_in', nonlinearity='linear')

    def forward(self, x):
        # 1. Conv Tokenizer
        x = self.tokenizer(x)  # (batch_size, embed_dim, H', W')
        
        # Flatten
        x = x.flatten(2).transpose(1, 2)  # (batch_size, seq_len, embed_dim)
        
        # 2. Add Positional Embedding
        if x.shape[1] == self.pos_embed.shape[1]:
            x = x + self.pos_embed
        else:
            # Handle potential size mismatch if any (e.g. different img_size)
            x = x + self.pos_embed[:, :x.shape[1], :]
        
        # 3. Transformer Blocks
        for block in self.blocks:
            x = block(x)
            
        # 4. Sequence Pooling (Attention-based aggregation)
        x = self.norm(x)
        x = self.seq_pool(x)
        
        # 5. Head
        logits = self.head(x)
        return logits
        
    def get_effective_depth(self):
        # Conv Tokenizer: 2 conv layers = 2 depth
        # Blocks: 2 * L
        # Sequence Pooling: 1 Attention-like op = 1 depth
        # Head: 1 Linear = 1 depth
        # Total = 2L + 4
        return self.depth * 2 + 4


class BeitAttention(MultiHeadSelfAttention):
    """Multi-head self-attention with relative position bias"""
    def __init__(self, embed_dim, num_heads, dropout_rate, num_patches):
        super().__init__(embed_dim, num_heads, dropout_rate)
        self.num_patches = num_patches
        # Relative position bias: (1, num_heads, seq_len, seq_len)
        # seq_len = num_patches + 1 (CLS)
        seq_len = num_patches + 1
        self.rel_pos_bias = nn.Parameter(torch.zeros(1, num_heads, seq_len, seq_len))
        nn.init.trunc_normal_(self.rel_pos_bias, std=0.02)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape
        
        qkv = self.qkv(x)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        scale = self.head_dim ** -0.5
        attn = (q @ k.transpose(-2, -1)) * scale
        
        # Add relative position bias
        # Handle shape mismatch if any (e.g. if seq_len changed, though unlikely here)
        if self.rel_pos_bias.shape[2] == seq_len and self.rel_pos_bias.shape[3] == seq_len:
            attn = attn + self.rel_pos_bias
            
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        
        out = attn @ v
        out = out.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
        out = self.proj(out)
        return out


class BeitBlock(nn.Module):
    """Transformer block for BEiT (uses BeitAttention)"""
    def __init__(self, embed_dim, num_heads, mlp_ratio, dropout_rate, num_patches):
        super().__init__()
        self.embed_dim = embed_dim
        
        # Attention with relative position bias
        self.attention = BeitAttention(embed_dim, num_heads, dropout_rate, num_patches)
        self.norm1 = nn.LayerNorm(embed_dim)
        
        # MLP
        self.mlp = MLP(embed_dim, mlp_ratio, dropout_rate)
        self.norm2 = nn.LayerNorm(embed_dim)
    
    def forward(self, x):
        attn_out = self.attention(x)
        x = x + attn_out
        x = self.norm1(x)
        
        mlp_out = self.mlp(x)
        x = x + mlp_out
        x = self.norm2(x)
        return x


class Beit(VisionTransformer):
    """BEiT: BERT Pre-Training of Image Transformers (Architecture variant)"""
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        # Replace blocks with BeitBlocks
        num_patches = self.patch_embed.n_patches
        # Retrieve args used in super().__init__
        embed_dim = self.embed_dim
        num_heads = kwargs.get('num_heads', 12)
        mlp_ratio = kwargs.get('mlp_ratio', 4)
        dropout_rate = kwargs.get('dropout_rate', 0.0)
        
        self.blocks = nn.ModuleList([
            BeitBlock(embed_dim, num_heads, mlp_ratio, dropout_rate, num_patches)
            for _ in range(self.depth)
        ])
        
        # Remove absolute positional embedding (BEiT relies on relative)
        self.pos_embed = None
        
    def forward(self, x):
        batch_size = x.shape[0]
        x = self.patch_embed(x)
        
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        # No absolute pos embed
        
        for block in self.blocks:
            x = block(x)
            
        x = self.norm(x)
        cls_token_final = x[:, 0]
        logits = self.head(cls_token_final)
        return logits


def get_data_loaders(dataset_name, batch_size=128):
    """Load MNIST, CIFAR-10, or CIFAR-100"""
    if dataset_name == 'mnist':
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),  # Convert to 3 channels for consistency
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        trainset = torchvision.datasets.MNIST(root='../data', train=True, 
                                             download=False, transform=transform)
        testset = torchvision.datasets.MNIST(root='../data', train=False,
                                            download=False, transform=transform)
        input_channels = 3  # Modified to 3 for consistency
        num_classes = 10
    
    elif dataset_name == 'cifar10':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        trainset = torchvision.datasets.CIFAR10(root='../data', train=True,
                                               download=False, transform=transform)
        testset = torchvision.datasets.CIFAR10(root='../data', train=False,
                                              download=False, transform=transform)
        input_channels = 3
        num_classes = 10
    
    elif dataset_name == 'cifar100':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
        ])
        trainset = torchvision.datasets.CIFAR100(root='../data', train=True,
                                                download=False, transform=transform)
        testset = torchvision.datasets.CIFAR100(root='../data', train=False,
                                               download=False, transform=transform)
        input_channels = 3
        num_classes = 100
    
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                             shuffle=True, num_workers=2)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                            shuffle=False, num_workers=2)
    
    return trainloader, testloader, input_channels, num_classes


# ==================== Audio Data Loading ====================



# ==================== ImageNet Data Loading ====================

def get_imagenet_data_loaders(batch_size=128):
    """Load ImageNet dataset from local files or use CIFAR-100 as substitute"""
    
    # Check if ImageNet data exists locally
    imagenet_dir = Path("../data/imagenet")
    if imagenet_dir.exists() and (imagenet_dir / "train").exists():
        print("Loading ImageNet from local files...")
        # Load actual ImageNet data
        transform = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        
        trainset = torchvision.datasets.ImageFolder(imagenet_dir / "train", transform=transform)
        testset = torchvision.datasets.ImageFolder(imagenet_dir / "val", transform=transform)
        
        input_channels = 3
        num_classes = len(trainset.classes)
        
    else:
        print("ImageNet not found locally. Using CIFAR-100 as substitute...")
        # Use CIFAR-100 as ImageNet substitute
        transform = transforms.Compose([
            transforms.Resize(224),  # Resize to ImageNet size
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        
        trainset = torchvision.datasets.CIFAR100(
            root='data', train=True, download=False, transform=transform
        )
        testset = torchvision.datasets.CIFAR100(
            root='data', train=False, download=False, transform=transform
        )
        
        input_channels = 3
        num_classes = 100  # CIFAR-100 has 100 classes
    
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    return trainloader, testloader, input_channels, num_classes


def get_vit_data_loaders(dataset_name, batch_size=128, img_size=224):
    """Load data for Vision Transformer experiments"""
    if dataset_name == 'cifar10':
        transform = transforms.Compose([
            transforms.Resize(img_size),  # Resize to ViT input size
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        trainset = torchvision.datasets.CIFAR10(root='../data', train=True,
                                               download=False, transform=transform)
        testset = torchvision.datasets.CIFAR10(root='../data', train=False,
                                              download=False, transform=transform)
        input_channels = 3
        num_classes = 10
    
    elif dataset_name == 'cifar100':
        transform = transforms.Compose([
            transforms.Resize(img_size),  # Resize to ViT input size
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
        ])
        trainset = torchvision.datasets.CIFAR100(root='../data', train=True,
                                                download=False, transform=transform)
        testset = torchvision.datasets.CIFAR100(root='../data', train=False,
                                               download=False, transform=transform)
        input_channels = 3
        num_classes = 100
    
    elif dataset_name == 'imagenet':
        # Use the existing ImageNet data loading function
        return get_imagenet_data_loaders(batch_size)
    
    else:
        raise ValueError(f"Dataset {dataset_name} not supported for ViT")
    
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    return trainloader, testloader, input_channels, num_classes


# ==================== Multi-GPU Configuration ====================

def setup_multi_gpu():
    """设置多GPU环境"""
    if torch.cuda.is_available():
        num_gpus = torch.cuda.device_count()
        print(f"检测到 {num_gpus} 个GPU:")
        for i in range(num_gpus):
            gpu_name = torch.cuda.get_device_name(i)
            gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
            print(f"  GPU {i}: {gpu_name} ({gpu_memory:.1f} GB)")
        
        if num_gpus >= 8:
            print("✅ 检测到8个或更多GPU，将使用多GPU并行训练")
            return True, num_gpus
        elif num_gpus > 1:
            print(f"⚠️  检测到 {num_gpus} 个GPU，将使用多GPU并行训练")
            return True, num_gpus
        else:
            print("⚠️  只检测到1个GPU，将使用单GPU训练")
            return False, 1
    else:
        print("❌ 未检测到CUDA GPU，将使用CPU训练")
        return False, 0


# ==================== Grid Search for Optimal LR ====================

def train_one_epoch(model, dataloader, learning_rate, device, max_batches=None, use_multi_gpu=False):
    """Train for one epoch and return final loss"""
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)
    
    # 如果使用多GPU，包装模型
    if use_multi_gpu and torch.cuda.device_count() > 1:
        if not isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
            # 使用DataParallel进行多GPU并行
            model = nn.DataParallel(model)
            # 将模型移到主GPU
            model = model.to(device)
        print(f"使用 {torch.cuda.device_count()} 个GPU进行并行训练")
    
    total_loss = 0
    num_batches = 0
    
    for batch_idx, (inputs, targets) in enumerate(dataloader):
        if max_batches and batch_idx >= max_batches:
            break
            
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
    
    return total_loss / num_batches


def grid_search_lr(model_class, model_kwargs, dataloader, lr_range, device, 
                   num_trials=3, max_batches=100, use_multi_gpu=False):
    """Grid search for optimal learning rate"""
    best_lr = None
    best_loss = float('inf')
    losses = []
    
    for lr in tqdm(lr_range, desc=f"Grid search (depth={model_kwargs.get('depth', 'N/A')})"):
        trial_losses = []
        
        for trial in range(num_trials):
            # Reinitialize model for each trial
            model = model_class(**model_kwargs).to(device)
            loss = train_one_epoch(model, dataloader, lr, device, max_batches, use_multi_gpu)
            trial_losses.append(loss)
        
        avg_loss = np.mean(trial_losses)
        losses.append(avg_loss)
        
        if avg_loss < best_loss:
            best_loss = avg_loss
            best_lr = lr
    
    return best_lr, best_loss, losses


# ==================== Segmented Experiment with Multiple Baselines ====================

def run_segmented_experiment(model_type='cnn', dataset_name='cifar10', 
                            activation='relu', device='cuda'):
    """
    Run experiment with segmented baseline calculation:
    - Use depths 3-4 to calculate k for depths 5-9
    - Use depths 10-11 to calculate k for depths 12-15
    - And so on...
    """
    
    print(f"\n{'='*60}")
    print(f"Running segmented experiment: {model_type.upper()} with {activation.upper()} on {dataset_name.upper()}")
    print(f"{'='*60}\n")
    
    # Load data
    trainloader, testloader, input_channels, num_classes = get_data_loaders(dataset_name)
    
    # Define depth segments for baseline calculation
    if model_type == 'cnn':
        # Define all depths to test (at least 15 points)
        all_depths = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 20, 22, 24, 26, 28, 30]
        
        # Define segments: (baseline_depths, prediction_depths)
        segments = [
            ([3, 4], [5, 6, 7, 8, 9]),
            ([10, 11], [12, 13, 14, 15, 16]),
            ([18, 20], [22, 24, 26, 28, 30])
        ]
        
        model_configs = []
        for depth in all_depths:
            model_configs.append({
                'class': HomogeneousCNN,
                'kwargs': {
                    'depth': depth,
                    'channels': 64,
                    'kernel_size': 3,
                    'num_classes': num_classes,
                    'input_channels': input_channels,
                    'activation': activation
                },
                'depth': depth
            })
    
    elif model_type == 'resnet':
        # Simple depth-based ResNet configurations (like CNN)
        all_depths = [ ]
        
        # Define segments for ResNet
        segments = [
            ([3, 4], [5, 6, 7, 8, 9]),
            ([10, 11], [12, 13, 14, 15, 16]),
            ([18, 20], [22, 24, 26, 28, 30])
        ]
        
        model_configs = []
        for depth in all_depths:
            model_configs.append({
                'class': PreActResNet,
                'kwargs': {
                    'depth': depth,
                    'num_classes': num_classes,
                    'input_channels': input_channels,
                    'activation': activation
                },
                'depth': depth
            })
    
    # Learning rate search range
    lr_range = np.logspace(-5, 0, 60)  # Slightly fewer points for faster execution
    
    # Store all results
    all_results = []
    
    # Grid search for ALL models first
    print("Phase 1: Grid searching optimal LR for all depths...")
    for config in model_configs:
        print(f"\nTesting depth L={config['depth']}...")
        
        best_lr, best_loss, losses = grid_search_lr(
            config['class'], 
            config['kwargs'], 
            trainloader, 
            lr_range, 
            device,
            num_trials=2,  # Fewer trials for speed
            max_batches=100  # Limit batches for faster execution
        )
        
        all_results.append({
            'depth': config['depth'],
            'best_lr': best_lr,
            'best_loss': best_loss,
            'segment': None  # Will be assigned later
        })
        
    
    # Process results by segments
    print("\n" + "="*60)
    print("Phase 2: Calculating segmented predictions...")
    print("="*60)
    
    theoretical_alpha = -3/2
    segment_results = []
    
    for seg_idx, (baseline_depths, prediction_depths) in enumerate(segments):
        print(f"\nSegment {seg_idx + 1}:")
        print(f"  Baseline depths: {baseline_depths}")
        print(f"  Prediction depths: {prediction_depths}")
        
        # Get baseline results
        baseline_results = [r for r in all_results if r['depth'] in baseline_depths]
        
        if len(baseline_results) < 2:
            print(f"  Warning: Not enough baseline points in segment {seg_idx + 1}")
            continue
        
        # Calculate k from baseline using average
        baseline_ks = []
        for br in baseline_results:
            k_i = br['best_lr'] * (br['depth'] ** (-theoretical_alpha))
            baseline_ks.append(k_i)
        
        k_segment = np.mean(baseline_ks)
        print(f"  Calculated k = {k_segment:.6f}")
        
        # Make predictions for this segment
        for depth in prediction_depths:
            actual_result = next((r for r in all_results if r['depth'] == depth), None)
            if actual_result and actual_result['best_lr'] is not None:
                predicted_lr = k_segment * (depth ** theoretical_alpha)
                
                segment_results.append({
                    'segment': seg_idx + 1,
                    'depth': depth,
                    'actual_lr': actual_result['best_lr'],
                    'predicted_lr': predicted_lr,
                    'relative_error': abs(actual_result['best_lr'] - predicted_lr) / actual_result['best_lr'],
                    'is_baseline': False
                })
                
                print(f"    Depth {depth}: Actual={actual_result['best_lr']:.6f}, "
                      f"Predicted={predicted_lr:.6f}, Error={segment_results[-1]['relative_error']:.2%}")
            elif actual_result and actual_result['best_lr'] is None:
                print(f"    Depth {depth}: Warning - No valid learning rate found (best_lr is None)")
            else:
                print(f"    Depth {depth}: Warning - No results found for this depth")
        
        # Also add baseline points to results
        for br in baseline_results:
            if br['best_lr'] is not None:
                segment_results.append({
                    'segment': seg_idx + 1,
                    'depth': br['depth'],
                    'actual_lr': br['best_lr'],
                    'predicted_lr': br['best_lr'],  # For baseline, predicted = actual
                    'relative_error': 0.0,
                    'is_baseline': True
                })
            else:
                print(f"    Warning: Baseline depth {br['depth']} has no valid learning rate")
    
    # Create comprehensive plot
    plt.figure(figsize=(15, 10))
    
    # Plot 1: Segmented predictions
    plt.subplot(2, 2, 1)
    colors = ['blue', 'green', 'red', 'purple', 'orange']
    
    for seg_idx, (baseline_depths, prediction_depths) in enumerate(segments):
        seg_data = [r for r in segment_results if r['segment'] == seg_idx + 1]
        
        if not seg_data:
            continue
            
        depths = [r['depth'] for r in seg_data]
        actual_lrs = [r['actual_lr'] for r in seg_data]
        predicted_lrs = [r['predicted_lr'] for r in seg_data]
        is_baseline = [r['is_baseline'] for r in seg_data]
        
        # Plot actual values
        baseline_mask = np.array(is_baseline)
        pred_mask = ~baseline_mask
        
        if np.any(baseline_mask):
            plt.scatter(np.array(depths)[baseline_mask], np.array(actual_lrs)[baseline_mask], 
                       color=colors[seg_idx % len(colors)], s=100, alpha=0.8, 
                       marker='s', label=f'Segment {seg_idx + 1} baseline')
        
        if np.any(pred_mask):
            plt.scatter(np.array(depths)[pred_mask], np.array(actual_lrs)[pred_mask], 
                       color=colors[seg_idx % len(colors)], s=100, alpha=0.5, 
                       marker='o', label=f'Segment {seg_idx + 1} actual')
            
            # Plot predictions as lines
            pred_depths = np.array(depths)[pred_mask]
            pred_lrs = np.array(predicted_lrs)[pred_mask]
            sorted_idx = np.argsort(pred_depths)
            plt.plot(pred_depths[sorted_idx], pred_lrs[sorted_idx], 
                    '--', color=colors[seg_idx % len(colors)], alpha=0.5, linewidth=2)
    
    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel('Depth L')
    plt.ylabel('Optimal Learning Rate η*')
    plt.title(f'Segmented Predictions\n{model_type.upper()} with {activation.upper()} on {dataset_name.upper()}')
    plt.legend(fontsize=8)
    plt.grid(True, alpha=0.3)
    
    # Plot 2: All points with global fit
    plt.subplot(2, 2, 2)
    # Filter out results with None best_lr
    valid_results = [r for r in all_results if r['best_lr'] is not None]
    if not valid_results:
        print("Warning: No valid results found for plotting")
        return
    
    all_depths_array = np.array([r['depth'] for r in valid_results])
    all_lrs_array = np.array([r['best_lr'] for r in valid_results])
    
    # Fit global power law
    log_depths = np.log(all_depths_array)
    log_lrs = np.log(all_lrs_array)
    reg = LinearRegression()
    reg.fit(log_depths.reshape(-1, 1), log_lrs)
    global_alpha = reg.coef_[0]
    global_k = np.exp(reg.intercept_)
    
    plt.scatter(all_depths_array, all_lrs_array, s=100, alpha=0.7, label='Grid Search')
    
    # Plot theoretical line with global fit
    depth_range = np.linspace(min(all_depths_array), max(all_depths_array), 100)
    theoretical_line = global_k * (depth_range ** global_alpha)
    plt.plot(depth_range, theoretical_line, 'r--', linewidth=2, 
             label=f'Global fit: η ∝ L^({global_alpha:.3f})')
    
    # Plot ideal theoretical line
    k_ideal = all_lrs_array[0] * (all_depths_array[0] ** 1.5)
    ideal_line = k_ideal * (depth_range ** (-1.5))
    plt.plot(depth_range, ideal_line, 'g-.', linewidth=2, 
             label='Theory: η ∝ L^(-1.5)')
    
    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel('Depth L')
    plt.ylabel('Optimal Learning Rate η*')
    plt.title('Global Power Law Fit')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Plot 3: Relative errors by segment
    plt.subplot(2, 2, 3)
    segment_errors = {}
    for r in segment_results:
        if not r['is_baseline']:
            seg = r['segment']
            if seg not in segment_errors:
                segment_errors[seg] = {'depths': [], 'errors': []}
            segment_errors[seg]['depths'].append(r['depth'])
            segment_errors[seg]['errors'].append(r['relative_error'] * 100)
    
    for seg_idx, data in segment_errors.items():
        plt.scatter(data['depths'], data['errors'], 
                   color=colors[(seg_idx-1) % len(colors)], 
                   s=80, alpha=0.7, label=f'Segment {seg_idx}')
    
    plt.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
    plt.axhline(y=10, color='red', linestyle='--', linewidth=1, alpha=0.5)
    plt.axhline(y=-10, color='red', linestyle='--', linewidth=1, alpha=0.5)
    plt.xlabel('Depth L')
    plt.ylabel('Relative Error (%)')
    plt.title('Prediction Errors by Segment')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Plot 4: Learning rate vs depth (linear scale)
    plt.subplot(2, 2, 4)
    plt.plot(all_depths_array, all_lrs_array, 'o-', markersize=8, linewidth=2)
    plt.xlabel('Depth L')
    plt.ylabel('Optimal Learning Rate η*')
    plt.title('Linear Scale View')
    plt.grid(True, alpha=0.3)
    
    # Update the total depths tested count
    total_depths_tested = len(valid_results)
    
    plt.tight_layout()
    plt.savefig(f'{model_type}_{activation}_{dataset_name}_segmented.png', dpi=150)
    plt.show()
    
    # Save detailed results to JSON
    import json
    results_data = {
        'model_type': model_type,
        'activation': activation,
        'dataset': dataset_name,
        'all_results': all_results,
        'segment_results': segment_results,
        'global_alpha': global_alpha,
        'theoretical_alpha': theoretical_alpha,
        'experiment_timestamp': pd.Timestamp.now().isoformat()
    }
    
    # Save to JSON file
    results_filename = f'{model_type}_{activation}_{dataset_name}_experiment_data.json'
    with open(results_filename, 'w') as f:
        json.dump(results_data, f, indent=2, default=str)
    print(f"✅ 实验结果已保存到: {results_filename}")
    
    # Calculate overall statistics
    all_prediction_errors = [r['relative_error'] for r in segment_results if not r['is_baseline']]
    
    print(f"\n{'='*60}")
    print(f"Overall Statistics:")
    print(f"{'='*60}")
    print(f"Model: {model_type.upper()}, Activation: {activation.upper()}, Dataset: {dataset_name.upper()}")
    print(f"Total depths tested: {total_depths_tested}")
    print(f"Global fitted exponent: {global_alpha:.4f}")
    print(f"Theoretical exponent: {theoretical_alpha:.4f}")
    print(f"Difference: {abs(global_alpha - theoretical_alpha):.4f}")
    
    if all_prediction_errors:
        print(f"\nSegmented prediction statistics:")
        print(f"  Mean relative error: {np.mean(all_prediction_errors):.2%}")
        print(f"  Median relative error: {np.median(all_prediction_errors):.2%}")
        print(f"  Max relative error: {np.max(all_prediction_errors):.2%}")
        print(f"  Predictions within 10% error: {sum(e < 0.1 for e in all_prediction_errors)}/{len(all_prediction_errors)}")
    
    # Create detailed results DataFrame
    df_results = pd.DataFrame(all_results)
    df_results = df_results.sort_values('depth')
    
    print(f"\n{'='*60}")
    print("Detailed Results Table:")
    print(f"{'='*60}")
    print(df_results[['depth', 'best_lr', 'best_loss']].to_string(index=False))
    
    return {
        'model_type': model_type,
        'activation': activation,
        'dataset': dataset_name,
        'all_results': all_results,
        'segment_results': segment_results,
        'global_alpha': global_alpha,
        'theoretical_alpha': theoretical_alpha
    }



def run_vit_experiment(dataset_name='cifar10', device='cuda', img_size=224, use_multi_gpu=True):
    """Run Vision Transformer experiment with different depths"""
    
    print(f"\n{'='*60}")
    print(f"Running ViT experiment on {dataset_name.upper()}")
    print(f"{'='*60}\n")
    
    # 检查GPU数量
    if use_multi_gpu and torch.cuda.is_available():
        num_gpus = torch.cuda.device_count()
        print(f"检测到 {num_gpus} 个GPU，将使用多GPU并行训练")
        if num_gpus >= 8:
            print("✅ 检测到8个或更多GPU，将充分利用多GPU加速")
        else:
            print(f"⚠️  只有 {num_gpus} 个GPU，建议使用8卡以获得最佳性能")
    else:
        print("⚠️  未检测到GPU或禁用多GPU，将使用CPU或单GPU")
        use_multi_gpu = False
    
    # Load data with larger batch size for multi-GPU
    if use_multi_gpu and torch.cuda.is_available():
        batch_size = 64 * torch.cuda.device_count()  # 每个GPU 64个样本
        print(f"使用多GPU，batch size设置为: {batch_size}")
    else:
        batch_size = 128
        print(f"使用单GPU/CPU，batch size设置为: {batch_size}")
    
    trainloader, testloader, input_channels, num_classes = get_vit_data_loaders(dataset_name, batch_size=batch_size, img_size=img_size)
    
    # Define depths for ViT (smaller depths due to computational cost)
    # Each transformer block has 2 residual connections, so we test fewer blocks
    all_depths = [4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27]
    
    # Define segments for ViT
    segments = [
        ([4, 5], [6, 7, 8, 9, 10, 11,12]),
        ([12, 13], [14, 15, 16, 17, 18,19]),
        ([20,21],[22,23,24,25,26,27]),
    ]
    
    # Set ViT parameters based on dataset
    if dataset_name == 'cifar10':
        embed_dim = 384
        num_heads = 6
        patch_size = 4  # Smaller patch size for CIFAR-10
    elif dataset_name == 'cifar100':
        embed_dim = 384
        num_heads = 6
        patch_size = 4  # Smaller patch size for CIFAR-100
    else:  # imagenet
        embed_dim = 768
        num_heads = 12
        patch_size = 16  # Standard patch size for ImageNet
    
    model_configs = []
    for depth in all_depths:
        model_configs.append({
            'class': VisionTransformer,
            'kwargs': {
                'img_size': img_size,
                'patch_size': patch_size,
                'in_channels': input_channels,
                'num_classes': num_classes,
                'embed_dim': embed_dim,
                'depth': depth,
                'num_heads': num_heads,
                'mlp_ratio': 4,
                'dropout_rate': 0.0
            },
            'depth': depth
        })
    
    # Learning rate search range
    lr_range = np.logspace(-5, 0, 60)
    
    # Store all results
    all_results = []
    
    # Grid search for ALL models first
    print("Phase 1: Grid searching optimal LR for all depths...")
    for config in model_configs:
        print(f"\nTesting depth L={config['depth']}...")
        
        best_lr, best_loss, losses = grid_search_lr(
            config['class'], 
            config['kwargs'], 
            trainloader, 
            lr_range, 
            device,
            num_trials=2,
            max_batches=1000,  # Smaller batches for ViT
            use_multi_gpu=use_multi_gpu
        )
        
        all_results.append({
            'depth': config['depth'],
            'best_lr': best_lr,
            'best_loss': best_loss,
            'segment': None
        })
        
        print(f"Depth {config['depth']}: Best LR = {best_lr:.6f}, Loss = {best_loss:.4f}")
    
    # Process results by segments
    theoretical_alpha = -3/2
    segment_results = []
    
    for seg_idx, (baseline_depths, prediction_depths) in enumerate(segments):
        print(f"\nSegment {seg_idx + 1}:")
        print(f"  Baseline depths: {baseline_depths}")
        print(f"  Prediction depths: {prediction_depths}")
        
        # Get baseline results
        baseline_results = [r for r in all_results if r['depth'] in baseline_depths]
        
        if len(baseline_results) < 2:
            print(f"  Warning: Not enough baseline points in segment {seg_idx + 1}")
            continue
        
        # Calculate k from baseline using average
        # For ViT, use effective depth (depth * 2 + 2)
        baseline_ks = []
        for br in baseline_results:
            effective_depth = br['depth'] * 2 + 2  # Each transformer block has 2 residual connections + 2 for patch embedding and classifier
            k_i = br['best_lr'] * (effective_depth ** (-theoretical_alpha))
            baseline_ks.append(k_i)
        
        k_segment = np.mean(baseline_ks)
        print(f"  Calculated k = {k_segment:.6f}")
        
        # Make predictions for this segment
        for depth in prediction_depths:
            actual_result = next((r for r in all_results if r['depth'] == depth), None)
            if actual_result and actual_result['best_lr'] is not None:
                effective_depth = depth * 2 + 2
                predicted_lr = k_segment * (effective_depth ** theoretical_alpha)
                
                segment_results.append({
                    'segment': seg_idx + 1,
                    'depth': depth,
                    'actual_lr': actual_result['best_lr'],
                    'predicted_lr': predicted_lr,
                    'relative_error': abs(actual_result['best_lr'] - predicted_lr) / actual_result['best_lr'],
                    'is_baseline': False
                })
                
                print(f"    Depth {depth}: Actual={actual_result['best_lr']:.6f}, "
                      f"Predicted={predicted_lr:.6f}, Error={segment_results[-1]['relative_error']:.2%}")
    
    # Create plot
    plt.figure(figsize=(12, 8))
    
    # Plot 1: Segmented predictions
    plt.subplot(2, 2, 1)
    colors = ['blue', 'green', 'red', 'purple', 'orange']
    
    for seg_idx, (baseline_depths, prediction_depths) in enumerate(segments):
        seg_data = [r for r in segment_results if r['segment'] == seg_idx + 1]
        
        if not seg_data:
            continue
            
        depths = [r['depth'] for r in seg_data]
        actual_lrs = [r['actual_lr'] for r in seg_data]
        predicted_lrs = [r['predicted_lr'] for r in seg_data]
        is_baseline = [r['is_baseline'] for r in seg_data]
        
        # Plot actual values
        baseline_mask = np.array(is_baseline)
        pred_mask = ~baseline_mask
        
        if np.any(baseline_mask):
            plt.scatter(np.array(depths)[baseline_mask], np.array(actual_lrs)[baseline_mask], 
                       color=colors[seg_idx % len(colors)], s=100, alpha=0.8, 
                       marker='s', label=f'Segment {seg_idx + 1} baseline')
        
        if np.any(pred_mask):
            plt.scatter(np.array(depths)[pred_mask], np.array(actual_lrs)[pred_mask], 
                       color=colors[seg_idx % len(colors)], s=100, alpha=0.5, 
                       marker='o', label=f'Segment {seg_idx + 1} actual')
            
            # Plot predictions as lines
            pred_depths = np.array(depths)[pred_mask]
            pred_lrs = np.array(predicted_lrs)[pred_mask]
            sorted_idx = np.argsort(pred_depths)
            plt.plot(pred_depths[sorted_idx], pred_lrs[sorted_idx], 
                    '--', color=colors[seg_idx % len(colors)], alpha=0.5, linewidth=2)
    
    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel('Depth L')
    plt.ylabel('Optimal Learning Rate η*')
    plt.title(f'ViT Experiment on {dataset_name.upper()}')
    plt.legend(fontsize=8)
    plt.grid(True, alpha=0.3)
    
    # Plot 2: Global fit
    plt.subplot(2, 2, 2)
    valid_results = [r for r in all_results if r['best_lr'] is not None]
    if valid_results:
        all_depths_array = np.array([r['depth'] for r in valid_results])
        all_lrs_array = np.array([r['best_lr'] for r in valid_results])
        
        # For ViT, use effective depth (depth * 2 + 2) for fitting
        effective_depths_array = all_depths_array * 2 + 2
        
        # Fit global power law
        log_depths = np.log(effective_depths_array)
        log_lrs = np.log(all_lrs_array)
        reg = LinearRegression()
        reg.fit(log_depths.reshape(-1, 1), log_lrs)
        global_alpha = reg.coef_[0]
        global_k = np.exp(reg.intercept_)
        
        plt.scatter(all_depths_array, all_lrs_array, s=100, alpha=0.7, label='Grid Search')
        
        # Plot theoretical line
        depth_range = np.linspace(min(all_depths_array), max(all_depths_array), 100)
        effective_depth_range = depth_range * 2 + 2
        theoretical_line = global_k * (effective_depth_range ** global_alpha)
        plt.plot(depth_range, theoretical_line, 'r--', linewidth=2, 
                 label=f'Global fit: η ∝ L^({global_alpha:.3f})')
        
        plt.xscale('log')
        plt.yscale('log')
        plt.xlabel('Depth L')
        plt.ylabel('Optimal Learning Rate η*')
        plt.title('Global Power Law Fit')
        plt.legend()
        plt.grid(True, alpha=0.3)
    
    # Plot 3: Relative errors by segment
    plt.subplot(2, 2, 3)
    segment_errors = {}
    for r in segment_results:
        if not r['is_baseline']:
            seg = r['segment']
            if seg not in segment_errors:
                segment_errors[seg] = {'depths': [], 'errors': []}
            segment_errors[seg]['depths'].append(r['depth'])
            segment_errors[seg]['errors'].append(r['relative_error'] * 100)
    
    for seg_idx, data in segment_errors.items():
        plt.scatter(data['depths'], data['errors'], 
                   color=colors[(seg_idx-1) % len(colors)], 
                   s=80, alpha=0.7, label=f'Segment {seg_idx}')
    
    plt.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
    plt.axhline(y=10, color='red', linestyle='--', linewidth=1, alpha=0.5)
    plt.axhline(y=-10, color='red', linestyle='--', linewidth=1, alpha=0.5)
    plt.xlabel('Depth L')
    plt.ylabel('Relative Error (%)')
    plt.title('Prediction Errors by Segment')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Plot 4: Learning rate vs depth (linear scale)
    plt.subplot(2, 2, 4)
    if valid_results:
        plt.plot(all_depths_array, all_lrs_array, 'o-', markersize=8, linewidth=2)
    plt.xlabel('Depth L')
    plt.ylabel('Optimal Learning Rate η*')
    plt.title('Linear Scale View')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f'vit_{dataset_name}.png', dpi=150)
    plt.show()
    
    # Save detailed results to JSON
    import json
    results_data = {
        'model_type': 'vit',
        'dataset': dataset_name,
        'img_size': img_size,
        'embed_dim': embed_dim,
        'num_heads': num_heads,
        'patch_size': patch_size,
        'all_results': all_results,
        'segment_results': segment_results,
        'global_alpha': global_alpha if valid_results else None,
        'theoretical_alpha': theoretical_alpha,
        'experiment_timestamp': pd.Timestamp.now().isoformat()
    }
    
    # Save to JSON file
    results_filename = f'vit_{dataset_name}_experiment_data.json'
    with open(results_filename, 'w') as f:
        json.dump(results_data, f, indent=2, default=str)
    print(f"✅ 实验结果已保存到: {results_filename}")
    
    # Calculate and print summary statistics
    all_prediction_errors = [r['relative_error'] for r in segment_results if not r['is_baseline']]
    
    print(f"\n{'='*60}")
    print(f"ViT 实验结果总结:")
    print(f"{'='*60}")
    print(f"数据集: {dataset_name.upper()}")
    print(f"图像尺寸: {img_size}x{img_size}")
    print(f"嵌入维度: {embed_dim}")
    print(f"注意力头数: {num_heads}")
    print(f"Patch 尺寸: {patch_size}x{patch_size}")
    print(f"测试深度数: {len(valid_results)}")
    print(f"全局拟合指数: {global_alpha:.4f}" if global_alpha else "N/A")
    print(f"理论指数: {theoretical_alpha:.4f}")
    if global_alpha:
        print(f"指数差异: {abs(global_alpha - theoretical_alpha):.4f}")
    
    if all_prediction_errors:
        print(f"\n分段预测统计:")
        print(f"  平均相对误差: {np.mean(all_prediction_errors):.2%}")
        print(f"  中位数相对误差: {np.median(all_prediction_errors):.2%}")
        print(f"  最大相对误差: {np.max(all_prediction_errors):.2%}")
        print(f"  10%误差内预测数: {sum(e < 0.1 for e in all_prediction_errors)}/{len(all_prediction_errors)}")
    
    return {
        'model_type': 'vit',
        'dataset': dataset_name,
        'all_results': all_results,
        'segment_results': segment_results,
        'global_alpha': global_alpha if valid_results else None,
        'theoretical_alpha': theoretical_alpha
    }




def run_vit_variants_experiment(dataset_name='cifar10', device='cuda', use_multi_gpu=True):
    """Run experiments for ViT variants (Vanilla, DeiT, BEiT)"""
    
    print(f"\n{'='*60}")
    print(f"Running ViT Variants Experiment on {dataset_name.upper()}")
    print(f"{'='*60}\n")
    
    # Check GPU
    if use_multi_gpu and torch.cuda.is_available():
        num_gpus = torch.cuda.device_count()
        print(f"Detected {num_gpus} GPUs, using multi-GPU training")
    else:
        print("Using single GPU/CPU")
        use_multi_gpu = False
    
    # Load data
    batch_size = 64 * torch.cuda.device_count() if use_multi_gpu and torch.cuda.is_available() else 128
    
    img_size = 32 if dataset_name.startswith('cifar') else 224
    patch_size = 4 if img_size == 32 else 16
    
    trainloader, testloader, input_channels, num_classes = get_vit_data_loaders(dataset_name, batch_size=batch_size, img_size=img_size)
    
    # Define depths to test
    test_depths = [4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27]
    
    variants = ['vit', 'deit', 'cct', 'beit']
    
    all_results = []
    
    for variant in variants:
        print(f"\n{'='*40}")
        print(f"Testing Variant: {variant.upper()}")
        print(f"{'='*40}")
        
        variant_results = []
        
        model_cls = None
        if variant == 'vit':
            model_cls = VisionTransformer
        
        elif variant == 'deit':
            model_cls = DeiT
        elif variant == 'cct':
            model_cls = CCT
        elif variant == 'beit':
            model_cls = Beit
            
        model_kwargs = {
            'img_size': img_size,
            'patch_size': patch_size,
            'in_channels': input_channels,
            'num_classes': num_classes,
            'embed_dim': 384 if dataset_name == 'cifar10' else 768,
            'num_heads': 6 if dataset_name == 'cifar10' else 12,
            'mlp_ratio': 4,
            'dropout_rate': 0.0
        }
        
        lr_range = np.logspace(-5, 0, 80)
        
        print("Phase 1: Grid searching optimal LR...")
        for depth in test_depths:
            print(f"\nTesting depth L={depth} (Effective={depth*2+2})...")
            
            kwargs = model_kwargs.copy()
            kwargs['depth'] = depth
            
            best_lr, best_loss, losses = grid_search_lr(
                model_cls,
                kwargs,
                trainloader,
                lr_range,
                device,
                num_trials=1,
                max_batches=800,
                use_multi_gpu=use_multi_gpu
            )
            
            # Calculate effective depth based on variant
            if variant == 'cct':
                eff_depth = depth * 2 + 4
            else:
                eff_depth = depth * 2 + 2

            variant_results.append({
                'depth': depth,
                'effective_depth': eff_depth,
                'best_lr': best_lr,
                'best_loss': best_loss,
                'variant': variant,
                'all_losses': losses if isinstance(losses, list) else losses.tolist(),
                'lr_range': lr_range.tolist()
            })
            
            print(f"Variant {variant}, Depth {depth} (Eff {eff_depth}): Best LR = {best_lr:.6f}")
            
        all_results.append((variant, variant_results))

    
    # Plot Loss Curves for debugging
    try:
        plt.figure(figsize=(15, 10))
        for i, (variant, results) in enumerate(all_results):
            plt.subplot(2, 2, i+1)
            for r in results:
                depth = r['depth']
                losses = r['all_losses']
                lrs = r['lr_range']
                # Filter out None or infinite losses for plotting
                valid_indices = [j for j, l in enumerate(losses) if l is not None and not np.isinf(l) and not np.isnan(l)]
                if valid_indices:
                    valid_lrs = [lrs[j] for j in valid_indices]
                    valid_losses = [losses[j] for j in valid_indices]
                    plt.semilogx(valid_lrs, valid_losses, label=f'D={depth}')
                    
                    # Mark best LR
                    best_lr = r['best_lr']
                    best_loss = r['best_loss']
                    if best_lr is not None:
                        plt.scatter(best_lr, best_loss, marker='*', s=100, edgecolors='black')
                
            plt.title(f'{variant.upper()} Loss Landscapes')
            plt.xlabel('Learning Rate')
            plt.ylabel('Loss')
            plt.legend(fontsize='x-small')
            plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(f'vit_variants_{dataset_name}_loss_landscapes.png', dpi=150)
        plt.show()
    except Exception as e:
        print(f'Error plotting loss landscapes: {e}')


    
    # Plot Loss Curves for debugging
    try:
        plt.figure(figsize=(15, 10))
        for i, (variant, results) in enumerate(all_results):
            plt.subplot(2, 2, i+1)
            for r in results:
                depth = r['depth']
                losses = r['all_losses']
                lrs = r['lr_range']
                # Filter out None or infinite losses for plotting
                valid_indices = [j for j, l in enumerate(losses) if l is not None and not np.isinf(l) and not np.isnan(l)]
                if valid_indices:
                    valid_lrs = [lrs[j] for j in valid_indices]
                    valid_losses = [losses[j] for j in valid_indices]
                    plt.semilogx(valid_lrs, valid_losses, label=f'D={depth}')
                    
                    # Mark best LR
                    best_lr = r['best_lr']
                    best_loss = r['best_loss']
                    if best_lr is not None:
                        plt.scatter(best_lr, best_loss, marker='*', s=100, edgecolors='black')
                
            plt.title(f'{variant.upper()} Loss Landscapes')
            plt.xlabel('Learning Rate')
            plt.ylabel('Loss')
            plt.legend(fontsize='x-small')
            plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(f'vit_variants_{dataset_name}_loss_landscapes.png', dpi=150)
        plt.show()
    except Exception as e:
        print(f'Error plotting loss landscapes: {e}')


    # Plotting Scaling Laws
    plt.figure(figsize=(12, 8))
    
    colors = {'vit': 'blue', 'deit': 'green', 'cct': 'purple', 'beit': 'red'}
    markers = {'vit': 'o', 'deit': 's', 'cct': 'D', 'beit': '^'}
    
    for variant, results in all_results:
        depths = [r['effective_depth'] for r in results]
        lrs = [r['best_lr'] for r in results]
        
        valid_points = [(d, l) for d, l in zip(depths, lrs) if l is not None]
        if not valid_points: continue
        
        vd, vl = zip(*valid_points)
        vd = np.array(vd)
        vl = np.array(vl)
        
        plt.scatter(vd, vl, label=f'{variant.upper()} Actual', color=colors[variant], marker=markers[variant], s=100, alpha=0.7)
        
        log_d = np.log(vd)
        log_l = np.log(vl)
        reg = LinearRegression().fit(log_d.reshape(-1, 1), log_l)
        alpha = reg.coef_[0]
        k = np.exp(reg.intercept_)
        
        d_range = np.linspace(min(vd), max(vd), 100)
        l_fit = k * (d_range ** alpha)
        
        plt.plot(d_range, l_fit, '--', color=colors[variant], label=f'{variant.upper()} Fit (α={alpha:.2f})')
        
    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel('Effective Depth L')
    plt.ylabel('Optimal Learning Rate η*')
    plt.title(f'ViT Variants Scaling Law on {dataset_name.upper()}')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f'vit_variants_{dataset_name}.png', dpi=150)
    plt.show()
    
    import json
    results_data = {
        'dataset': dataset_name,
        'results': {v: r for v, r in all_results}
    }
    with open(f'vit_variants_{dataset_name}.json', 'w') as f:
        json.dump(results_data, f, indent=2, default=str)
        
    print(f"Results saved to vit_variants_{dataset_name}.json")
    return all_results


# ==================== Quick Test Function ====================

def quick_test():
    """Quick test with fewer depths for debugging"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Quick test with CNN on CIFAR-10
    trainloader, _, _, num_classes = get_data_loaders('cifar10', batch_size=128)
    
    # Test just a few depths
    test_depths = [3, 4, 6, 8, 10, 12]
    lr_range = np.logspace(-3, 0, 20)  # Fewer LR points
    
    results = []
    for depth in test_depths:
        print(f"\nTesting depth {depth}...")
        model_kwargs = {
            'depth': depth,
            'channels': 64,
            'kernel_size': 3,
            'num_classes': num_classes,
            'input_channels': 3,
            'activation': 'relu'
        }
        
        best_lr, best_loss, _ = grid_search_lr(
            HomogeneousCNN, 
            model_kwargs, 
            trainloader, 
            lr_range, 
            device,
            num_trials=1,
            max_batches=800
        )
        
        results.append({
            'depth': depth,
            'best_lr': best_lr,
            'best_loss': best_loss
        })
        
        print(f"  Best LR: {best_lr:.6f}, Loss: {best_loss:.4f}")
    
    # Quick analysis
    depths = np.array([r['depth'] for r in results])
    lrs = np.array([r['best_lr'] for r in results])
    
    # Fit power law
    log_depths = np.log(depths)
    log_lrs = np.log(lrs)
    reg = LinearRegression()
    reg.fit(log_depths.reshape(-1, 1), log_lrs)
    alpha = reg.coef_[0]
    
    print(f"\nFitted exponent: {alpha:.4f}")
    print(f"Theoretical: -1.5")
    print(f"Difference: {abs(alpha + 1.5):.4f}")
    
    # Plot
    plt.figure(figsize=(8, 6))
    plt.scatter(depths, lrs, s=100, alpha=0.7, label='Grid Search')
    
    # Theoretical line
    k_theory = lrs[0] * (depths[0] ** 1.5)
    theory_lrs = k_theory * (depths ** (-1.5))
    plt.plot(depths, theory_lrs, 'r--', linewidth=2, label='Theory: η ∝ L^(-3/2)')
    
    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel('Depth L')
    plt.ylabel('Optimal Learning Rate η*')
    plt.title('Quick Test: CNN with ReLU on CIFAR-10')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
    
    return results


if __name__ == "__main__":
    # Choose which experiment to run:
    
    # Option 1: Quick test (for debugging)
    # results = quick_test()
    
    # Option 2: Single detailed experiment with segmented baselines
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Setup multi-GPU environment
    print("="*80)
    print("GPU 环境检测")
    print("="*80)
    use_multi_gpu, num_gpus = setup_multi_gpu()
    print(f"使用设备: {device}")
    print(f"多GPU模式: {'启用' if use_multi_gpu else '禁用'}")
    print("="*80)
    
    # # Test CNN with ReLU
    # result = run_segmented_experiment('cnn', 'cifar10', 'relu', device)
    
    # # Test CNN with GELU
    # result = run_segmented_experiment('cnn', 'cifar10', 'gelu', device)
    
    # Test ResNet with ReLU
    # result = run_segmented_experiment('resnet', 'cifar10', 'relu', device)
    # result = run_segmented_experiment('resnet', 'cifar100', 'relu', device)
    
    # ==================== New Experiments ====================
    
    # Experiment 1: 1D CNN on audio datasets
    print("="*80)
    print("EXPERIMENT 1: 1D CNN on Audio Datasets")
    print("="*80)
    
    # # Test 1D CNN on Google Speech Commands v2
    # result = run_audio_experiment('1dcnn', 'speech_commands', 'relu', device)
    # result = run_audio_experiment('1dcnn', 'speech_commands', 'gelu', device)
    
    # # Test 1D CNN on ESC-50
    # result = run_audio_experiment('1dcnn', 'esc50', 'relu', device)
    # result = run_audio_experiment('1dcnn', 'esc50', 'gelu', device)
    
    # Experiment 2: ResNet with regularization on CIFAR
    print("="*80)
    print("EXPERIMENT 2: ResNet with Regularization on CIFAR")
    print("="*80)
    
    # # Test ResNet with different regularization on CIFAR-10
    # result = run_resnet_regularization_experiment('cifar10', 'relu', 'none', device)
    # result = run_resnet_regularization_experiment('cifar10', 'relu', 'dropout', device)
    # result = run_resnet_regularization_experiment('cifar10', 'relu', 'batchnorm', device)
    # result = run_resnet_regularization_experiment('cifar10', 'relu', 'both', device)
    
    # # Test ResNet with different regularization on CIFAR-100
    # result = run_resnet_regularization_experiment('cifar100', 'relu', 'none', device)
    # result = run_resnet_regularization_experiment('cifar100', 'relu', 'dropout', device)
    # result = run_resnet_regularization_experiment('cifar100', 'relu', 'batchnorm', device)
    # result = run_resnet_regularization_experiment('cifar100', 'relu', 'both', device)
    
    # Experiment 3: ImageNet Experiments
    print("="*80)
    print("EXPERIMENT 3: ImageNet Experiments")
    print("="*80)
    
    # Test 2D CNN on ImageNet
    print("\n--- 2D CNN on ImageNet ---")
    # result = run_imagenet_cnn_experiment('relu', device)
    # result = run_imagenet_cnn_experiment('gelu', device)
    
    # Test ResNet on ImageNet with different regularization
    print("\n--- ResNet on ImageNet ---")
    # result = run_imagenet_resnet_experiment('relu', 'none', device)
    # result = run_imagenet_resnet_experiment('relu', 'dropout', device)
    # result = run_imagenet_resnet_experiment('relu', 'batchnorm', device)
    # result = run_imagenet_resnet_experiment('relu', 'both', device)
    
    # result = run_imagenet_resnet_experiment('gelu', 'none', device)
    # result = run_imagenet_resnet_experiment('gelu', 'dropout', device)
    # result = run_imagenet_resnet_experiment('gelu', 'batchnorm', device)
    # result = run_imagenet_resnet_experiment('gelu', 'both', device)
    
    # Experiment 4: Vision Transformer Experiments
    print("="*80)
    print("EXPERIMENT 4: Vision Transformer Experiments")
    print("="*80)
    
    # Test ViT on CIFAR-10 with multi-GPU
    print("\n--- ViT on CIFAR-10 (Multi-GPU) ---")
    # result = run_vit_experiment('cifar10', device, img_size=32, use_multi_gpu=use_multi_gpu)
    
    # # Test ViT on CIFAR-100 with multi-GPU
    # print("\n--- ViT on CIFAR-100 (Multi-GPU) ---")
    # result = run_vit_experiment('cifar100', device, img_size=32, use_multi_gpu=use_multi_gpu)
    
    # # Test ViT on ImageNet with multi-GPU
    # print("\n--- ViT on ImageNet (Multi-GPU) ---")
    # result = run_vit_experiment('imagenet', device, img_size=224, use_multi_gpu=use_multi_gpu)
    
    # Option 3: Run all original experiments
    # all_results = run_all_experiments()
    # ==================== ViT Variants Experiments ====================
    print("="*80)
    print("EXPERIMENT 5: ViT Variants (Vanilla, DeiT, CCT, BEiT)")
    print("="*80)
    
    # Run on CIFAR-10
    run_vit_variants_experiment('cifar10', device, use_multi_gpu)
    
    # Run on CIFAR-100
    run_vit_variants_experiment('cifar100', device, use_multi_gpu)
    
    # Run on ImageNet (uncomment if data available)
    run_vit_variants_experiment('imagenet', device, use_multi_gpu)

