import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import math
import time
from typing import Optional, Tuple
import warnings

warnings.filterwarnings('ignore')


# ====================== Helper Functions ======================

def calculate_ece(predictions, labels, confidences, n_bins=15):
    """
    Calculate Expected Calibration Error (ECE)

    Args:
        predictions: predicted class labels
        labels: true class labels
        confidences: prediction confidences
        n_bins: number of bins for calibration
    """
    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    ece = 0.0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
        prop_in_bin = in_bin.float().mean().item()

        if prop_in_bin > 0:
            accuracy_in_bin = (predictions[in_bin] == labels[in_bin]).float().mean().item()
            avg_confidence_in_bin = confidences[in_bin].mean().item()
            ece += abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

    return ece * 100  # Return as percentage


def kl_divergence_gaussian(mu, log_var, prior_variance=1.0):
    """
    KL divergence between diagonal Gaussian and standard Gaussian prior
    """
    var = log_var.exp()
    kl = 0.5 * torch.sum(var / prior_variance + mu ** 2 / prior_variance - 1 - log_var + math.log(prior_variance))
    return kl


# ====================== SD-VI Layer Module ======================

class SDVILayer(nn.Module):
    """
    SD-VI variational layer managing mean and covariance for a single layer's weights
    """

    def __init__(self, weight_shape, jitter=1e-6, device='cuda'):
        super(SDVILayer, self).__init__()

        self.weight_shape = weight_shape
        self.n_params = np.prod(weight_shape)
        self.jitter = jitter
        self.device = device

        # Initialize variational parameters
        self.mu = nn.Parameter(torch.randn(weight_shape, device=device) * 0.1)

        # Initialize covariance as diagonal matrix
        self.S = torch.eye(self.n_params, dtype=torch.float32, device=device) * 0.01
        self.S.requires_grad = False  # We'll handle S updates manually

        # Store Cholesky decomposition for efficient sampling
        self.L = None
        self._update_cholesky()

    def _update_cholesky(self):
        """Update Cholesky decomposition of S"""
        try:
            S_reg = self.S + self.jitter * torch.eye(self.n_params, device=self.device)
            self.L = torch.linalg.cholesky(S_reg)
        except:
            # If Cholesky fails, increase jitter
            S_reg = self.S + (self.jitter * 10) * torch.eye(self.n_params, device=self.device)
            self.L = torch.linalg.cholesky(S_reg)

    def forward(self, deterministic=False):
        """
        Sample weights using reparameterization trick

        Args:
            deterministic: If True, return mean weights (for evaluation)
        """
        if deterministic:
            return self.mu

        # Reparameterization trick: weight = mu + L @ eps
        eps = torch.randn(self.n_params, 1, device=self.device)
        weight_flat = self.mu.view(-1, 1) + self.L @ eps
        weight = weight_flat.view(self.weight_shape).squeeze()

        return weight

    def pso_step(self, grad_S, lr_S, lambda1):
        """
        Proximal Spectral Optimization step for updating covariance

        Args:
            grad_S: Gradient with respect to S
            lr_S: Learning rate for S
            lambda1: Regularization parameter
        """
        with torch.no_grad():
            # Gradient ascent step (we're maximizing log-likelihood)
            S_intermediate = self.S + lr_S * grad_S

            # Apply proximal spectral map
            self.S = self._proximal_spectral_map(S_intermediate, lr_S, lambda1)

            # Update Cholesky decomposition
            self._update_cholesky()

    def _proximal_spectral_map(self, S_intermediate, lr_S, lambda1):
        """
        Apply proximal spectral mapping (core of SD-VI algorithm)
        """
        # Ensure symmetry
        S_sym = (S_intermediate + S_intermediate.T) / 2.0

        # Add jitter for numerical stability
        current_jitter = self.jitter
        max_retries = 5

        for retry in range(max_retries):
            try:
                S_reg = S_sym + current_jitter * torch.eye(self.n_params, device=self.device)
                eigenvalues, eigenvectors = torch.linalg.eigh(S_reg)
                break
            except:
                current_jitter *= 10
                if retry == max_retries - 1:
                    # Fall back to identity if all else fails
                    return torch.eye(self.n_params, device=self.device) * 0.01

        # Apply soft thresholding to eigenvalues
        threshold = lr_S * lambda1
        shrunk_eigenvalues = torch.relu(eigenvalues - threshold)

        # Reassemble covariance matrix
        S_new = eigenvectors @ torch.diag(shrunk_eigenvalues) @ eigenvectors.T

        return S_new


