"""
    transformer.py

    A simple transformer model with multi-head attention, positional encoding, and optional MLP layers.

    This model can be used for various sequence classification tasks, including binary classification.

    The model supports:
    - Multi-head attention with optional layer normalization and MLP layers.
    - Positional encoding (both fixed and randomized).
    - Residual connections.
"""

from xml.parsers.expat import model
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DataParallel, DistributedDataParallel as DDP
import numpy as np
import os
import json
import noise_stability.measure_noise_stability as ns
import visualization.plotting as plotting

epoch_period_global = 10 # how often to print the training results

##########################################################################################

######## MULTI-GPU UTILITIES ############

##########################################################################################

def setup_device(use_dp=False, gpu_ids=None, use_ddp=False, device=None):
    """
    Set up device configuration for single or multi-GPU training.
    
    Args:
        use_dp (bool): Whether to use DataParallel
        gpu_ids (list): Specific GPU IDs to use (e.g., [0, 1, 2, 3])
        use_ddp (bool): Whether to use DistributedDataParallel
        device (str): Device to use for single GPU training (e.g., 'cuda:0', 'cuda:1')
    
    Returns:
        tuple: (device, is_multi_gpu, gpu_count)
    """
    if not torch.cuda.is_available():
        print("CUDA not available. Using CPU.")
        return 'cpu', False, 0
    
    gpu_count = torch.cuda.device_count()
    print(f"Found {gpu_count} GPU(s)")
    
    # DDP setup - device is determined by RANK environment variable
    if use_ddp:
        rank = int(os.environ['RANK'])
        device = f'cuda:{rank}'
        print(f"Using DDP on GPU {rank}")
        return device, True, 1
    
    # DataParallel setup
    if use_dp:
        if gpu_ids is None:
            gpu_ids = list(range(gpu_count))
        else:
            gpu_ids = [id for id in gpu_ids if id < gpu_count]
        
        if len(gpu_ids) > 1:
            device = f'cuda:{gpu_ids[0]}'  # Primary device
            print(f"Using DataParallel with {len(gpu_ids)} GPUs: {gpu_ids}")
            return device, True, len(gpu_ids)
        else:
            device = f'cuda:{gpu_ids[0]}'
            print(f"Using single GPU: {device}")
            return device, False, 1
    
    # Single GPU setup (default or specified)
    if device is None:
        device = 'cuda:0'
    print(f"Using single GPU: {device}")
    return device, False, 1

def setup_ddp(rank, world_size, backend='nccl'):
    """
    Initialize the distributed process group for DDP.
    
    Args:
        rank (int): Rank of the current process
        world_size (int): Total number of processes
        backend (str): Backend to use ('nccl' for GPU, 'gloo' for CPU)
    """
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'

    print(f"Initializing DDP with rank {rank} and world size {world_size}")
    dist.init_process_group(backend, rank=rank, world_size=world_size)

    print(f"DDP initialized with rank {rank} and world size {world_size}")

def cleanup_ddp():
    """Clean up the distributed process group."""
    if dist.is_initialized():
        dist.destroy_process_group()

def wrap_model_for_multi_gpu(model, device, use_dp=False, gpu_ids=None, use_ddp=False):
    """
    Wrap model for multi-GPU training using DataParallel or DistributedDataParallel.
    
    Args:
        model: PyTorch model
        device: Primary device
        use_dp (bool): Whether to use DataParallel
        gpu_ids (list): GPU IDs to use for DataParallel
        use_ddp (bool): Whether to use DistributedDataParallel
    
    Returns:
        Wrapped model
    """
    model = model.to(device)
    
    if not torch.cuda.is_available():
        return model
    
    if use_ddp:
        # DistributedDataParallel setup
        rank = int(os.environ['RANK'])
        print(f"Wrapping model with DistributedDataParallel on GPU {rank}")
        # DDP should already be initialized in grokking_experiment.py
        model = DDP(model, device_ids=[rank], output_device=rank)
        return model
    
    if use_dp:
        if gpu_ids is None:
            gpu_ids = list(range(torch.cuda.device_count()))
        else:
            gpu_ids = [id for id in gpu_ids if id < torch.cuda.device_count()]
        
        if len(gpu_ids) > 1:
            print(f"Wrapping model with DataParallel on GPUs: {gpu_ids}")
            # Optimize DataParallel by specifying output device to balance memory usage
            primary_gpu = gpu_ids[0]
            model = DataParallel(model, device_ids=gpu_ids, output_device=primary_gpu)
        else:
            print(f"Single GPU specified for DataParallel, using single GPU mode")
    
    return model

