import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision.datasets import ImageFolder
import wandb
import matplotlib.pyplot as plt
import numpy as np
import os
import argparse
from tqdm import tqdm
import copy
import json
import time
from datetime import datetime
from pathlib import Path
from datasets import load_dataset


# Create directories for saving results
os.makedirs('results', exist_ok=True)
os.makedirs('results/plots', exist_ok=True)
os.makedirs('results/models', exist_ok=True)
os.makedirs('results/ablation_studies', exist_ok=True)


#####################################################
# COMMON COMPONENTS
#####################################################

class ScaledDotProductAttention(nn.Module):
    """
    Standard scaled dot-product attention.
    """
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, q, k, v, mask=None):
        d_k = q.size(-1)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attention = F.softmax(scores, dim=-1)
        attention = self.dropout(attention)
        return torch.matmul(attention, v), attention


class RotaryPositionalEncoding(nn.Module):
    """
    Rotary Positional Encoding (RoPE) as used by RoFormer.
    """
    def __init__(self, d_model, n_heads, max_len=5000):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.max_len = max_len
        self.register_buffer("cos_cached", torch.zeros((max_len, self.d_k // 2)))
        self.register_buffer("sin_cached", torch.zeros((max_len, self.d_k // 2)))
        self._update_cos_sin_tables()

    def _update_cos_sin_tables(self, seq_len=None):
        if seq_len is None:
            seq_len = self.max_len
        positions = torch.arange(seq_len).unsqueeze(1)
        div_term = 1 / (10000 ** (torch.arange(0, self.d_k, 2) / self.d_k))
        angles = positions * div_term
        self.cos_cached[:seq_len] = torch.cos(angles)
        self.sin_cached[:seq_len] = torch.sin(angles)

    def forward(self, x):
        seq_len = x.size(2)
        if seq_len > self.cos_cached.size(0):
            self._update_cos_sin_tables(seq_len)
        cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(0)
        sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(0)
        return self.rotate_half(x, cos, sin)

    def rotate_half(self, x, cos, sin):
        x = x.float()
        x_rot = x[:, :, :, : self.d_k // 2]
        x_pass = x[:, :, :, self.d_k // 2 :]
        x_rot_neg = torch.cat([-x_rot[:, :, :, 1::2], x_rot[:, :, :, ::2]], dim=-1)
        x_rot = (x_rot * cos) + (x_rot_neg * sin)
        x = torch.cat([x_rot, x_pass], dim=-1)
        return x.type_as(self.cos_cached)


#####################################################
# SPECIFIC ATTENTION MECHANISMS
#####################################################

# We keep the FourierRoFormer (the novel variant) and the RoFormer baseline.
# For the baseline "vit" and "deit" variants, we rely on standard dot-product attention 
# (without any rotary positional encoding).

class FourierRoFormer(nn.Module):
    """
    FourierRoFormer uses a learned combination of Fourier components to modulate attention.
    """
    def __init__(self, d_model, n_heads, dropout=0.1, num_components=4, 
                 enable_fourier=True, enable_damping=True):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.dropout = nn.Dropout(dropout)
        self.num_components = num_components
        self.rope = RotaryPositionalEncoding(d_model, n_heads)
        self.frequencies = nn.Parameter(torch.linspace(0.1, 2.0, num_components))
        self.amplitudes = nn.Parameter(torch.ones(num_components) * 0.1)
        self.phases = nn.Parameter(torch.zeros(num_components))
        self.gamma = nn.Parameter(torch.ones(1) * 0.01)
        self.enable_fourier = enable_fourier
        self.enable_damping = enable_damping
        
    def forward(self, q, k, v, mask=None):
        d_k = q.size(-1)
        seq_len = q.size(2)
        q_rotated = self.rope(q)
        k_rotated = self.rope(k)
        scores = torch.matmul(q_rotated, k_rotated.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        positions = torch.arange(0, seq_len, device=q.device).unsqueeze(0).unsqueeze(0).unsqueeze(0)
        distances = torch.abs(positions - positions.transpose(-2, -1)).float()
        modified_scores = scores.clone()
        if self.enable_damping:
            damping = torch.exp(-self.gamma * distances)
            modified_scores = modified_scores * damping
        if self.enable_fourier:
            fourier_sum = torch.zeros_like(distances, dtype=torch.float)
            for i in range(self.num_components):
                fourier_sum += self.amplitudes[i] * torch.cos(self.frequencies[i] * distances + self.phases[i])
            fourier_modulation = torch.tanh(fourier_sum) * 0.5 + 0.5
            modified_scores = modified_scores * fourier_modulation
        attention = nn.functional.softmax(modified_scores, dim=-1)
        attention = self.dropout(attention)
        return torch.matmul(attention, v), attention


#####################################################
# INTEGRATION COMPONENTS
#####################################################

class EnhancedMultiHeadAttention(nn.Module):
    """
    Multi-head attention that supports different variants.
    For the “roformer” variant, rotary embeddings are applied.
    For “vit” and “deit”, only vanilla attention is used.
    For “fourier”, the FourierRoFormer is used.
    """
    def __init__(self, d_model, n_heads, dropout=0.1, attention_type='roformer', 
                 gamma_init=0.01, omega_init=1.0, num_fourier_components=4,
                 ablation_config=None):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.attention_type = attention_type
        
        # Linear projections for Q, K, V and output
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        
        # Default ablation configuration if none provided
        if ablation_config is None:
            ablation_config = {
                'enable_damping': True,
                'enable_fourier': True,
                'num_fourier_components': num_fourier_components
            }
        
        if self.attention_type == 'roformer':
            self.rope = RotaryPositionalEncoding(d_model, n_heads)
            self.attention = ScaledDotProductAttention(dropout)
        elif self.attention_type == 'fourier':
            self.attention = FourierRoFormer(
                d_model, n_heads, dropout, 
                num_components=ablation_config.get('num_fourier_components', num_fourier_components),
                enable_fourier=ablation_config.get('enable_fourier', True),
                enable_damping=ablation_config.get('enable_damping', True)
            )
        # For standard ViT and DeiT, we use vanilla dot-product attention.
        elif self.attention_type in ['vit', 'deit']:
            self.attention = ScaledDotProductAttention(dropout)
        else:
            raise ValueError(f"Unsupported attention type: {attention_type}")
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        # Linear projections and reshape for multi-head attention
        q = self.w_q(q).reshape(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        k = self.w_k(k).reshape(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        v = self.w_v(v).reshape(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        if self.attention_type == 'roformer':
            q_rotated = self.rope(q)
            k_rotated = self.rope(k)
            x, attn = self.attention(q_rotated, k_rotated, v, mask)
        else:
            x, attn = self.attention(q, k, v, mask)
        x = x.transpose(1, 2).contiguous().reshape(batch_size, -1, self.d_model)
        x = self.w_o(x)
        return x, attn


class EnhancedTransformerEncoderLayer(nn.Module):
    """
    Transformer encoder layer for the different variants.
    """
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1, attention_type='roformer', 
                 gamma_init=0.01, omega_init=1.0, num_fourier_components=4,
                 ablation_config=None):
        super().__init__()
        self.attention_type = attention_type
        self.self_attn = EnhancedMultiHeadAttention(
            d_model, n_heads, dropout, attention_type, gamma_init, omega_init, 
            num_fourier_components, ablation_config
        )
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        attn_output, attn_weights = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x, attn_weights


class EnhancedTransformerEncoder(nn.Module):
    """
    Stack of transformer encoder layers.
    """
    def __init__(self, d_model, n_heads, n_layers, d_ff, dropout=0.1, 
                 attention_type='roformer', gamma_init=0.01, omega_init=1.0, 
                 num_fourier_components=4, ablation_config=None):
        super().__init__()
        self.attention_type = attention_type
        self.layers = nn.ModuleList([
            EnhancedTransformerEncoderLayer(
                d_model, n_heads, d_ff, dropout, attention_type, gamma_init, 
                omega_init, num_fourier_components, ablation_config
            )
            for _ in range(n_layers)
        ])
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        attentions = []
        for layer in self.layers:
            x, attn = layer(x, mask)
            attentions.append(attn)
        return x, attentions


class EnhancedVisionTransformer(nn.Module):
    """
    Vision Transformer that supports the different variants.
    For variant 'deit', it adds a distillation token and head.
    For variant 'vit', it behaves as a standard ViT.
    For variants 'roformer' and 'fourier', the corresponding attention modules are used.
    """
    def __init__(self, img_size=32, patch_size=4, in_channels=3, d_model=768, n_heads=12, 
                 n_layers=12, d_ff=3072, num_classes=10, dropout=0.1,
                 attention_type='roformer', gamma_init=0.01, omega_init=1.0, 
                 num_fourier_components=4, ablation_config=None):
        super().__init__()
        self.attention_type = attention_type
        assert img_size % patch_size == 0, "Image size must be divisible by patch size"
        num_patches = (img_size // patch_size) ** 2
        patch_dim = in_channels * patch_size ** 2
        self.patch_embed = nn.Linear(patch_dim, d_model)
        
        # For DeiT variant, add an extra distillation token.
        if attention_type == 'deit':
            self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
            self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))
            # Allocate position embeddings for two extra tokens (cls and dist)
            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, d_model))
            nn.init.normal_(self.cls_token, std=0.02)
            nn.init.normal_(self.dist_token, std=0.02)
        else:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, d_model))
            nn.init.normal_(self.cls_token, std=0.02)
        
        nn.init.normal_(self.pos_embed, std=0.02)
        
        self.transformer = EnhancedTransformerEncoder(
            d_model=d_model,
            n_heads=n_heads,
            n_layers=n_layers,
            d_ff=d_ff,
            dropout=dropout,
            attention_type=attention_type,
            gamma_init=gamma_init,
            omega_init=omega_init,
            num_fourier_components=num_fourier_components,
            ablation_config=ablation_config
        )
        
        self.norm = nn.LayerNorm(d_model)
        # For DeiT, we add a secondary head for the distillation token.
        if attention_type == 'deit':
            self.head = nn.Linear(d_model, num_classes)
            self.head_dist = nn.Linear(d_model, num_classes)
        else:
            self.head = nn.Linear(d_model, num_classes)
            
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        
    def forward(self, x):
        batch_size = x.size(0)
        # Reshape image into patches
        x = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        x = x.contiguous().reshape(batch_size, self.in_channels, -1, self.patch_size**2)
        x = x.transpose(1, 2).reshape(batch_size, -1, self.in_channels * self.patch_size**2)
        x = self.patch_embed(x)
        
        if self.attention_type == 'deit':
            cls_tokens = self.cls_token.expand(batch_size, -1, -1)
            dist_tokens = self.dist_token.expand(batch_size, -1, -1)
            x = torch.cat([cls_tokens, dist_tokens, x], dim=1)
        else:
            cls_tokens = self.cls_token.expand(batch_size, -1, -1)
            x = torch.cat([cls_tokens, x], dim=1)
            
        x = x + self.pos_embed
        x, attentions = self.transformer(x)
        
        if self.attention_type == 'deit':
            # For DeiT, average the predictions from the class and distillation tokens.
            cls_output = self.norm(x[:, 0])
            dist_output = self.norm(x[:, 1])
            logits = (self.head(cls_output) + self.head_dist(dist_output)) / 2
        else:
            x = self.norm(x[:, 0])
            logits = self.head(x)
            
        return logits, attentions


#####################################################
# DATASET UTILITIES
#####################################################

def get_dataset_stats(dataset_name):
    if dataset_name in ['cifar10', 'cifar100']:
        mean = (0.5, 0.5, 0.5)
        std = (0.5, 0.5, 0.5)
        img_size = 32
        default_patch_size = 4
    elif dataset_name in ['imagenet', 'imagenet-subset']:
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
        img_size = 224
        default_patch_size = 16
    elif dataset_name in ['oxford-pets', 'oxford-flowers', 'oxford-iiit-pet', 'flowers102']:
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
        img_size = 224
        default_patch_size = 16
    elif dataset_name == 'stanford-cars':
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
        img_size = 224
        default_patch_size = 16
    else:
        mean = (0.5, 0.5, 0.5)
        std = (0.5, 0.5, 0.5)
        img_size = 32
        default_patch_size = 4
    return mean, std, img_size, default_patch_size


def prepare_dataset(dataset_name, batch_size=128, img_size=None, subset_size=None):
    mean, std, default_img_size, default_patch_size = get_dataset_stats(dataset_name)
    if img_size is None:
        img_size = default_img_size
    print(f"Preparing {dataset_name} with image size {img_size}...")
    if dataset_name == "cifar10":
        num_classes = 10
        transform_train = transforms.Compose([
            transforms.Resize(img_size) if img_size != 32 else transforms.Lambda(lambda x: x),
            transforms.RandomCrop(img_size, padding=int(img_size/8)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        transform_test = transforms.Compose([
            transforms.Resize(img_size) if img_size != 32 else transforms.Lambda(lambda x: x),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
        testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    elif dataset_name == "cifar100":
        num_classes = 100
        transform_train = transforms.Compose([
            transforms.Resize(img_size) if img_size != 32 else transforms.Lambda(lambda x: x),
            transforms.RandomCrop(img_size, padding=int(img_size/8)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        transform_test = transforms.Compose([
            transforms.Resize(img_size) if img_size != 32 else transforms.Lambda(lambda x: x),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
        testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
    elif dataset_name == "imagenet-subset":
        # num_classes = 100
        # transform_train = transforms.Compose([
        #     transforms.RandomResizedCrop(img_size),
        #     transforms.RandomHorizontalFlip(),
        #     transforms.ToTensor(),
        #     transforms.Normalize(mean, std)
        # ])
        # transform_test = transforms.Compose([
        #     transforms.Resize(int(img_size * 1.14)),
        #     transforms.CenterCrop(img_size),
        #     transforms.ToTensor(),
        #     transforms.Normalize(mean, std)
        # ])
        # imagenet_path = os.environ.get("IMAGENET_PATH", "./data/imagenet")
        # subset_path = os.path.join(os.path.dirname(imagenet_path), "imagenet_subset")
        # if not os.path.exists(subset_path):
        #     print(f"WARNING: ImageNet subset directory not found at {subset_path}")
        #     print("Falling back to CIFAR-100 for testing purposes...")
        #     return prepare_dataset("cifar100", batch_size, img_size, subset_size)
        # trainset = ImageFolder(os.path.join(subset_path, 'train'), transform=transform_train)
        # testset = ImageFolder(os.path.join(subset_path, 'val'), transform=transform_test)

        num_classes = 1000
        transform_train = transforms.Compose([
			transforms.Lambda(lambda x: x.convert('RGB')),
            transforms.Resize(int(img_size * 1.14)),
            transforms.RandomResizedCrop(img_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        transform_test = transforms.Compose([
			transforms.Lambda(lambda x: x.convert('RGB')),
            transforms.Resize(int(img_size * 1.14)),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        
        # Load ImageNet-1K from Hugging Face
        print("Loading ImageNet-1K from Hugging Face...")
        
        # Create a custom dataset wrapper
        class HFImageNetDataset(torch.utils.data.Dataset):
            def __init__(self, split, transform):
                # self.dataset = load_dataset("imagenet-1k", split=split)
                # Use dataset zh-plus/tiny-imagenet
                self.dataset = load_dataset("zh-plus/tiny-imagenet", split=split)
                self.transform = transform
                
            def __len__(self):
                return len(self.dataset)
                
            def __getitem__(self, idx):
                sample = self.dataset[idx]
                image = sample['image']
                label = sample['label']

                if image.mode != 'RGB':
                    image = image.convert('RGB')
                
                if self.transform:
                    image = self.transform(image)
                
                return image, label
        
        # Create datasets and dataloaders
        trainset = HFImageNetDataset('train', transform_train)
        # testset = HFImageNetDataset('validation', transform_test)
        testset = HFImageNetDataset('valid', transform_test)
        
        # Apply subset if specified
        if subset_size is not None and subset_size > 0:
            # Create balanced subsets (simplified for brevity)
            # In a real implementation, ensure class balance
            train_indices = list(range(min(len(trainset), subset_size * 1000)))
            test_indices = list(range(min(len(testset), subset_size * 100)))
            trainset = torch.utils.data.Subset(trainset, train_indices)
            testset = torch.utils.data.Subset(testset, test_indices)
    elif dataset_name in ["oxford-pets", "oxford-iiit-pet"]:
        try:
            num_classes = 37
            transform_train = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
            transform_test = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
            trainset = torchvision.datasets.OxfordIIITPet(
                root='./data', split='trainval', download=True, transform=transform_train)
            testset = torchvision.datasets.OxfordIIITPet(
                root='./data', split='test', download=True, transform=transform_test)
        except:
            print("WARNING: OxfordIIITPet dataset not available. Falling back to CIFAR-10.")
            return prepare_dataset("cifar10", batch_size, img_size, subset_size)
    elif dataset_name in ["flowers102", "oxford-flowers"]:
        try:
            num_classes = 102
            transform_train = transforms.Compose([
                transforms.RandomResizedCrop(img_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
            transform_test = transforms.Compose([
                transforms.Resize(int(img_size * 1.14)),
                transforms.CenterCrop(img_size),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
            trainset = torchvision.datasets.Flowers102(
                root='./data', split='train', download=True, transform=transform_train)
            testset = torchvision.datasets.Flowers102(
                root='./data', split='test', download=True, transform=transform_test)
        except:
            print("WARNING: Flowers102 dataset not available. Falling back to CIFAR-10.")
            return prepare_dataset("cifar10", batch_size, img_size, subset_size)
    elif dataset_name == "stanford-cars":
        try:
            num_classes = 196
            transform_train = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
            transform_test = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
            print("WARNING: StanfordCars dataset setup requires manual download. Falling back to CIFAR-100.")
            return prepare_dataset("cifar100", batch_size, img_size, subset_size)
        except:
            print("ERROR: Failed to load Stanford Cars dataset. Falling back to CIFAR-100.")
            return prepare_dataset("cifar100", batch_size, img_size, subset_size)
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")
    
    if subset_size is not None and subset_size > 0:
        trainset = create_balanced_subset(trainset, subset_size)
        testset = create_balanced_subset(testset, min(subset_size // 2, 50))
    
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
    print(f"Dataset {dataset_name} prepared with {len(trainset)} training samples and {len(testset)} test samples")
    return trainloader, testloader, num_classes


def create_balanced_subset(dataset, samples_per_class):
    if hasattr(dataset, 'targets'):
        targets = dataset.targets
    elif hasattr(dataset, 'classes'):
        targets = [dataset[i][1] for i in range(len(dataset))]
    else:
        try:
            targets = [dataset[i][1] for i in range(len(dataset))]
        except:
            print("WARNING: Could not determine targets for subset creation")
            indices = torch.randperm(len(dataset))[:samples_per_class * 10]
            return Subset(dataset, indices)
    class_indices = {}
    for i, target in enumerate(targets):
        if target not in class_indices:
            class_indices[target] = []
        class_indices[target].append(i)
    selected_indices = []
    for target, indices in class_indices.items():
        n_samples = min(samples_per_class, len(indices))
        selected_indices.extend(indices[:n_samples])
    return Subset(dataset, selected_indices)


#####################################################
# LOGGING UTILITIES
#####################################################

class WandbLogger:
    """
    Wandb logger for tracking experiments.
    """
    def __init__(self, experiment_type='main', group_name=None):
        self.experiment_type = experiment_type
        self.group_name = group_name
    
    def init_run(self, config, run_name, tags=None, job_type=None):
        return wandb.init(
            project="paper-roformer-experiments",
            name=run_name,
            config=config,
            group=self.group_name,
            job_type=job_type or self.experiment_type,
            tags=tags,
            reinit=True
        )
    
    def log_metrics(self, metrics, step=None):
        standard_metrics = {}
        for key in ['train_loss', 'train_acc', 'test_loss', 'test_acc', 'epoch']:
            if key in metrics:
                standard_metrics[key] = metrics[key]
        for key, value in metrics.items():
            if key not in standard_metrics:
                standard_metrics[key] = value
        wandb.log(standard_metrics, step=step)
    
    def log_model_parameters(self, model, epoch):
        params_dict = {}
        if self.group_name is not None and "fourier" in self.group_name:
            for name, module in model.named_modules():
                if isinstance(module, FourierRoFormer):
                    params_dict['fourier_gamma'] = module.gamma.item()
                    frequencies = module.frequencies.detach().cpu().numpy()
                    amplitudes = module.amplitudes.detach().cpu().numpy()
                    plt.figure(figsize=(10, 6))
                    plt.bar(range(len(frequencies)), amplitudes)
                    plt.xticks(range(len(frequencies)), [f"{f:.2f}" for f in frequencies])
                    plt.xlabel("Frequency")
                    plt.ylabel("Amplitude")
                    plt.title(f"Fourier Components at Epoch {epoch}")
                    wandb.log({"fourier_components": wandb.Image(plt)}, step=epoch)
                    plt.close()
                    dominant_idx = np.argmax(amplitudes)
                    params_dict['dominant_frequency'] = frequencies[dominant_idx]
                    params_dict['dominant_amplitude'] = amplitudes[dominant_idx]
        if params_dict:
            wandb.log({f"params/{k}": v for k, v in params_dict.items()}, step=epoch)
    
    def log_attention_maps(self, model, test_loader, device, epoch):
        model.eval()
        images, labels = next(iter(test_loader))
        examples = images[:3].to(device)
        try:
            if hasattr(test_loader.dataset, 'classes') and len(test_loader.dataset.classes) == 10:
                classes = ('plane', 'car', 'bird', 'cat', 'deer', 
                          'dog', 'frog', 'horse', 'ship', 'truck')
            else:
                classes = [str(i) for i in range(1000)]
        except:
            classes = [str(i) for i in range(1000)]
        fig, axes = plt.subplots(3, 2, figsize=(12, 15))
        for i in range(3):
            img = examples[i].cpu().permute(1, 2, 0).numpy()
            if img.shape[0] == 32:
                img = (img * 0.5 + 0.5).clip(0, 1)
            else:
                mean = np.array([0.485, 0.456, 0.406])
                std = np.array([0.229, 0.224, 0.225])
                img = ((img * std) + mean).clip(0, 1)
            axes[i, 0].imshow(img)
            axes[i, 0].set_title(f"Original Image")
            axes[i, 0].axis('off')
            with torch.no_grad():
                outputs, attention_maps = model(examples[i:i+1])
                predicted_class = outputs.argmax(dim=1).item()
                class_name = classes[predicted_class] if predicted_class < len(classes) else f"Class {predicted_class}"
            attn = attention_maps[-1][0, 0].cpu()
            num_patches = attn.shape[1] - 1
            patch_size = model.patch_size
            grid_size = int(math.sqrt(num_patches))
            try:
                patch_attn = attn[0, 1:].reshape(grid_size, grid_size)
            except RuntimeError:
                print(f"Warning: Could not reshape attention map to {grid_size}x{grid_size} grid. Using flattened attention.")
                patch_attn = attn[0, 1:].reshape(-1, 1)
            try:
                patch_attn_upsampled = torch.nn.functional.interpolate(
                    patch_attn.unsqueeze(0).unsqueeze(0),
                    size=(img.shape[0], img.shape[1]),
                    mode='bicubic'
                ).squeeze().numpy()
                axes[i, 1].imshow(img)
                axes[i, 1].imshow(patch_attn_upsampled, alpha=0.7, cmap='hot')
                axes[i, 1].set_title(f"Attention Map ({class_name})")
                axes[i, 1].axis('off')
            except Exception as e:
                print(f"Warning: Could not create attention overlay: {e}. Showing raw attention.")
                axes[i, 1].imshow(patch_attn.numpy(), cmap='hot')
                axes[i, 1].set_title(f"Attention Heatmap ({class_name})")
                axes[i, 1].axis('off')
        plt.tight_layout()
        wandb.log({"attention_maps": wandb.Image(fig)}, step=epoch)
        plt.close(fig)
    
    def create_comparison_table(self, results, columns):
        table = wandb.Table(columns=columns)
        for row in results:
            table.add_data(*[row[col] if col in row else None for col in columns])
        return table
    
    def finish_run(self):
        if wandb.run is not None:
            wandb.run.finish()


#####################################################
# ABLATION STUDY UTILITIES (FOURIER VARIANT ONLY)
#####################################################

def generate_ablation_configs(base_config):
    """
    Generate ablation configurations only for the Fourier variant.
    """
    ablation_configs = []
    base = copy.deepcopy(base_config)
    if 'img_size' not in base or base['img_size'] is None:
        _, _, default_img_size, _ = get_dataset_stats(base['dataset'])
        base['img_size'] = default_img_size if default_img_size is not None else 32
    if 'patch_size' not in base or base['patch_size'] is None:
        _, _, _, default_patch_size = get_dataset_stats(base['dataset'])
        base['patch_size'] = default_patch_size if default_patch_size is not None else 4
    # Baseline Fourier variant (all Fourier enhancements enabled)
    ablation_configs.append({
        'name': 'fourier_baseline',
        'description': 'Fourier RoFormer with Fourier modulation and damping enabled',
        'config': copy.deepcopy(base),
        'ablation_config': {
            'enable_damping': True,
            'enable_fourier': True,
            'num_fourier_components': base.get('num_fourier_components', 4)
        }
    })
    # Ablation: disable Fourier modulation
    config = copy.deepcopy(base)
    ablation_configs.append({
        'name': 'fourier_no_fourier',
        'description': 'Fourier RoFormer with Fourier modulation disabled',
        'config': config,
        'ablation_config': {
            'enable_fourier': False,
            'enable_damping': True
        }
    })
    # Ablation: disable damping
    config = copy.deepcopy(base)
    ablation_configs.append({
        'name': 'fourier_no_damping',
        'description': 'Fourier RoFormer with damping disabled',
        'config': config,
        'ablation_config': {
            'enable_fourier': True,
            'enable_damping': False
        }
    })
    # Ablation: disable both Fourier modulation and damping
    config = copy.deepcopy(base)
    ablation_configs.append({
        'name': 'fourier_no_fourier_no_damping',
        'description': 'Fourier RoFormer with both Fourier modulation and damping disabled',
        'config': config,
        'ablation_config': {
            'enable_fourier': False,
            'enable_damping': False
        }
    })
    # Ablation: vary number of Fourier components
    for num_components in [2, 4, 8, 16]:
        config = copy.deepcopy(base)
        config['num_fourier_components'] = num_components
        ablation_configs.append({
            'name': f'fourier_components_{num_components}',
            'description': f'Fourier RoFormer with {num_components} Fourier components',
            'config': config,
            'ablation_config': {
                'enable_fourier': True,
                'enable_damping': True,
                'num_fourier_components': num_components
            }
        })
    return ablation_configs


def run_ablation_study(base_config, output_dir='results/ablation_studies'):
    os.makedirs(output_dir, exist_ok=True)
    if 'img_size' not in base_config or base_config['img_size'] is None:
        _, _, default_img_size, _ = get_dataset_stats(base_config['dataset'])
        base_config['img_size'] = default_img_size
    # For ablation, override variants to only Fourier
    base_config['variants'] = ['fourier']
    ablation_configs = generate_ablation_configs(base_config)
    ablation_results = {
        'base_config': base_config,
        'ablation_configs': ablation_configs,
        'results': []
    }
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    ablation_group = f"ablation_{base_config['dataset']}_{timestamp}"
    model_type = ablation_configs[0]['config']['variants'][0]
    logger = WandbLogger(experiment_type='ablation', group_name=ablation_group)
    common_tags = [
        f"ablation_{model_type}",
        base_config['dataset'],
        f"d{base_config['d_model']}",
        f"h{base_config['n_heads']}",
        f"l{base_config['n_layers']}"
    ]
    parent_run_name = f"PARENT_{ablation_group}"
    parent_config = {
        **base_config,
        'ablation_group': ablation_group,
        'model_type': model_type,
        'ablation_configs': [c['name'] for c in ablation_configs],
        'timestamp': timestamp
    }
    parent_run = logger.init_run(
        config=parent_config,
        run_name=parent_run_name,
        tags=[*common_tags, "ablation_study"],
        job_type="parent"
    )
    ablation_summary = {
        'Name': [cfg['name'] for cfg in ablation_configs],
        'Description': [cfg['description'] for cfg in ablation_configs]
    }
    ablation_table = logger.create_comparison_table(
        [{k: v for k, v in zip(ablation_summary.keys(), values)} 
         for values in zip(*ablation_summary.values())],
        list(ablation_summary.keys())
    )
    wandb.log({"ablation_configurations": ablation_table})
    logger.finish_run()
    for config_idx, ablation_config in enumerate(ablation_configs):
        print(f"\n========================")
        print(f"Running ablation {config_idx+1}/{len(ablation_configs)}: {ablation_config['name']}")
        print(f"Description: {ablation_config['description']}")
        print(f"========================")
        config = ablation_config['config']
        specific_ablation_config = ablation_config['ablation_config']
        trainloader, testloader, num_classes = prepare_dataset(
            config['dataset'], config['batch_size'], 
            subset_size=config.get('subset_size', None)
        )
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = create_model(config, num_classes, specific_ablation_config)
        model = model.to(device)
        num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Model has {num_params:,} trainable parameters")
        enhanced_config = {
            **config,
            'ablation_name': ablation_config['name'],
            'ablation_description': ablation_config['description'],
            'ablation_group': ablation_group,
            'ablation_timestamp': timestamp,
            'ablation_config': specific_ablation_config,
            'parameter_count': num_params
        }
        run_name = f"{ablation_config['name']}_{config['dataset']}"
        ablation_tags = common_tags.copy()
        for key, value in specific_ablation_config.items():
            if isinstance(value, (bool, int, float)):
                ablation_tags.append(f"{key}_{value}")
        run = logger.init_run(
            config=enhanced_config,
            run_name=run_name,
            tags=[*ablation_tags, "ablation_study"],
            job_type="ablation"
        )
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(
            model.parameters(), 
            lr=config['lr'], 
            weight_decay=config.get('weight_decay', 0.05)
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=config['epochs']
        )
        ablation_epochs = config['epochs']
        train_losses, test_losses, train_accs, test_accs = [], [], [], []
        best_acc = 0
        for epoch in range(ablation_epochs):
            train_loss, train_acc = train_epoch(model, trainloader, criterion, optimizer, device)
            train_losses.append(train_loss)
            train_accs.append(train_acc)
            test_loss, test_acc = evaluate(model, testloader, criterion, device)
            test_losses.append(test_loss)
            test_accs.append(test_acc)
            logger.log_metrics({
                'epoch': epoch,
                'train_loss': train_loss,
                'train_acc': train_acc,
                'test_loss': test_loss,
                'test_acc': test_acc
            })
            logger.log_model_parameters(model, epoch)
            scheduler.step()
            if test_acc > best_acc:
                best_acc = test_acc
                model_dir = os.path.join(output_dir, "models")
                os.makedirs(model_dir, exist_ok=True)
                model_path = os.path.join(model_dir, f"{ablation_config['name']}_best.pth")
                torch.save(model.state_dict(), model_path)
                wandb.run.summary.update({
                    'best_model_epoch': epoch,
                    'best_model_accuracy': best_acc
                })
            if epoch % 20 == 0:
                logger.log_attention_maps(model, testloader, device, epoch)
            print(f"Epoch {epoch+1}/{ablation_epochs} | "
                  f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
                  f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%")
        wandb.run.summary['best_accuracy'] = best_acc
        wandb.run.summary['final_train_loss'] = train_losses[-1]
        wandb.run.summary['final_test_acc'] = test_accs[-1]
        wandb.run.summary['parameter_count'] = num_params
        ablation_results['results'].append({
            'name': ablation_config['name'],
            'description': ablation_config['description'],
            'train_losses': train_losses,
            'test_losses': test_losses,
            'train_accs': train_accs,
            'test_accs': test_accs,
            'best_acc': best_acc,
            'num_params': num_params
        })
        logger.finish_run()
        del model
        torch.cuda.empty_cache()
    summary_run = logger.init_run(
        config=parent_config,
        run_name=f"SUMMARY_{ablation_group}",
        tags=[*common_tags, "summary", "ablation_study"],
        job_type="summary"
    )
    generate_ablation_plots(ablation_results, output_dir, ablation_group)
    comparison_data = {
        'Ablation': [r['name'] for r in ablation_results['results']],
        'Best Accuracy': [r['best_acc'] for r in ablation_results['results']],
        'Parameters': [r['num_params'] for r in ablation_results['results']],
        'Description': [r['description'] for r in ablation_results['results']]
    }
    comparison_table = logger.create_comparison_table(
        [{k: v for k, v in zip(comparison_data.keys(), values)} 
         for values in zip(*comparison_data.values())],
        list(comparison_data.keys())
    )
    wandb.log({"ablation_comparison": comparison_table})
    logger.finish_run()
    results_path = os.path.join(output_dir, f"{ablation_group}_results.json")
    with open(results_path, 'w') as f:
        for result in ablation_results['results']:
            for key in ['train_losses', 'test_losses', 'train_accs', 'test_accs']:
                if isinstance(result[key], np.ndarray):
                    result[key] = result[key].tolist()
            json.dump(ablation_results, f, indent=2)
    return ablation_results


def generate_ablation_plots(ablation_results, output_dir, experiment_id):
    plots_dir = os.path.join(output_dir, "plots")
    os.makedirs(plots_dir, exist_ok=True)
    results = ablation_results['results']
    plt.figure(figsize=(12, 6))
    names = [r['name'] for r in results]
    best_accs = [r['best_acc'] for r in results]
    sorted_indices = np.argsort(best_accs)[::-1]
    names = [names[i] for i in sorted_indices]
    best_accs = [best_accs[i] for i in sorted_indices]
    plt.bar(names, best_accs)
    plt.title('Best Test Accuracy for Different Ablations')
    plt.ylabel('Accuracy (%)')
    plt.xlabel('Configuration')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, f"{experiment_id}_accuracy_comparison.png"))
    plt.close()
    plt.figure(figsize=(12, 6))
    for result in results:
        plt.plot(result['test_accs'], label=result['name'])
    plt.title('Test Accuracy Over Epochs')
    plt.ylabel('Accuracy (%)')
    plt.xlabel('Epoch')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, f"{experiment_id}_accuracy_curves.png"))
    plt.close()
    plt.figure(figsize=(12, 6))
    for result in results:
        plt.plot(result['test_losses'], label=result['name'])
    plt.title('Test Loss Over Epochs')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, f"{experiment_id}_loss_curves.png"))
    plt.close()
    summary_data = {
        'Configuration': [r['name'] for r in results],
        'Best Accuracy': [f"{r['best_acc']:.2f}%" for r in results],
        'Parameters': [f"{r['num_params']:,}" for r in results],
        'Description': [r['description'] for r in results]
    }
    fig, ax = plt.subplots(figsize=(12, len(results) * 0.5 + 1))
    ax.axis('tight')
    ax.axis('off')
    table = ax.table(
        cellText=[list(row) for row in zip(*summary_data.values())],
        colLabels=list(summary_data.keys()),
        loc='center',
        cellLoc='center'
    )
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1, 1.5)
    plt.title('Ablation Study Summary')
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, f"{experiment_id}_summary_table.png"), 
                bbox_inches='tight', dpi=200)
    plt.close()


#####################################################
# EXPERIMENT UTILITIES
#####################################################

def create_model(config, num_classes, ablation_config=None):
    _, _, default_img_size, default_patch_size = get_dataset_stats(config['dataset'])
    img_size = config.get('img_size')
    if img_size is None:
        img_size = default_img_size if default_img_size is not None else 32
    patch_size = config.get('patch_size', default_patch_size)
    if patch_size is None:
        patch_size = 4
    attention_type = config['variants'][0] if config['variants'] else 'roformer'
    model = EnhancedVisionTransformer(
        img_size=img_size, 
        patch_size=patch_size, 
        in_channels=3, 
        num_classes=num_classes,
        d_model=config['d_model'], 
        n_heads=config['n_heads'], 
        n_layers=config['n_layers'],
        d_ff=config['d_model']*4, 
        dropout=0.1, 
        attention_type=attention_type,
        gamma_init=config.get('gamma_init', 0.01),
        omega_init=config.get('omega_init', 1.0),
        num_fourier_components=config.get('num_fourier_components', 4),
        ablation_config=ablation_config
    )
    return model


def parse_args():
    parser = argparse.ArgumentParser(description="Enhanced Vision Transformer Experiment")
    parser.add_argument("--batch_size", type=int, default=128, help="Batch size for training")
    parser.add_argument("--epochs", type=int, default=1, help="Number of epochs for training")
    parser.add_argument("--lr", type=float, default=0.0005, help="Learning rate")
    parser.add_argument("--weight_decay", type=float, default=0.05, help="Weight decay")
    parser.add_argument("--d_model", type=int, default=192, help="Model dimension")
    parser.add_argument("--n_heads", type=int, default=6, help="Number of attention heads")
    parser.add_argument("--n_layers", type=int, default=6, help="Number of transformer layers")
    parser.add_argument("--dataset", type=str, default="cifar10", 
                        choices=["cifar10", "cifar100", "imagenet-subset", 
                                "oxford-pets", "oxford-flowers", "stanford-cars"],
                        help="Dataset to use")
    parser.add_argument("--img_size", type=int, default=None, help="Input image size")
    parser.add_argument("--gamma_init", type=float, default=0.01, help="Initial gamma parameter")
    parser.add_argument("--omega_init", type=float, default=1.0, help="Initial omega parameter")
    parser.add_argument("--num_fourier_components", type=int, default=4, help="Number of Fourier components")
    # Default variants set to only the desired experiments: fourier, roformer, vit, deit.
    parser.add_argument("--variants", type=str, nargs='+', 
                        default=['fourier', 'roformer', 'vit', 'deit'],
                        help="List of variants to run")
    parser.add_argument("--run_ablation", action="store_true", help="Run comprehensive ablation studies (Fourier only)")
    parser.add_argument("--subset_size", type=int, default=None, help="Samples per class for faster experimentation")
    return parser.parse_args()


def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, targets in tqdm(dataloader, desc="Training"):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs, _ = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    epoch_loss = running_loss / len(dataloader)
    accuracy = 100 * correct / total
    return epoch_loss, accuracy


def evaluate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc="Evaluating"):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs, _ = model(inputs)
            loss = criterion(outputs, targets)
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    epoch_loss = running_loss / len(dataloader)
    accuracy = 100 * correct / total
    return epoch_loss, accuracy


def train_model(model, trainloader, testloader, criterion, optimizer, scheduler, device, config, logger):
    model_name = wandb.run.name
    print(f"\nTraining {model_name} model...")
    epochs = config['epochs']
    best_acc = 0
    wandb.run.summary.update({
        'model_type': model.attention_type,
        'dataset': config['dataset'],
        'image_size': config.get('img_size'),
        'patch_size': getattr(model, 'patch_size', 'unknown'),
        'parameter_count': sum(p.numel() for p in model.parameters() if p.requires_grad)
    })
    logger.log_attention_maps(model, testloader, device, epoch=0)
    logger.log_model_parameters(model, epoch=0)
    columns = ["epoch", "train_loss", "train_acc", "test_loss", "test_acc"]
    performance_table = wandb.Table(columns=columns)
    for epoch in range(epochs):
        train_loss, train_acc = train_epoch(model, trainloader, criterion, optimizer, device)
        test_loss, test_acc = evaluate(model, testloader, criterion, device)
        logger.log_metrics({
            'epoch': epoch,
            'train_loss': train_loss,
            'train_acc': train_acc,
            'test_loss': test_loss,
            'test_acc': test_acc
        })
        performance_table.add_data(epoch, train_loss, train_acc, test_loss, test_acc)
        print(f"Epoch {epoch+1}/{epochs} | "
              f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
              f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%")
        if epoch % 10 == 0:
            logger.log_attention_maps(model, testloader, device, epoch)
        logger.log_model_parameters(model, epoch)
        if test_acc > best_acc:
            best_acc = test_acc
            model_path = f"results/models/{model_name}_best.pth"
            torch.save(model.state_dict(), model_path)
            wandb.save(model_path)
            wandb.run.summary['best_accuracy'] = test_acc
            wandb.run.summary['best_epoch'] = epoch
        scheduler.step()
    wandb.log({"performance_summary": performance_table})
    wandb.run.summary.update({
        'final_train_loss': train_loss,
        'final_train_acc': train_acc,
        'final_test_loss': test_loss,
        'final_test_acc': test_acc,
        'total_epochs': epochs
    })
    model_path = f"results/models/{model_name}_final.pth"
    torch.save(model.state_dict(), model_path)
    wandb.save(model_path)
    return model


def run_experiment(config):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    trainloader, testloader, num_classes = prepare_dataset(
        config['dataset'], config['batch_size'], 
        img_size=config.get('img_size'), 
        subset_size=config.get('subset_size')
    )
    models = {}
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    experiment_group = f"{config['dataset']}_{timestamp}"
    logger = WandbLogger(experiment_type='main', group_name=experiment_group)

    # Only include the desired variants
    if 'roformer' in config['variants']:
        models['RoFormer'] = create_model({**config, 'variants': ['roformer']}, num_classes)
    if 'vit' in config['variants']:
        models['Standard_ViT'] = create_model({**config, 'variants': ['vit']}, num_classes)
    if 'deit' in config['variants']:
        models['DeiT'] = create_model({**config, 'variants': ['deit']}, num_classes)
    if 'fourier' in config['variants']:
        models['Fourier_RoFormer'] = create_model({**config, 'variants': ['fourier']}, num_classes)
    trained_models = {}
    for model_name, model in models.items():
        run_name = f"{model_name}_{config['dataset']}_{config['d_model']}d_{config['n_heads']}h"
        enhanced_config = {
            **config,
            'model_type': model_name,
            'experiment_group': experiment_group,
            'parameter_count': sum(p.numel() for p in model.parameters() if p.requires_grad),
            'img_size': config.get('img_size'),
            'timestamp': timestamp
        }
        run = logger.init_run(
            config=enhanced_config,
            run_name=run_name,
            tags=[
                model_name, 
                config['dataset'], 
                f"d{config['d_model']}", 
                f"h{config['n_heads']}",
                f"l{config['n_layers']}",
                "variant_comparison"
            ]
        )
        model = model.to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['epochs'])
        trained_model = train_model(
            model, trainloader, testloader, criterion, optimizer, scheduler, 
            device, enhanced_config, logger
        )
        trained_models[model_name] = trained_model
        logger.finish_run()
    return trained_models


# def main():
#     args = parse_args()
#     config = vars(args)
#     if config['img_size'] is None:
#         _, _, default_img_size, _ = get_dataset_stats(config['dataset'])
#         config['img_size'] = default_img_size
#     os.environ["WANDB_MODE"] = os.environ.get("WANDB_MODE", "online")
#     print("Starting experiment with selected variants...")
#     print(f"Using the following variants: {config['variants']}")
#     print(f"Dataset: {config['dataset']} with image size: {config['img_size']}")
#     if config['run_ablation']:
#         print("\nRunning comprehensive ablation studies (Fourier variant only)...")
#         ablation_results = run_ablation_study(config)
#         print("Ablation studies completed!")
#     trained_models = run_experiment(config)
#     print("\nExperiment completed!")


def main():
    args = parse_args()
    config = vars(args)
    if config['img_size'] is None:
        _, _, default_img_size, _ = get_dataset_stats(config['dataset'])
        config['img_size'] = default_img_size
    os.environ["WANDB_MODE"] = os.environ.get("WANDB_MODE", "online")
    print("Starting experiment with selected variants...")
    print(f"Using the following variants: {config['variants']}")
    print(f"Dataset: {config['dataset']} with image size: {config['img_size']}")
    
    # Save original variants
    original_variants = config['variants'].copy()
    
    if config['run_ablation']:
        print("\nRunning comprehensive ablation studies (Fourier variant only)...")
        run_ablation_study(config)
        print("Ablation studies completed!")
        # Restore the original variants after ablation study
        config['variants'] = original_variants
        
    trained_models = run_experiment(config)
    print("\nExperiment completed!")


if __name__ == "__main__":
    main()