# ====================== Bayesian Layers with SD-VI ======================

class BayesianConv2d_SDVI(nn.Module):
    """Bayesian Conv2d layer using SD-VI"""

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True, device='cuda'):
        super(BayesianConv2d_SDVI, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.device = device

        # Weight shape
        weight_shape = (out_channels, in_channels // groups, *self.kernel_size)
        self.weight_layer = SDVILayer(weight_shape, device=device)

        # Deterministic bias
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels, device=device))
        else:
            self.register_parameter('bias', None)

    def forward(self, x, deterministic=False):
        weight = self.weight_layer(deterministic=deterministic)
        return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)


class BayesianLinear_SDVI(nn.Module):
    """Bayesian Linear layer using SD-VI"""

    def __init__(self, in_features, out_features, bias=True, device='cuda'):
        super(BayesianLinear_SDVI, self).__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.device = device

        # Weight shape
        weight_shape = (out_features, in_features)
        self.weight_layer = SDVILayer(weight_shape, device=device)

        # Deterministic bias
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features, device=device))
        else:
            self.register_parameter('bias', None)

    def forward(self, x, deterministic=False):
        weight = self.weight_layer(deterministic=deterministic)
        return F.linear(x, weight, self.bias)


# ====================== Mean-Field VI Layers (Baseline) ======================

class BayesianConv2d_MFVI(nn.Module):
    """Bayesian Conv2d layer using Mean-Field VI"""

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True, device='cuda'):
        super(BayesianConv2d_MFVI, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.device = device

        # Weight parameters
        self.weight_mu = nn.Parameter(torch.randn(out_channels, in_channels // groups,
                                                  *self.kernel_size, device=device) * 0.1)
        self.weight_log_var = nn.Parameter(torch.full_like(self.weight_mu, -5.0))

        # Deterministic bias
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels, device=device))
        else:
            self.register_parameter('bias', None)

    def forward(self, x, deterministic=False):
        if deterministic:
            weight = self.weight_mu
        else:
            # Reparameterization trick
            weight_std = torch.exp(0.5 * self.weight_log_var)
            eps = torch.randn_like(weight_std)
            weight = self.weight_mu + weight_std * eps

        return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

    def kl_divergence(self):
        return kl_divergence_gaussian(self.weight_mu, self.weight_log_var)


class BayesianLinear_MFVI(nn.Module):
    """Bayesian Linear layer using Mean-Field VI"""

    def __init__(self, in_features, out_features, bias=True, device='cuda'):
        super(BayesianLinear_MFVI, self).__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.device = device

        # Weight parameters
        self.weight_mu = nn.Parameter(torch.randn(out_features, in_features, device=device) * 0.1)
        self.weight_log_var = nn.Parameter(torch.full_like(self.weight_mu, -5.0))

        # Deterministic bias
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features, device=device))
        else:
            self.register_parameter('bias', None)

    def forward(self, x, deterministic=False):
        if deterministic:
            weight = self.weight_mu
        else:
            # Reparameterization trick
            weight_std = torch.exp(0.5 * self.weight_log_var)
            eps = torch.randn_like(weight_std)
            weight = self.weight_mu + weight_std * eps

        return F.linear(x, weight, self.bias)

    def kl_divergence(self):
        return kl_divergence_gaussian(self.weight_mu, self.weight_log_var)


# ====================== Wide ResNet Architecture ======================

class BasicBlock(nn.Module):
    """Basic block for Wide ResNet"""

    def __init__(self, in_planes, out_planes, stride, dropout_rate, layer_type='sdvi', device='cuda'):
        super(BasicBlock, self).__init__()

        self.layer_type = layer_type

        # Choose layer type
        if layer_type == 'sdvi':
            Conv2d = BayesianConv2d_SDVI
        elif layer_type == 'mfvi':
            Conv2d = BayesianConv2d_MFVI
        else:
            Conv2d = nn.Conv2d

        # Layers
        self.bn1 = nn.BatchNorm2d(in_planes, device=device)
        self.conv1 = Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                            padding=1, bias=True, device=device)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.bn2 = nn.BatchNorm2d(out_planes, device=device)
        self.conv2 = Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
                            padding=1, bias=True, device=device)

        # Shortcut
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != out_planes:
            if layer_type in ['sdvi', 'mfvi']:
                self.shortcut = nn.Sequential(
                    Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                           bias=True, device=device),
                )
            else:
                self.shortcut = nn.Sequential(
                    nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                              bias=True, device=device),
                )

    def forward(self, x, deterministic=False):
        if self.layer_type in ['sdvi', 'mfvi']:
            out = self.dropout(self.conv1(F.relu(self.bn1(x)), deterministic=deterministic))
            out = self.conv2(F.relu(self.bn2(out)), deterministic=deterministic)
            if isinstance(self.shortcut, nn.Sequential) and len(self.shortcut) > 0:
                out += self.shortcut[0](x, deterministic=deterministic)
            else:
                out += x
        else:
            out = self.dropout(self.conv1(F.relu(self.bn1(x))))
            out = self.conv2(F.relu(self.bn2(out)))
            out += self.shortcut(x)

        return out