def test_multi_gpu_setup():
    """
    Test and display multi-GPU setup information.
    """
    print("=== Multi-GPU Setup Information ===")
    
    if not torch.cuda.is_available():
        print("❌ CUDA is not available")
        return False
    
    gpu_count = torch.cuda.device_count()
    print(f"✅ CUDA is available")
    print(f"🔢 Number of GPUs: {gpu_count}")
    
    for i in range(gpu_count):
        gpu_props = torch.cuda.get_device_properties(i)
        memory_gb = gpu_props.total_memory / (1024**3)
        print(f"   GPU {i}: {gpu_props.name} ({memory_gb:.1f} GB)")
    
    if gpu_count > 1:
        print(f"✅ Multi-GPU training available with {gpu_count} GPUs")
        return True
    else:
        print("ℹ️  Only 1 GPU available - multi-GPU training not needed")
        return False

def get_model_memory_usage(model, device):
    """
    Get memory usage information for a model.
    
    Args:
        model: PyTorch model
        device: Device the model is on
    
    Returns:
        Memory usage in MB
    """
    if not torch.cuda.is_available() or device == 'cpu':
        return 0
    
    # Get GPU index from device string
    if isinstance(device, str) and device.startswith('cuda'):
        gpu_id = int(device.split(':')[1]) if ':' in device else 0
    else:
        gpu_id = 0
    
    torch.cuda.synchronize(gpu_id)
    memory_mb = torch.cuda.memory_allocated(gpu_id) / (1024**2)
    return memory_mb

def print_training_info(model, device, is_multi_gpu, gpu_count):
    """
    Print training configuration information.
    
    Args:
        model: PyTorch model
        device: Primary device
        is_multi_gpu: Whether using multi-GPU
        gpu_count: Number of GPUs being used
    """
    print("\n=== Training Configuration ===")
    print(f"Device: {device}")
    print(f"Multi-GPU: {'Yes' if is_multi_gpu else 'No'}")
    
    if is_multi_gpu:
        print(f"GPUs in use: {gpu_count}")
        print(f"Model type: DataParallel")
    
    # Model parameter count
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    # Memory usage
    memory_mb = get_model_memory_usage(model, device)
    if memory_mb > 0:
        print(f"Model memory usage: {memory_mb:.1f} MB")
    
    print("=" * 35)

##########################################################################################

######## MODEL SPECIFICS ############

##########################################################################################

def positional_encoding(seq_length, d_model):
    """
        Generate sinusoidal positional encodings.
    """
    position = torch.arange(seq_length).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
    pos_enc = torch.zeros(1, seq_length, d_model)
    pos_enc[0, :, 0::2] = torch.sin(position * div_term)
    pos_enc[0, :, 1::2] = torch.cos(position * div_term)
    return pos_enc

class MultiHeadAttention(nn.Module):
    """
        A class that implements multi-head attention mechanism.
    """
    def __init__(self, d_model, num_heads, dropout_rate=0.1):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        # Projections for Q, K, V for all heads at once
        self.W_Q = nn.Linear(d_model, d_model) # d x d
        self.W_K = nn.Linear(d_model, d_model) # d x d
        self.W_V = nn.Linear(d_model, d_model) # d x d
        self.W_O = nn.Linear(d_model, d_model)
        
        # Layer Norm
        self.layer_norm = nn.LayerNorm(d_model)

        # Dropout layers
        self.attention_dropout = nn.Dropout(dropout_rate)
        self.output_dropout = nn.Dropout(dropout_rate)
        self.ff_dropout = nn.Dropout(dropout_rate)

        # MLP
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, 2 * d_model),
            nn.ReLU(),
            nn.Dropout(dropout_rate),  # Add dropout after ReLU
            nn.Linear(2 * d_model, d_model)
        )
        
    def forward(self, x, attention_mask):
        assert(attention_mask is not None), "Attention mask must be provided"

        # x is like: (B, n, d)
        batch_size, seq_len, _ = x.shape

        # Residual connection
        residual = x

        # Apply Layer Normalization
        x = self.layer_norm(x)

        # Linear projections and reshape to multi-head
        Q = self.W_Q(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, n, d)
        K = self.W_K(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, n, d)
        V = self.W_V(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, n, d)
        
        # Scale dot-product attention
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5) # (B, H, n, n)

        if attention_mask is not None:
            mask = attention_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, n)
            attention_scores = attention_scores.masked_fill(mask==0, -1e9)

        attention_weights = F.softmax(attention_scores, dim=-1)
        attention_weights = self.attention_dropout(attention_weights)  # Apply dropout to attention weights
        
        # Apply attention and reshape
        out = torch.matmul(attention_weights, V) # (B, H, n, d)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)

        # Apply output projection and dropout
        out = self.output_dropout(self.W_O(out))
        # Apply residual connection.
        out = residual + out
        
        # Store for second residual connection
        residual2 = out
        out = self.feed_forward(out)
        out = out + residual2  # Second residual connection
        
        # Final projection: (B, n, H * d)
        return out

class SimpleTransformer(nn.Module):
    """
    A simple transformer model for sequence classification tasks and language modeling.

    Args:
        vocab_size (int): Size of the vocabulary.
        d_model (int): Dimension of the model (embedding size).
        n_layers (int): Number of transformer layers.
        n_heads (int): Number of attention heads.
        layer_norm (bool): Whether to use layer normalization.
        MLP (bool): Whether to use a feed-forward MLP after attention.
        pos_enc (bool): Whether to use positional encoding.
        use_res (bool): Whether to use residual connections.
        max_length (int): Maximum sequence length for positional encoding.
        task_type (str): Type of task - 'classification' or 'language_modeling'
    
    """
    def __init__(self, \
                 vocab_size, \
                 d_model, \
                 n_layers, \
                 n_heads, \
                 max_length=512, \
                 num_classes=None,
                 dropout_rate=0.1,
                 task_type='classification'
                ):
        
        # Initialize the transformer model.
        super(SimpleTransformer, self).__init__()
        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.vocab_size = vocab_size
        self.max_length = max_length
        self.task_type = task_type

        # Positional encoding
        self.pos_enc = nn.Parameter(
            positional_encoding(2 * self.max_length, self.d_model), requires_grad=False)

        # Embedding layer: R^vocab_size -> R^d_model
        self.embedding = nn.Embedding(self.vocab_size, self.d_model)
        
        # Embedding dropout
        self.embedding_dropout = nn.Dropout(dropout_rate)
        
        # Transformer layers
        self.layers = nn.ModuleList()
        for _ in range(self.n_layers):
            self.layers.append(MultiHeadAttention(d_model, n_heads, dropout_rate))
        
        # Output projection
        if task_type == 'language_modeling':
            # For language modeling, output is vocab_size (predict next token)
            self.output_projection = nn.Linear(d_model, vocab_size)
        else:
            # For classification, output is num_classes
            out_classes = num_classes if num_classes is not None else 2
            self.output_projection = nn.Linear(d_model, out_classes)

    def forward(self, x, attention_mask):
        assert(attention_mask is not None), "Attention mask must be provided"

        # x shape: (B, n) with integer token ids.
        _, seq_len = x.shape
        
        # Project input to d_model dimensions
        x = self.embedding(x)  # (B, n, d_model)

        # Positional Encodings
        if self.pos_enc is not None:
            x = x + self.pos_enc[:, :seq_len, :].to(x.device)  # (B, n, d_model)
        
        # Apply dropout after embeddings + positional encoding
        x = self.embedding_dropout(x)
        
        # Go through the layers.
        for layer in range(self.n_layers):
            x = self.layers[layer](x, attention_mask) # (B, n, d_model)
        
        if self.task_type == 'language_modeling':
            # For language modeling, apply projection to each token
            # x shape: (B, n, d_model)
            output = self.output_projection(x)  # (B, n, vocab_size)
        else:
            # For classification, use mean pooling
            # Here: x is still (B, n, d_model)
            x = torch.mean(x, dim=1) # (B, d_model)
            output = self.output_projection(x)  # (B, num_classes)
        
        return output