class WideResNet(nn.Module):
    """Wide ResNet-28-10 architecture"""

    def __init__(self, depth=28, widen_factor=10, dropout_rate=0.3,
                 num_classes=10, layer_type='sdvi', device='cuda'):
        super(WideResNet, self).__init__()

        self.layer_type = layer_type
        self.device = device

        assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4'
        n = (depth - 4) // 6
        k = widen_factor

        nStages = [16, 16 * k, 32 * k, 64 * k]

        # Choose layer type
        if layer_type == 'sdvi':
            Conv2d = BayesianConv2d_SDVI
            Linear = BayesianLinear_SDVI
        elif layer_type == 'mfvi':
            Conv2d = BayesianConv2d_MFVI
            Linear = BayesianLinear_MFVI
        else:
            Conv2d = nn.Conv2d
            Linear = nn.Linear

        # First conv layer
        self.conv1 = Conv2d(3, nStages[0], kernel_size=3, stride=1,
                            padding=1, bias=True, device=device)

        # ResNet blocks
        self.layer1 = self._make_layer(BasicBlock, nStages[0], nStages[1], n,
                                       stride=1, dropout_rate=dropout_rate)
        self.layer2 = self._make_layer(BasicBlock, nStages[1], nStages[2], n,
                                       stride=2, dropout_rate=dropout_rate)
        self.layer3 = self._make_layer(BasicBlock, nStages[2], nStages[3], n,
                                       stride=2, dropout_rate=dropout_rate)

        # Final layers
        self.bn1 = nn.BatchNorm2d(nStages[3], device=device)
        self.linear = Linear(nStages[3], num_classes, device=device)

    def _make_layer(self, block, in_planes, out_planes, num_blocks, stride, dropout_rate):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []

        for stride in strides:
            layers.append(block(in_planes, out_planes, stride, dropout_rate,
                                self.layer_type, self.device))
            in_planes = out_planes

        return nn.Sequential(*layers)

    def forward(self, x, deterministic=False):
        if self.layer_type in ['sdvi', 'mfvi']:
            out = self.conv1(x, deterministic=deterministic)
            for block in self.layer1:
                out = block(out, deterministic=deterministic)
            for block in self.layer2:
                out = block(out, deterministic=deterministic)
            for block in self.layer3:
                out = block(out, deterministic=deterministic)
            out = F.relu(self.bn1(out))
            out = F.avg_pool2d(out, 8)
            out = out.view(out.size(0), -1)
            out = self.linear(out, deterministic=deterministic)
        else:
            out = self.conv1(x)
            out = self.layer1(out)
            out = self.layer2(out)
            out = self.layer3(out)
            out = F.relu(self.bn1(out))
            out = F.avg_pool2d(out, 8)
            out = out.view(out.size(0), -1)
            out = self.linear(out)

        return out

    def get_kl_divergence(self):
        """Calculate total KL divergence for MFVI model"""
        if self.layer_type != 'mfvi':
            return 0

        kl = 0
        for module in self.modules():
            if isinstance(module, (BayesianConv2d_MFVI, BayesianLinear_MFVI)):
                kl += module.kl_divergence()
        return kl

    def get_sdvi_layers(self):
        """Get all SD-VI layers for PSO updates"""
        sdvi_layers = []
        for module in self.modules():
            if isinstance(module, (BayesianConv2d_SDVI, BayesianLinear_SDVI)):
                sdvi_layers.append(module.weight_layer)
        return sdvi_layers


# ====================== Training and Evaluation Functions ======================