##########################################################################################


##########################################################################################

######## TRAINING AND TESTING ############

##########################################################################################

def train_model(model, 
                train_loader, 
                val_loader, 
                num_epochs, 
                folder_name, 
                vocab_size, 
                lr=0.001, 
                device='cpu', 
                weight_decay=0.0001, 
                rho=[0.01 * i for i in range(1, 11)],
                seed=None,
                noise_reg_strength=0.1,
                noise_reg_r=0.05,
                patience=5,
                lr_factor=0.5,
                input_length=None,
                use_dp=False,
                gpu_ids=None,
                use_ddp=False,
                task_type='classification',
                label_smoothing=0.0
            ):

    """
        Train the transformer model with noise stability regularization
    
        Arguments:
            model: The transformer model to train.
            train_loader: DataLoader for the training dataset.
            val_loader: DataLoader for the validation dataset.
            num_epochs: Number of training epochs.
            folder_name: Directory to save model checkpoints.
            vocab_size: Size of the vocabulary.
            lr: Learning rate for the optimizer.
            device: Device to train the model on (e.g., 'cpu' or 'cuda').
            weight_decay: Weight decay for the optimizer.
            rho: List of noise regularization strengths.
            seed: Random seed for reproducibility.
            noise_reg_strength: Strength of the noise regularization.
            noise_reg_r: Radius for the noise regularization.
            patience: Patience for early stopping.
            lr_factor: Factor by which to reduce the learning rate.
            input_length: Maximum input length for the model.
            use_dp: Whether to use DataParallel.
            gpu_ids: Specific GPU IDs to use (e.g., [0, 1, 2, 3]).
            use_ddp: Whether to use DistributedDataParallel instead of DataParallel.
            task_type: Type of task - 'classification' or 'language_modeling'
    """

    # Set seed for reproducibility if provided
    if seed is not None:
        torch.manual_seed(seed)
        if device != 'cpu':
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)  # For multi-GPU
        np.random.seed(seed)
        
    # Setup device and multi-GPU configuration
    if isinstance(device, str) and device.startswith('cuda'):
        device, is_multi_gpu, gpu_count = setup_device(use_dp, gpu_ids, use_ddp, device)
    else:
        is_multi_gpu = False
        gpu_count = 0 if device == 'cpu' else 1
    
    # Wrap model for multi-GPU training
    model = wrap_model_for_multi_gpu(model, device, use_dp, gpu_ids, use_ddp)
    
    # Only rank 0 handles logging and saving -- only for DDP.
    rank = None
    world_size = None
    is_main_process = True
    if use_ddp:
        rank = int(os.environ['RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
        is_main_process = (rank == 0)

    # Print training configuration
    print_training_info(model, device, is_multi_gpu, gpu_count)
    
    if use_dp :
        # Adjust learning rate for multi-GPU training
        lr = lr * gpu_count  # Scale learning rate with number of GPUs
        print(f"Adjusted learning rate for multi-GPU: {lr}")
    elif use_ddp:
        lr = lr * world_size if world_size is not None else lr
        print(f"Adjusted learning rate for DDP: {lr}")
    else:
        print(f"Learning rate: {lr}")
    
    # Define loss function and optimizer
    # For language modeling, we need to handle per-token loss
    if task_type == 'language_modeling':
        criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing, ignore_index=-100)
    else:
        criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
    
    optimizer = optim.AdamW(
        model.parameters(), 
        lr=lr * world_size if use_ddp and world_size is not None else lr, 
        weight_decay=weight_decay)

    # The scheduler reduces the learning rate by a predetermined factor when 
    # validation loss plateaus. 
    scheduler = None
    if patience != 0:
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
                                                        'min', 
                                                        patience=patience, 
                                                        factor=lr_factor)
    
    # Track metrics
    train_losses = []
    val_losses = []
    val_accuracies = []
    val_perplexities = []  # For language modeling

    # Track noise stability and L2 norm
    noise_stabilities = {}
    model_l2_squared = []  # Track ||f||_2^2 over time
    for r in rho:
        noise_stabilities[r] = []
    
    ############################################################################
    ################### TRAINING ###################
    ############################################################################
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0

        # For each batch in the train loader:
        for batch_idx, batch in enumerate(train_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            # Zero gradients
            optimizer.zero_grad()

            # Forward pass without hooks for standard loss
            outputs = model(input_ids, attention_mask=attention_mask)
            
            # Compute loss based on task type
            if task_type == 'language_modeling':
                # outputs: (B, seq_len, vocab_size)
                # labels: (B, seq_len)
                # Reshape for CrossEntropyLoss
                task_loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
            else:
                # Classification: outputs: (B, num_classes), labels: (B,)
                task_loss = criterion(outputs, labels)
            
            # Apply noise stability regularization periodically
            noise_reg_loss = 0.0
            if noise_reg_strength > 0:

                # Get the underlying model for noise stability computation
                # (DataParallel and DDP wrap the original model in .module)
                model_for_noise = model.module if isinstance(model, (DataParallel, DDP)) else model
                
                # Compute with gradient tracking (NO torch.no_grad here!)
                noise_stability = ns.compute_batch_noise_stability_with_grad(
                    model_for_noise, 
                    input_ids, 
                    attention_mask,
                    r=noise_reg_r,
                    vocab_size=vocab_size,
                    device=device
                )
                
                # To minimize stability (make model more sensitive to changes)
                # noise_reg_loss = noise_reg_strength * noise_stability
                
                # To maximize stability (make model more robust to changes)
                noise_reg_loss = -noise_reg_strength * noise_stability

            # Combined loss
            loss = task_loss + noise_reg_loss
            
            # Backward pass and optimize
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            epoch_loss += task_loss.item()  # Only count task loss for reporting
            
        avg_train_loss = epoch_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        if epoch % epoch_period_global == 0:
            # Validation phase
            model.eval()
            val_loss = 0.0
            correct = 0
            total = 0
            
            all_preds = []
            with torch.no_grad():
                for batch in val_loader: 
                    input_ids = batch['input_ids'].to(device)
                    attention_mask = batch['attention_mask'].to(device)
                    labels = batch['labels'].to(device)
                    
                    # Forward pass
                    outputs = model(input_ids, attention_mask=attention_mask)
                    
                    # Calculate loss and accuracy based on task type
                    if task_type == 'language_modeling':
                        # outputs: (B, seq_len, vocab_size), labels: (B, seq_len)
                        loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
                        val_loss += loss.item()
                        
                        # Calculate per-token accuracy
                        _, predicted = torch.max(outputs, dim=-1)  # (B, seq_len)
                        all_preds.extend(predicted.view(-1).cpu().tolist())
                        total += labels.numel()
                        correct += (predicted == labels).sum().item()
                    else:
                        # Classification
                        loss = criterion(outputs, labels)
                        val_loss += loss.item()
                        
                        # Calculate accuracy
                        _, predicted = torch.max(outputs.data, 1)
                        all_preds.extend(predicted.cpu().tolist())
                        total += labels.size(0)
                        correct += (predicted == labels).sum().item()
            
            avg_val_loss = val_loss / len(val_loader)
            val_accuracy = 100 * correct / total
            val_perplexity = torch.exp(torch.tensor(avg_val_loss)).item()
            
            val_losses.append(avg_val_loss)
            val_accuracies.append(val_accuracy)
            val_perplexities.append(val_perplexity)
            
            # Update learning rate based on validation loss
            if scheduler is not None:
                scheduler.step(avg_val_loss)
            
            if is_main_process:
                if task_type == 'language_modeling':
                    print(f'Epoch {epoch+1}/{num_epochs}, '
                        f'Train Loss: {avg_train_loss:.4f}, '
                        f'Val Loss: {avg_val_loss:.4f}, '
                        f'Val Perplexity: {val_perplexity:.2f}, '
                        f'Val Accuracy: {val_accuracy:.2f}%, '
                        f'LR: {optimizer.param_groups[0]["lr"]:.6f}')
                else:
                    print(f'Epoch {epoch+1}/{num_epochs}, '
                        f'Train Loss: {avg_train_loss:.4f}, '
                        f'Val Loss: {avg_val_loss:.4f}, '
                        f'Val Accuracy: {val_accuracy:.2f}%, '
                        f'LR: {optimizer.param_groups[0]["lr"]:.6f}')

            # After each epoch, measure noise stability and parity stability
            # Only measure if rho is not empty (can be slow)
            if rho:
                # Get the underlying model for measurements
                model_for_noise = model.module if isinstance(model, (DataParallel, DDP)) else model
                
                # Compute ||f||_2^2 once per epoch when we're measuring noise stability
                l2_sq = estimate_model_l2_squared(model_for_noise, 
                                                  input_length, 
                                                  vocab_size, 
                                                  num_samples=5,
                                                  device=device,
                                                  task_type=task_type)
                model_l2_squared.append(l2_sq)
                
                for r in rho:
                    noise_stability = ns.measure_noise_stability(model_for_noise, 
                                                                input_length, 
                                                                r, 
                                                                vocab_size, 
                                                                num_trials=50,  # Reduced from 100 to 50
                                                                device=device)
                    noise_stabilities[float(r)].append(noise_stability)

    ###############################################################################################
    ################### END OF TRAINING ######################################
    ###############################################################################################

    if is_main_process:
        # Plot the training results just for this seed.
        plotting.plot_training_results(train_losses, \
                                        val_losses, \
                                        val_accuracies, \
                                        noise_stabilities, \
                                        folder_name, \
                                        val_perplexities=val_perplexities, \
                                        task_type=task_type)
        
        # Save metrics as JSON for later analysis
        metrics = {
            'train_losses': train_losses,
            'val_losses': val_losses,
            'val_accuracies': val_accuracies,
            'val_perplexities': val_perplexities,
            'noise_stabilities': noise_stabilities,
            'model_l2_squared': model_l2_squared,
            'task_type': task_type
        }
        with open(f"{folder_name}/metrics.json", "w") as f:
            json.dump(metrics, f)

        # Dump the final model in a file.
        # Save the underlying model state_dict (unwrap DataParallel/DDP if needed)
        model_to_save = model.module if isinstance(model, (DataParallel, DDP)) else model
        torch.save(model_to_save.state_dict(), f"{folder_name}/model.pth")
        print("Model saved.")

    return model

def estimate_model_l2_squared(model, input_length, vocab_size, num_samples=1000, device='cpu', task_type='classification'):
    """
    Estimate ||f||_2^2 = E[f(X)^2] via uniform sampling over the input space.
    
    Args:
        model: The model to evaluate (should be in eval mode)
        input_length: Length of input sequences
        vocab_size: Size of vocabulary
        num_samples: Number of samples to average over
        device: Device to run on
        task_type: 'classification' or 'language_modeling'
    
    Returns:
        Estimated E[f(X)^2] (float)
    """
    model.eval()
    
    # Generate random inputs
    X = torch.randint(0, vocab_size, size=(num_samples, input_length), device=device)
    attention_mask = torch.ones_like(X, dtype=torch.float)
    
    with torch.no_grad():
        outputs = model(X, attention_mask=attention_mask)
        
        if task_type == 'language_modeling':
            # For language modeling, get predictions for each token
            # outputs: (B, seq_len, vocab_size)
            predictions = torch.argmax(outputs, dim=-1)  # (B, seq_len)
            # Take mean over sequence to get one value per sample
            f_X = predictions.float().mean(dim=1)  # (B,)
        else:
            # For classification, get the predicted class
            predictions = torch.argmax(outputs, dim=1)  # (B,)
            f_X = predictions.float()
    
    # Calculate E[f(X)^2]
    l2_squared = torch.mean(f_X ** 2).item()
    
    return l2_squared

def test(model, test_loader, device, task_type='classification'):
    """
        Test a model on a testing dataset.
        
        Args:
            model: The model to test
            test_loader: DataLoader for test data
            device: Device to run on
            task_type: 'classification' or 'language_modeling'
            
        Returns:
            accuracy: Test accuracy (as percentage for consistency)
    """
    model.eval()

    correct = 0
    total = 0
    
    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids, attention_mask=attention_mask)
            
            if task_type == 'language_modeling':
                # Per-token accuracy
                _, predicted = torch.max(outputs, dim=-1)  # (B, seq_len)
                correct += (predicted == labels).sum().item()
                total += labels.numel()
            else:
                # Classification accuracy
                _, predicted = torch.max(outputs.data, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

    # Return the accuracy as percentage
    accuracy = 100 * correct / total
    return accuracy

##########################################################################################

##########################################################################################

######## DEPLOYING ############

##########################################################################################

# def init_weights(m):
#     """
#         Initialize weights for the model.
#         Parameter: 
#             m (nn.Module): The model layer to initialize.
#     """
#     if isinstance(m, nn.Linear):
#         # Smaller initialization for more stable training
#         torch.nn.init.xavier_uniform_(m.weight, gain=0.5)
#         if m.bias is not None:
#             torch.nn.init.zeros_(m.bias)
#     elif isinstance(m, nn.Embedding):
#         torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
#     elif isinstance(m, nn.LayerNorm):
#         torch.nn.init.ones_(m.weight)
#         torch.nn.init.zeros_(m.bias)

# Model initialization for addition.
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

def run_multiple_seeds(model_class, 
                       model_args, 
                       train_loader, 
                       val_loader, 
                       test_loader,
                       num_epochs, 
                       folder_name, 
                       vocab_size, 
                       seeds=[42, 123, 456, 789, 101],
                       noise_reg_strength=0.1,
                       noise_reg_r=0.05,
                       learn_function_stabilities=None,
                       use_dp=False,
                       gpu_ids=None,
                       use_ddp=False,
                       epoch_period=100,
                       task_type='classification',
                       **train_kwargs): 

    """
        Run training with multiple seeds and collect results.

        Parameters:
            model_class (Type[nn.Module]): The model class to train.
            model_args (dict): Arguments to initialize the model.
            train_loader (DataLoader): DataLoader for the training set.
            val_loader (DataLoader): DataLoader for the validation set.
            test_loader (DataLoader): DataLoader for the test set.
            num_epochs (int): Number of training epochs.
            folder_name (str): Folder to save results.
            vocab_size (int): Size of the vocabulary.
            seeds (list): List of random seeds to use.
            noise_reg_strength (float): Strength of noise regularization.
            noise_reg_r (float): R value for noise regularization.
            learn_function_stabilities: Learning function stability for plotting.
            use_multi_gpu (bool): Whether to use multiple GPUs if available.
            gpu_ids (list): Specific GPU IDs to use (e.g., [0, 1, 2, 3]).
            use_ddp (bool): Whether to use DistributedDataParallel instead of DataParallel.
            **train_kwargs: Additional keyword arguments for training.
    """
    # Initialize results dictionary with parity stabilities
    results = {
        'train_losses': [],
        'val_losses': [],
        'val_accuracies': [],
        'val_perplexities': [],
        'test_accuracies': [],
        'noise_stabilities': {r: [] for r in train_kwargs.get('rho', [])},
        'model_l2_squared': [],
        'task_type': task_type
    }

    global epoch_period_global
    epoch_period_global = epoch_period
    
    # Create subfolder for seed experiments
    seeds_folder = f"{folder_name}/seeds"
    os.makedirs(seeds_folder, exist_ok=True)

    # DDP parameters
    rank = None
    world_size = None

    # Check if we're in an externally launched DDP environment
    if use_ddp and 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ['RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
        print(f"Running in external DDP environment: rank {rank}/{world_size}")
    elif use_ddp:
        raise RuntimeError(
            "DDP mode requires external launch with torchrun or torch.distributed.launch. "
            "Example: torchrun --nproc_per_node=4 grokking_experiments/grokking_experiment.py --training_mode ddp"
        )
    else:
        rank = None
        world_size = None
    
    # Train with each seed
    for i, seed in enumerate(seeds):
        print(f"\n{'='*50}\nTraining with seed {seed} ({i+1}/{len(seeds)})\n{'='*50}")
        
        # Create seed-specific folder
        seed_folder = f"{seeds_folder}/seed_{seed}"
        os.makedirs(seed_folder, exist_ok=True)
        
        # Create a new model instance for each seed
        model = model_class(**model_args)
                    
        # Initialize weights.
        model.apply(init_weights)
        
        # Train the model with this seed
        trained_model = train_model(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            num_epochs=num_epochs,
            folder_name=seed_folder,
            vocab_size=vocab_size,
            noise_reg_strength=noise_reg_strength,
            noise_reg_r=noise_reg_r,
            seed=seed,
            use_dp=use_dp,
            gpu_ids=gpu_ids,
            use_ddp=use_ddp,
            task_type=task_type,
            **train_kwargs
        )
        
        # Evaluate on test set (only main process in DDP)
        if not use_ddp or rank == 0:
            test_accuracy = test(trained_model, test_loader, 
                               device=train_kwargs.get('device', 'cpu'),
                               task_type=task_type)
            print(f"Test accuracy for seed {seed}: {test_accuracy:.2f}%")
        else:
            test_accuracy = 0.0  # Non-main processes don't evaluate
        
        # Only main process handles results collection in DDP
        if not use_ddp or rank == 0:
            # Load metrics from this run
            with open(f"{seed_folder}/metrics.json", "r") as f:
                metrics = json.load(f)
            
            # Store results
            results['train_losses'].append(metrics['train_losses']) # a list of losses
            results['val_losses'].append(metrics['val_losses'])     # a list of losses
            results['val_accuracies'].append(metrics['val_accuracies']) # list of accuracies
            results['val_perplexities'].append(metrics['val_perplexities']) # list of perplexities
            results['test_accuracies'].append(test_accuracy)            # list of accuracies
            
            # Store ||f||_2^2 values if available
            if 'model_l2_squared' in metrics:
                results['model_l2_squared'].append(metrics['model_l2_squared'])
            
            # metrics['noise_stabilities'] is a list of noise stability values per value of r
            for r in metrics['noise_stabilities']:
                rr = float(r)
                if rr not in results['noise_stabilities']:
                    results['noise_stabilities'][rr] = []

                # A collection of lists for each r and each seed.
                # Each list contains the noise stability values for each epoch for that seed.
                results['noise_stabilities'][rr].append(metrics['noise_stabilities'][r])
    
    # Only main process saves results and plots in DDP
    if not use_ddp or rank == 0:
        # Save combined results
        with open(f"{folder_name}/combined_results.json", "w") as f:
            json.dump(results, f)
        
        # Plot combined results with variance
        plotting.plot_combined_results(results, 
                                       folder_name, 
                                       epoch_period_global, 
                                       learn_function_stabilities)
    
    return results