def train_sdvi(model, train_loader, test_loader, epochs, lr_mu, lr_S, lambda1, device):
    """Train SD-VI model"""

    print("\n" + "=" * 60)
    print("Training SD-VI Model")
    print("=" * 60)

    # Get all parameters
    mu_params = []
    other_params = []
    sdvi_layers = model.get_sdvi_layers()

    for name, param in model.named_parameters():
        if 'weight_layer.mu' in name:
            mu_params.append(param)
        else:
            other_params.append(param)

    # Optimizers
    optimizer_mu = optim.Adam(mu_params, lr=lr_mu)
    optimizer_other = optim.Adam(other_params, lr=lr_mu)

    # Training loop
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        correct = 0
        total = 0

        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)

            # Forward pass
            outputs = model(inputs, deterministic=False)
            loss = F.cross_entropy(outputs, targets)

            # Update mean parameters
            optimizer_mu.zero_grad()
            optimizer_other.zero_grad()

            # Enable gradients for S
            for layer in sdvi_layers:
                layer.S.requires_grad = True

            loss.backward()

            optimizer_mu.step()
            optimizer_other.step()

            # PSO step for covariance matrices
            for layer in sdvi_layers:
                if hasattr(layer.S, 'grad') and layer.S.grad is not None:
                    # Clip gradients for stability
                    torch.nn.utils.clip_grad_norm_([layer.S], max_norm=10.0)
                    layer.pso_step(layer.S.grad, lr_S, lambda1)
                    layer.S.grad = None
                layer.S.requires_grad = False

            # Statistics
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            if batch_idx % 100 == 0:
                print(f'Epoch {epoch + 1}/{epochs} [{batch_idx}/{len(train_loader)}] '
                      f'Loss: {loss.item():.3f} | Acc: {100. * correct / total:.2f}%')

        # Evaluation
        if (epoch + 1) % 5 == 0:
            test_acc, test_ece, test_nll = evaluate(model, test_loader, device, n_samples=10)
            print(f'Epoch {epoch + 1}: Test Acc: {test_acc:.2f}% | ECE: {test_ece:.2f}% | NLL: {test_nll:.4f}')

    return model


def train_mfvi(model, train_loader, test_loader, epochs, lr, kl_weight, device):
    """Train Mean-Field VI model"""

    print("\n" + "=" * 60)
    print("Training Mean-Field VI Model")
    print("=" * 60)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    num_batches = len(train_loader)

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        correct = 0
        total = 0

        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)

            # Forward pass
            outputs = model(inputs, deterministic=False)

            # Calculate negative ELBO
            nll = F.cross_entropy(outputs, targets, reduction='sum')
            kl = model.get_kl_divergence()
            loss = nll + kl_weight * kl / num_batches

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Statistics
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            if batch_idx % 100 == 0:
                print(f'Epoch {epoch + 1}/{epochs} [{batch_idx}/{len(train_loader)}] '
                      f'Loss: {loss.item():.3f} | NLL: {nll.item() / targets.size(0):.3f} | '
                      f'KL: {kl.item() / num_batches:.3f} | Acc: {100. * correct / total:.2f}%')

        # Evaluation
        if (epoch + 1) % 5 == 0:
            test_acc, test_ece, test_nll = evaluate(model, test_loader, device, n_samples=10)
            print(f'Epoch {epoch + 1}: Test Acc: {test_acc:.2f}% | ECE: {test_ece:.2f}% | NLL: {test_nll:.4f}')

    return model


def evaluate(model, test_loader, device, n_samples=20):
    """Evaluate model with multiple forward passes"""

    model.eval()
    all_predictions = []
    all_targets = []
    all_log_probs = []

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)

            # Multiple forward passes for uncertainty estimation
            batch_predictions = []
            batch_log_probs = []

            for _ in range(n_samples):
                outputs = model(inputs, deterministic=False)
                log_probs = F.log_softmax(outputs, dim=1)
                batch_predictions.append(outputs.unsqueeze(0))
                batch_log_probs.append(log_probs.unsqueeze(0))

            # Average predictions
            batch_predictions = torch.cat(batch_predictions, dim=0)
            batch_log_probs = torch.cat(batch_log_probs, dim=0)

            mean_predictions = batch_predictions.mean(dim=0)
            mean_log_probs = batch_log_probs.mean(dim=0)

            all_predictions.append(mean_predictions)
            all_targets.append(targets)
            all_log_probs.append(mean_log_probs)

    # Concatenate all batches
    all_predictions = torch.cat(all_predictions, dim=0)
    all_targets = torch.cat(all_targets, dim=0)
    all_log_probs = torch.cat(all_log_probs, dim=0)

    # Calculate metrics
    _, predicted_classes = all_predictions.max(1)
    accuracy = (predicted_classes == all_targets).float().mean().item() * 100

    # Calculate confidence scores
    probs = F.softmax(all_predictions, dim=1)
    confidences, _ = probs.max(1)

    # Calculate ECE
    ece = calculate_ece(predicted_classes, all_targets, confidences)

    # Calculate NLL
    nll = F.nll_loss(all_log_probs, all_targets).item()

    return accuracy, ece, nll


# ====================== Main Function ======================

def main():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Set random seeds
    torch.manual_seed(42)
    np.random.seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)

    # Hyperparameters
    batch_size = 128
    epochs_sdvi = 50
    epochs_mfvi = 50

    # SD-VI hyperparameters
    lr_mu = 0.001
    lr_S = 0.0001
    lambda1 = 0.01

    # MFVI hyperparameters
    lr_mfvi = 0.001
    kl_weight = 0.01

    # Data loading
    print("Loading CIFAR-10 dataset...")
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                             shuffle=False, num_workers=2)

    print(f"Training set size: {len(trainset)}")
    print(f"Test set size: {len(testset)}")

    # Train SD-VI model
    print("\n" + "=" * 60)
    print("Initializing SD-VI Model (Wide ResNet-28-10)")
    print("=" * 60)

    model_sdvi = WideResNet(depth=28, widen_factor=10, dropout_rate=0.3,
                            num_classes=10, layer_type='sdvi', device=device).to(device)

    print(f"Model parameters: {sum(p.numel() for p in model_sdvi.parameters()):,}")

    start_time = time.time()
    model_sdvi = train_sdvi(model_sdvi, trainloader, testloader, epochs_sdvi,
                            lr_mu, lr_S, lambda1, device)
    sdvi_train_time = time.time() - start_time

    # Train Mean-Field VI model
    print("\n" + "=" * 60)
    print("Initializing Mean-Field VI Model (Wide ResNet-28-10)")
    print("=" * 60)

    model_mfvi = WideResNet(depth=28, widen_factor=10, dropout_rate=0.3,
                            num_classes=10, layer_type='mfvi', device=device).to(device)

    print(f"Model parameters: {sum(p.numel() for p in model_mfvi.parameters()):,}")

    start_time = time.time()
    model_mfvi = train_mfvi(model_mfvi, trainloader, testloader, epochs_mfvi,
                            lr_mfvi, kl_weight, device)
    mfvi_train_time = time.time() - start_time

    # Final evaluation
    print("\n" + "=" * 60)
    print("FINAL EVALUATION RESULTS")
    print("=" * 60)

    # SD-VI evaluation
    print("\nEvaluating SD-VI model...")
    sdvi_acc, sdvi_ece, sdvi_nll = evaluate(model_sdvi, testloader, device, n_samples=20)

    # MFVI evaluation
    print("\nEvaluating Mean-Field VI model...")
    mfvi_acc, mfvi_ece, mfvi_nll = evaluate(model_mfvi, testloader, device, n_samples=20)

    # Print results table
    print("\n" + "=" * 60)
    print("Table 1: Wide ResNet-28-10 on CIFAR-10")
    print("=" * 60)
    print(f"{'Method':<15} {'Accuracy (%)':<15} {'ECE (%)':<15} {'NLL':<15} {'Time (min)':<15}")
    print("-" * 60)
    print(f"{'SD-VI':<15} {sdvi_acc:<15.2f} {sdvi_ece:<15.2f} {sdvi_nll:<15.4f} {sdvi_train_time / 60:<15.2f}")
    print(f"{'Mean-Field VI':<15} {mfvi_acc:<15.2f} {mfvi_ece:<15.2f} {mfvi_nll:<15.4f} {mfvi_train_time / 60:<15.2f}")
    print("=" * 60)

    # Additional analysis
    print("\n" + "=" * 60)
    print("ADDITIONAL ANALYSIS")
    print("=" * 60)

    # Check effective rank of covariance matrices in SD-VI
    sdvi_layers = model_sdvi.get_sdvi_layers()
    ranks = []
    for i, layer in enumerate(sdvi_layers[:5]):  # Check first 5 layers
        eigenvalues = torch.linalg.eigvalsh(layer.S)
        effective_rank = torch.sum(eigenvalues > 1e-6).item()
        total_params = layer.n_params
        ranks.append((effective_rank, total_params))
        print(f"Layer {i + 1}: Effective rank = {effective_rank}/{total_params} "
              f"({100 * effective_rank / total_params:.1f}%)")

    print("\nTraining completed successfully!")


if __name__ == "__main__":
    main()