import wandb
import torch
import torch.nn as nn
import json
import copy
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from utils import set_global_seed
from cifar10_training import get_data, get_optimizer
import models
from lipschitz_tracker import LipschitzTracker
from optimizers.hooks import DiagnosticHook

NAME = ''
# FOR SIMPLECNN
# OPTIMIZER_NAMES = ['SoftSignumPT_not_decoupled_wd', 'SoftSignumPT', 'AdamW', 'Signum', 'Signum+SGD', 'Adam', 'SGD'] 
OPTIMIZER_NAMES = ['Signum+SGD-like-SoftSignumPT-auto_not_decoupled_wd', 'Signum+SGD-like-SoftSignumPT-auto']

 
# OPTIMIZER_NAMES = ['SoftSignum_decoupled_wd', 'SoftSignumPT_not_decoupled_wd'] 


# FOR RESNET18_32x32
# OPTIMIZER_NAMES = ['Signum', 'Signum+SGD', 'SoftSignum', 'SoftSignumPT', 'AdamW']

MODEL_NAME = 'simplecnn'
DATASET_NAME = 'cifar10'
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
SEED = 42
NUM_EPOCHS = 50
BATCH_SIZE = 128
USE_AUGMENTATIONS = False

HISTOGRAM_BINS = 50


LOG_HISTOGRAMS = False
LOG_CDF = False          
LOG_PERCENTILES = False  
LOG_LIPSCHITZ = False      
LOG_LIPSCHITZ_TEST = False 
LIPSCHITZ_MAX_THRESHOLD = 1e8  
LIPSCHITZ_TEST_EPSILON = 1e-4  
LIPSCHITZ_TEST_N_DIRS = 5       


ESTIMATE_STEP_VARIANCE = False  
STEP_VARIANCE_K_BATCHES = 50   
STEP_VARIANCE_FREQUENCY = 20  


LOG_GRAD_NOISE = True      
LOG_GRAD_NOISE_FREQ = 5   


USE_DIAGNOSTIC_HOOK = True
SATURATION_THRESHOLD = 0.55
DAMPING_TOL = 1e-8


WANDB_PROJECT = ''
WANDB_ENTITY = ''


SAVE_JSON = False 







torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

set_global_seed(SEED)

print(f"Using device: {DEVICE}")


train_loader, eval_loader, test_loader = get_data(
    batch_size=BATCH_SIZE, 
    seed=SEED, 
    use_augmentations=USE_AUGMENTATIONS
)
print(f"Train samples: {len(train_loader.dataset)}")
print(f"Val samples: {len(eval_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")


for OPTIMIZER_NAME in OPTIMIZER_NAMES:
    print("\n" + "="*70)
    print(f"Starting training with optimizer: {OPTIMIZER_NAME}")
    print("="*70 + "\n")
    
    
    set_global_seed(SEED)
    
    try:
        with open(f'tuning/{NAME}/{DATASET_NAME}/{MODEL_NAME}/{OPTIMIZER_NAME}.json', 'r') as f:
            optimizer_params = json.load(f)
        optimizer_params.pop('val_score', None)
        optimizer_params.pop('test_score', None)
        hook = DiagnosticHook(
            saturation_threshold=SATURATION_THRESHOLD,
            damping_tol=DAMPING_TOL
        )
        if USE_DIAGNOSTIC_HOOK:
            optimizer_params['hook'] = hook
            
        print(f"Loaded parameters: {optimizer_params}")
    except FileNotFoundError:
        print(f"Warning: No tuned parameters found for {OPTIMIZER_NAME}, using defaults")
        optimizer_params = {
            'lr': 0.01,
            'momentum': 0.9,
            'weight_decay': 0.001,
            'tmin': 2.0,
            'tmax': 20.0,
            'warmup_iters': 0.8,  # Fraction of total iterations
        }
    
    optimizer_params['batch_size'] = BATCH_SIZE


    run = wandb.init(
        project=WANDB_PROJECT,
        entity=WANDB_ENTITY,
        name=f'{OPTIMIZER_NAME}_{DATASET_NAME}_{MODEL_NAME}',
        config={
            'optimizer': OPTIMIZER_NAME,
            'model': MODEL_NAME,
            'seed': SEED,
            'num_epochs': NUM_EPOCHS,
            'batch_size': BATCH_SIZE,
            'use_augmentations': USE_AUGMENTATIONS,
            **optimizer_params
        }
    )

  
    run.define_metric("epoch", hidden=True)
    run.define_metric("iteration", hidden=True)
    

    run.define_metric("*", step_metric="epoch")
    

    run.define_metric("grad_noise/*", step_metric="iteration")
    run.define_metric("tanh/*", step_metric="iteration")
    run.define_metric("inner_metrics/*", step_metric="iteration")  
  

    for key, value in optimizer_params.items():
        run.summary[f'optimizer/{key}'] = value
    run.summary['optimizer/name'] = OPTIMIZER_NAME


    if MODEL_NAME == 'simplecnn':
        model = models.SimpleCNN().to(DEVICE)
    elif MODEL_NAME == 'simplecnnbinclass':
        model = models.SimpleCNNBinClass().to(DEVICE)
    elif MODEL_NAME == 'resnet18_32x32':
        model = models.ResNet18_32x32().to(DEVICE)
    else:
        raise ValueError(f"Invalid model name: {MODEL_NAME}")

    print(f"Model: {MODEL_NAME}")

    n_iters = NUM_EPOCHS * len(train_loader)
    optimizer, (clipping, scheduler) = get_optimizer(
        OPTIMIZER_NAME, 
        model, 
        search_space=None, 
        trial=None, 
        optimizer_params=optimizer_params,
        n_iters=n_iters
    )
    print(f"Optimizer: {OPTIMIZER_NAME}")
    if clipping:
        print(f"Gradient clipping: {clipping}")
        wandb.config.update({'clipping': clipping})
    if scheduler:
        print(f"LR scheduler: {scheduler}")
        wandb.config.update({'scheduler': scheduler})

    criterion = nn.CrossEntropyLoss()


    lipschitz_tracker = LipschitzTracker(criterion, DEVICE)


    def evaluate(model, loader, device):
        """Evaluate model accuracy and average loss"""
        model.eval()
        correct = 0
        total = 0
        total_loss = 0.0
        with torch.no_grad():
            for images, labels in loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                total_loss += loss.item() * images.size(0) 
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        avg_loss = total_loss / total 
        accuracy = 100 * correct / total
        return accuracy, avg_loss

    def collect_per_sample_losses(model, loader, device, criterion):
        """Collect per-sample losses for histogram"""
        model.eval()
        all_losses = []
        
        with torch.no_grad():
            for images, labels in loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                
                # Per-sample losses (without reduction)
                losses = nn.CrossEntropyLoss(reduction='none')(outputs, labels)
                all_losses.extend(losses.cpu().numpy().tolist())
        
        return all_losses

    def flatten_grads(model):
        """Flatten all gradients into a single vector"""
        vec = []
        for p in model.parameters():
            if p.grad is None:
                vec.append(torch.zeros(p.numel(), device=p.device, dtype=p.dtype))
            else:
                vec.append(p.grad.detach().reshape(-1))
        return torch.cat(vec)

    def set_flat_grads(model, flat):
        """Set gradients from a flat vector"""
        offset = 0
        for p in model.parameters():
            n = p.numel()
            g = flat[offset:offset+n].view_as(p)
            if p.grad is None:
                p.grad = g.clone()
            else:
                p.grad.copy_(g)
            offset += n

    @torch.no_grad()
    def grad_noise_metrics(gk, gref):
        """
        Compute gradient noise metrics:
        m = ||g^k - gref||^2 (L2 distance)
        m_sign = ||sign(g^k) - sign(gref)||^2 (sign distance)
        m_alpha_sign = min_{alpha >= 0} ||alpha * g^k - gref||^2 
        m_alpha_normalized = min_{alpha >= 0} ||alpha * normalized(g^k) - gref||^2 
        """
        m = (gk - gref).pow(2).sum().item()
        m_sign = (gk.sign() - gref.sign()).pow(2).sum().item()
        cos_sml = torch.nn.functional.cosine_similarity(gk, gref, dim=0).item()
        cos_sml_sign = torch.nn.functional.cosine_similarity(gk.sign(), gref.sign(), dim=0).item()
        
        eps = 1e-12
        rel_m = ((gk - gref).pow(2).sum() / (torch.dot(gref, gref) + eps)).item()
        rel_m_sign =  ((gk.sign() - gref.sign()).pow(2).sum() / (torch.dot(gref.sign(), gref.sign()) + eps)).item()
        
        
        # m_alpha_sign = min_{alpha >= 0} ||alpha * g^k - gref||^2 
        sign_gk = gk.sign()
        denom = (sign_gk * sign_gk).sum()
        alpha = (sign_gk * gref).sum() / denom
        alpha = torch.clamp(alpha, min=0.0)
        approx = alpha * sign_gk
        m_alpha_sign = (approx - gref).pow(2).sum().item()
        cos_sml_alpha_sign = torch.nn.functional.cosine_similarity(approx, gref, dim=0).item()
        
        
        # Instead of sign, use normalized gradient
        # m_alpha_normalized = min_{alpha >= 0} ||alpha * normalized(g^k) - gref||^2
        normalized_gk = gk / torch.linalg.vector_norm(gk)
        alpha = (normalized_gk * gref).sum()
        alpha = torch.clamp(alpha, min=0.0)
        approx = alpha * normalized_gk    
        m_alpha_normalized = (approx - gref).pow(2).sum().item()
        cos_sml_alpha_normalized = torch.nn.functional.cosine_similarity(approx, gref, dim=0).item()
        
        # Sparsity metrics for noise = gk - gref
        noise = gk - gref
        noise_norm_2 = torch.linalg.vector_norm(noise, ord=2).item()  # L2 norm
        noise_norm_1 = torch.linalg.vector_norm(noise, ord=1).item()  # L1 norm
        noise_norm_inf = noise.abs().max().item()  # L-inf norm
        
        # Dimension of the gradient vector
        d = gk.numel()
        
        # Sparsity ratios
        eps = 1e-12
        sparsity_ratio_2_inf = noise_norm_2 / (noise_norm_inf + eps)
        sparsity_ratio_1_2 = noise_norm_1 / (d * noise_norm_2 + eps)
        
        return {
            'm': m,
            'm_sign': m_sign,
            'rel_m': rel_m,
            'rel_m_sign': rel_m_sign,
            'cos_sml': cos_sml,
            'cos_sml_sign': cos_sml_sign,
            'm_alpha_sign': m_alpha_sign,
            'cos_sml_alpha_sign': cos_sml_alpha_sign,
            'm_alpha_normalized': m_alpha_normalized,
            'cos_sml_alpha_normalized': cos_sml_alpha_normalized,
            'sparsity_ratio_2_inf': sparsity_ratio_2_inf,
            'sparsity_ratio_1_2': sparsity_ratio_1_2
        }
        
    
    def freeze_bn(m):
        if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
            m.eval()

    def reference_grad_fixed_batches(model, criterion, device, fixed_batches):
        """
        Compute reference gradient: ∇f(x^k) ≈ (1/K) Σ ∇_i f_i(x^k) over fixed batches.
        This approximates the "full gradient" using a fixed set of mini-batches.
        """
        was_training = model.training
        model.train()
        model.apply(freeze_bn)  # running stats won't be updated
        

        model.zero_grad(set_to_none=True)
        K = len(fixed_batches)
        for images, labels in fixed_batches:
            images, labels = images.to(device), labels.to(device)
            loss = criterion(model(images), labels) / K
            loss.backward()

        gref = flatten_grads(model).detach().clone()

        model.train(was_training)
        return gref

    def create_histogram_plot(losses, title, bins=50, log_scale=True):
        """
        Create a 2D histogram plot
        X-axis: loss values
        Y-axis: count (log scale if log_scale=True)
        """
        fig, ax = plt.subplots(figsize=(10, 6))
        
        # Create histogram
        counts, bin_edges, patches = ax.hist(losses, bins=bins, edgecolor='black', alpha=0.7)
        
        # Set log scale for Y axis
        if log_scale:
            ax.set_yscale('log')
            ylabel = 'Count (log scale)'
        else:
            ylabel = 'Count'
        
        # Labels and title
        ax.set_xlabel('Loss Value', fontsize=12)
        ax.set_ylabel(ylabel, fontsize=12)
        ax.set_title(title, fontsize=14)
        ax.grid(True, alpha=0.3, which='both')
        
        # Add statistics text
        mean_loss = np.mean(losses)
        median_loss = np.median(losses)
        std_loss = np.std(losses)
        
        stats_text = f'Mean: {mean_loss:.4f}\nMedian: {median_loss:.4f}\nStd: {std_loss:.4f}'
        ax.text(0.95, 0.95, stats_text,
                transform=ax.transAxes,
                verticalalignment='top',
                horizontalalignment='right',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5),
                fontsize=10)
        
        plt.tight_layout()
        return fig

    def create_cdf_plot(losses, title):
        fig, ax = plt.subplots(figsize=(10, 6))
        
        sorted_losses = np.sort(losses)
        cumulative_prob = np.arange(1, len(sorted_losses) + 1) / len(sorted_losses) * 100
        
        ax.plot(sorted_losses, cumulative_prob, linewidth=2, color='blue')
        ax.fill_between(sorted_losses, 0, cumulative_prob, alpha=0.3)
        
        ax.set_xlabel('Loss Value', fontsize=12)
        ax.set_ylabel('Cumulative Probability (%)', fontsize=12)
        ax.set_title(title, fontsize=14)
        ax.grid(True, alpha=0.3)
        ax.set_ylim(0, 100)
        
        percentiles = [25, 50, 75, 90]
        for p in percentiles:
            p_value = np.percentile(losses, p)
            ax.plot(p_value, p, 'ro', markersize=8)
            ax.text(p_value, p + 3, f'p{p}', fontsize=9, ha='center')
        
        mean_loss = np.mean(losses)
        median_loss = np.median(losses)
        
        stats_text = f'Mean: {mean_loss:.4f}\nMedian: {median_loss:.4f}\np90: {np.percentile(losses, 90):.4f}'
        ax.text(0.95, 0.05, stats_text,
                transform=ax.transAxes,
                verticalalignment='bottom',
                horizontalalignment='right',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5),
                fontsize=10)
        
        plt.tight_layout()
        return fig

    def compute_percentiles(losses):
        
        return {
            'p10': float(np.percentile(losses, 10)),
            'p25': float(np.percentile(losses, 25)),
            'p50': float(np.percentile(losses, 50)),
            'p75': float(np.percentile(losses, 75)),
            'p90': float(np.percentile(losses, 90)),
            'p95': float(np.percentile(losses, 95)),
            'p99': float(np.percentile(losses, 99)),
            'min': float(np.min(losses)),
            'max': float(np.max(losses)),
            'mean': float(np.mean(losses)),
            'std': float(np.std(losses))
        }

    def estimate_step_variance(model, optimizer, criterion, device, fixed_batches, lr):
       
        model.train()
        
        saved_params = []
        for p in model.parameters():
            saved_params.append(p.detach().clone())
        

        saved_optimizer_state = copy.deepcopy(optimizer.state_dict())
        
        def restore_all():
           
            with torch.no_grad():
                for p, saved_p in zip(model.parameters(), saved_params):
                    p.copy_(saved_p)
            optimizer.load_state_dict(saved_optimizer_state)
            model.zero_grad()
        
        
        step_directions = []
        
        for images, labels in fixed_batches:
            images, labels = images.to(device), labels.to(device)
            
           
            restore_all()
            
        
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            
            
            optimizer.step()
            
            
            direction = []
            with torch.no_grad():
                for p, saved_p in zip(model.parameters(), saved_params):
                    step_dir = (p - saved_p) / lr if lr > 0 else (p - saved_p)
                    direction.append(step_dir.view(-1))
            
            step_directions.append(torch.cat(direction))
        
       
        restore_all()
        
        if len(step_directions) < 2:
            return {'step_variance': None, 'k_batches_used': len(step_directions)}
        
        
        G = torch.stack(step_directions)  # [K, D]
        
        
        v = G.var(dim=0, unbiased=True)  # [D]
        
        
        step_var = v.mean().item()
        
        return {
            'step_variance': step_var,
            'k_batches_used': len(step_directions)
        }

    def estimate_lipschitz_test(model, loader, criterion, device, epsilon=1e-4, n_directions=5):
        
        
        rng_state = torch.get_rng_state()
        if torch.cuda.is_available():
            cuda_rng_state = torch.cuda.get_rng_state(device)
        
        model.eval()
        
       
        original_params = torch.cat([p.data.view(-1).clone() for p in model.parameters()])
        
        def set_params(flat_params):
            offset = 0
            for p in model.parameters():
                numel = p.numel()
                p.data.copy_(flat_params[offset:offset + numel].view(p.shape))
                offset += numel
        
        def compute_grad():
            model.zero_grad()
            total_samples = 0
            for images, labels in loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = nn.CrossEntropyLoss(reduction='sum')(outputs, labels)
                loss.backward()
                total_samples += images.size(0)
            
            grad = torch.cat([
                p.grad.data.view(-1) / total_samples if p.grad is not None 
                else torch.zeros(p.numel(), device=device)
                for p in model.parameters()
            ])
            return grad
        
        
        grad_0 = compute_grad()
        
        L_estimates_l2 = []
        L_estimates_linf = []
        
        for _ in range(n_directions):
            
            direction = torch.randn_like(original_params)
            direction = direction / direction.norm(2)
            
            
            delta = epsilon * direction
            
            
            set_params(original_params + delta)
            grad_1 = compute_grad()
            
            
            grad_diff = grad_1 - grad_0
            
            # L2 → L2
            delta_norm_l2 = delta.norm(2).item()
            grad_diff_l2 = grad_diff.norm(2).item()
            if delta_norm_l2 > 1e-12:
                L_estimates_l2.append(grad_diff_l2 / delta_norm_l2)
            
            # L∞ → L1 
            delta_norm_linf = delta.abs().max().item()
            grad_diff_l1 = grad_diff.abs().sum().item()
            if delta_norm_linf > 1e-12:
                L_estimates_linf.append(grad_diff_l1 / delta_norm_linf)
        
        
        set_params(original_params)
        model.zero_grad()  
        
        
        torch.set_rng_state(rng_state)
        if torch.cuda.is_available():
            torch.cuda.set_rng_state(cuda_rng_state, device)
        
        return {
            'L2': {
                'max': float(np.max(L_estimates_l2)) if L_estimates_l2 else None,
                'mean': float(np.mean(L_estimates_l2)) if L_estimates_l2 else None,
            },
            'Linf': {
                'max': float(np.max(L_estimates_linf)) if L_estimates_linf else None,
                'mean': float(np.mean(L_estimates_linf)) if L_estimates_linf else None,
            },
        }

  
    print("\n" + "="*50)
    print("Starting training...")
    print("="*50 + "\n")

    train_losses = []
    val_accuracies = []
    test_accuracies = []
    
    
    if ESTIMATE_STEP_VARIANCE or LOG_GRAD_NOISE:
        print(f"Preparing {STEP_VARIANCE_K_BATCHES} fixed batches for variance/gradient noise estimation...")
        fixed_variance_batches = []
        temp_iterator = iter(train_loader)
        for _ in range(STEP_VARIANCE_K_BATCHES):
            try:
                batch = next(temp_iterator)
                fixed_variance_batches.append(batch)
            except StopIteration:
                break
        print(f"Fixed {len(fixed_variance_batches)} batches\n")

    global_iteration = 0 

    for epoch in range(NUM_EPOCHS):
        model.train()
        running_loss = 0.0
        step_norms_l2 = []
        step_norms_linf = []
        
        
        variance_stats = None
        if ESTIMATE_STEP_VARIANCE and (epoch + 1) % STEP_VARIANCE_FREQUENCY == 0:
            print(f"  Estimating step variance over {len(fixed_variance_batches)} fixed batches...")
            variance_stats = estimate_step_variance(
                model=model,
                optimizer=optimizer,
                criterion=criterion,
                device=DEVICE,
                fixed_batches=fixed_variance_batches,
                lr=optimizer_params.get('lr', 0.01)
            )
        
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
        for images, labels in pbar:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            
            
            if LOG_GRAD_NOISE and global_iteration % LOG_GRAD_NOISE_FREQ == 0:
                
                gk = flatten_grads(model).detach().clone()
                
               
                gref = reference_grad_fixed_batches(
                    model, criterion, DEVICE, fixed_variance_batches
                )
                
                
                metrics_grad_noise = grad_noise_metrics(gk, gref)
                
                         
               
                gref_norm2 = torch.dot(gref, gref).item()
                gk_norm2 = torch.dot(gk, gk).item() 
                
               
                log_dict = {
                    'grad_noise/l2_distance_sq': metrics_grad_noise['m'],
                    'grad_noise/sign_l2_distance_sq': metrics_grad_noise['m_sign'],
                    'grad_noise/l2_distance_rel': metrics_grad_noise['rel_m'],
                    'grad_noise/sign_l2_distance_rel': metrics_grad_noise['rel_m_sign'],
                    'grad_noise/cosine_similarity': metrics_grad_noise['cos_sml'],
                    'grad_noise/cosine_similarity_sign': metrics_grad_noise['cos_sml_sign'],
                    'grad_noise/alpha_sign': metrics_grad_noise['m_alpha_sign'], 
                    'grad_noise/cos_sml_alpha_sign': metrics_grad_noise['cos_sml_alpha_sign'],
                    'grad_noise/alpha_normalized': metrics_grad_noise['m_alpha_normalized'],
                    'grad_noise/cos_sml_alpha_normalized': metrics_grad_noise['cos_sml_alpha_normalized'],
                    'grad_noise/gref_norm2': gref_norm2,
                    'grad_noise/gk_norm2': gk_norm2,
                    'sparsity/ratio_2_inf': metrics_grad_noise['sparsity_ratio_2_inf'],
                    'sparsity/ratio_1_2': metrics_grad_noise['sparsity_ratio_1_2'],
                    'iteration': global_iteration,
                }
                
                
                if USE_DIAGNOSTIC_HOOK: 
                    temperature = hook.t
                    if temperature is not None:
                        gref_tanh = torch.tanh(gref * temperature / 2.0)
                        gk_tanh = torch.tanh(gk * temperature / 2.0)
                        
                        m_tanh = (gk_tanh - gref_tanh).pow(2).sum().item()
                        cos_sml_tanh = torch.nn.functional.cosine_similarity(gk_tanh, gref_tanh, dim=0).item()
                    
                        log_dict.update({
                            'tanh/grads_l2_distance_sq': m_tanh, 
                            'tanh/grads_cosine_similarity': cos_sml_tanh
                        })
                
                
                wandb.log(log_dict)
                
               
                set_flat_grads(model, gk)
            
            global_iteration += 1
            
            if clipping is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), clipping, norm_type='inf')
            
            
            params_before = []
            for p in model.parameters():
                params_before.append(p.detach().clone())
            
            
            if LOG_LIPSCHITZ:
                lipschitz_tracker.save_state(model, images, labels)
            
            optimizer.step()

            if scheduler:
                scheduler.step()
            
            if LOG_LIPSCHITZ:
                lipschitz_tracker.compute_lipschitz(model)
                
            running_loss += loss.item()

            if USE_DIAGNOSTIC_HOOK:
                metrics = hook.compute_and_reset()
                metrics = {f'inner_metrics/{key}': metrics[key] for key in metrics}
                
                wandb.log(metrics)
            
        
            with torch.no_grad():
                step_vector = []
                lr = optimizer_params.get('lr', 0.01)
                for p, p_old in zip(model.parameters(), params_before):
                    step_dir = (p - p_old) / lr if lr > 0 else (p - p_old)
                    step_vector.append(step_dir.view(-1))
                step_vector = torch.cat(step_vector)

                step_norm_l2 = step_vector.norm(2).item()
                step_norm_linf = step_vector.abs().max().item()
                step_norms_l2.append(step_norm_l2)
                step_norms_linf.append(step_norm_linf)

            if USE_DIAGNOSTIC_HOOK:
                temperature = hook.t
                if temperature is not None and (ESTIMATE_STEP_VARIANCE or LOG_GRAD_NOISE):
                   
                    gref = reference_grad_fixed_batches(model, criterion, DEVICE, fixed_variance_batches)

                 
                    with torch.no_grad():
                        gref_tanh = torch.tanh(gref * temperature / 2.0)
                        step_vector_tanh = step_vector

                        m_tanh = (step_vector_tanh - gref_tanh).pow(2).sum().item()
                        cos_sml_tanh = torch.nn.functional.cosine_similarity(step_vector_tanh, gref_tanh, dim=0).item()

                    wandb.log({
                        'tanh/step_l2_distance_sq': m_tanh,
                        'tanh/step_cosine_similarity': cos_sml_tanh
                    })
            
            # Update progress bar
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        # Calculate metrics
        avg_loss = running_loss / len(train_loader)
        val_acc, val_loss = evaluate(model, eval_loader, DEVICE)
        test_acc, test_loss = evaluate(model, test_loader, DEVICE)
        
        
        log_dict = {
            'train_loss': avg_loss,
            'val_loss': val_loss,
            'test_loss': test_loss,
            'val_acc': val_acc,
            'test_acc': test_acc,
            'epoch': epoch + 1
        }

        
        if step_norms_l2:
            log_dict.update({
                'step_norm/L2_mean': np.mean(step_norms_l2),
                'step_norm/L2_max': np.max(step_norms_l2),
                'step_norm/L2_min': np.min(step_norms_l2),
                'step_norm/Linf_mean': np.mean(step_norms_linf),
                'step_norm/Linf_max': np.max(step_norms_linf),
                'step_norm/Linf_min': np.min(step_norms_linf),
            })
        
      
        if LOG_HISTOGRAMS or LOG_CDF or LOG_PERCENTILES:
            train_sample_losses = collect_per_sample_losses(model, train_loader, DEVICE, criterion)
            val_sample_losses = collect_per_sample_losses(model, eval_loader, DEVICE, criterion)
            test_sample_losses = collect_per_sample_losses(model, test_loader, DEVICE, criterion)
            
            
            if LOG_HISTOGRAMS:
                
                train_hist_log = create_histogram_plot(
                    train_sample_losses, 
                    f'Train Loss Distribution (Epoch {epoch+1}) - Log Scale',
                    bins=HISTOGRAM_BINS,
                    log_scale=True
                )
                val_hist_log = create_histogram_plot(
                    val_sample_losses, 
                    f'Val Loss Distribution (Epoch {epoch+1}) - Log Scale',
                    bins=HISTOGRAM_BINS,
                    log_scale=True
                )
                test_hist_log = create_histogram_plot(
                    test_sample_losses, 
                    f'Test Loss Distribution (Epoch {epoch+1}) - Log Scale',
                    bins=HISTOGRAM_BINS,
                    log_scale=True
                )
                
               
                train_hist_linear = create_histogram_plot(
                    train_sample_losses, 
                    f'Train Loss Distribution (Epoch {epoch+1}) - Linear Scale',
                    bins=HISTOGRAM_BINS,
                    log_scale=False
                )
                val_hist_linear = create_histogram_plot(
                    val_sample_losses, 
                    f'Val Loss Distribution (Epoch {epoch+1}) - Linear Scale',
                    bins=HISTOGRAM_BINS,
                    log_scale=False
                )
                test_hist_linear = create_histogram_plot(
                    test_sample_losses, 
                    f'Test Loss Distribution (Epoch {epoch+1}) - Linear Scale',
                    bins=HISTOGRAM_BINS,
                    log_scale=False
                )
                
                log_dict.update({
                    'train_loss_histogram/log': wandb.Image(train_hist_log),
                    'val_loss_histogram/log': wandb.Image(val_hist_log),
                    'test_loss_histogram/log': wandb.Image(test_hist_log),
                    'train_loss_histogram/linear': wandb.Image(train_hist_linear),
                    'val_loss_histogram/linear': wandb.Image(val_hist_linear),
                    'test_loss_histogram/linear': wandb.Image(test_hist_linear),
                })
                
           
                plt.close(train_hist_log)
                plt.close(val_hist_log)
                plt.close(test_hist_log)
                plt.close(train_hist_linear)
                plt.close(val_hist_linear)
                plt.close(test_hist_linear)
            
         
            if LOG_CDF:
                train_cdf_fig = create_cdf_plot(train_sample_losses, f'Train Loss CDF (Epoch {epoch+1})')
                val_cdf_fig = create_cdf_plot(val_sample_losses, f'Val Loss CDF (Epoch {epoch+1})')
                test_cdf_fig = create_cdf_plot(test_sample_losses, f'Test Loss CDF (Epoch {epoch+1})')
                log_dict.update({
                    'train_loss_cdf': wandb.Image(train_cdf_fig),
                    'val_loss_cdf': wandb.Image(val_cdf_fig),
                    'test_loss_cdf': wandb.Image(test_cdf_fig),
                })
                plt.close(train_cdf_fig)
                plt.close(val_cdf_fig)
                plt.close(test_cdf_fig)
            
        
            if LOG_PERCENTILES:
                train_percentiles = compute_percentiles(train_sample_losses)
                val_percentiles = compute_percentiles(val_sample_losses)
                test_percentiles = compute_percentiles(test_sample_losses)
                log_dict.update({
                    'train/p10': train_percentiles['p10'],
                    'train/p25': train_percentiles['p25'],
                    'train/p50_median': train_percentiles['p50'],
                    'train/p75': train_percentiles['p75'],
                    'train/p90': train_percentiles['p90'],
                    'train/p95': train_percentiles['p95'],
                    'train/p99': train_percentiles['p99'],
                    'train/min': train_percentiles['min'],
                    'train/max': train_percentiles['max'],
                    'val/p10': val_percentiles['p10'],
                    'val/p25': val_percentiles['p25'],
                    'val/p50_median': val_percentiles['p50'],
                    'val/p75': val_percentiles['p75'],
                    'val/p90': val_percentiles['p90'],
                    'val/p95': val_percentiles['p95'],
                    'val/p99': val_percentiles['p99'],
                    'val/min': val_percentiles['min'],
                    'val/max': val_percentiles['max'],
                    'test/p10': test_percentiles['p10'],
                    'test/p25': test_percentiles['p25'],
                    'test/p50_median': test_percentiles['p50'],
                    'test/p75': test_percentiles['p75'],
                    'test/p90': test_percentiles['p90'],
                    'test/p95': test_percentiles['p95'],
                    'test/p99': test_percentiles['p99'],
                    'test/min': test_percentiles['min'],
                    'test/max': test_percentiles['max'],
                })
        
        
        if LOG_LIPSCHITZ:
            L_epoch_stats = lipschitz_tracker.get_epoch_stats()
            L_linf_max = L_epoch_stats.get('Linf', {}).get('max')
            L_l2_max = L_epoch_stats.get('L2', {}).get('max')
            print(f"  Lipschitz: L2_max={L_l2_max:.2e}, Linf_max={L_linf_max:.2e}" if L_linf_max else "  Lipschitz: None")
            
            if 'L2' in L_epoch_stats:
                log_dict.update({
                    'lipschitz_train/L2_min': L_epoch_stats['L2']['min'],
                    'lipschitz_train/L2_max': L_epoch_stats['L2']['max'],
                    'lipschitz_train/L2_mean': L_epoch_stats['L2']['mean'],
                })
            
            if L_linf_max is None or L_linf_max <= LIPSCHITZ_MAX_THRESHOLD:
                if 'Linf' in L_epoch_stats:
                    log_dict.update({
                        'lipschitz_train/Linf_min': L_epoch_stats['Linf']['min'],
                        'lipschitz_train/Linf_max': L_linf_max,
                        'lipschitz_train/Linf_mean': L_epoch_stats['Linf']['mean'],
                    })
        
        
        if LOG_LIPSCHITZ_TEST:
            L_test_stats = estimate_lipschitz_test(
                model, test_loader, criterion, DEVICE,
                epsilon=LIPSCHITZ_TEST_EPSILON,
                n_directions=LIPSCHITZ_TEST_N_DIRS
            )
            L_test_linf_max = L_test_stats.get('Linf', {}).get('max')
            if L_test_linf_max is None or L_test_linf_max <= LIPSCHITZ_MAX_THRESHOLD:
                log_dict.update({
                    'lipschitz_test/L2_max': L_test_stats.get('L2', {}).get('max'),
                    'lipschitz_test/L2_mean': L_test_stats.get('L2', {}).get('mean'),
                    'lipschitz_test/Linf_max': L_test_linf_max,
                    'lipschitz_test/Linf_mean': L_test_stats.get('Linf', {}).get('mean'),
                })
        
    
        if variance_stats is not None and variance_stats.get('step_variance') is not None:
            log_dict.update({
                'step_variance/step_variance': variance_stats['step_variance'],
                'step_variance/k_batches': variance_stats['k_batches_used'],
            })
        
       
        wandb.log(log_dict)
        
        # Save metrics
        train_losses.append(avg_loss)
        val_accuracies.append(val_acc)
        test_accuracies.append(test_acc)
        
        # Print epoch summary
        print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] "
            f"Loss: {avg_loss:.4f} | "
            f"Val Acc: {val_acc:.2f}% | "
            f"Test Acc: {test_acc:.2f}%")


    print("\n" + "="*50)
    print("Training completed!")
    print("="*50)
    print(f"Best Val Accuracy: {max(val_accuracies):.2f}% at epoch {val_accuracies.index(max(val_accuracies))+1}")
    print(f"Best Test Accuracy: {max(test_accuracies):.2f}% at epoch {test_accuracies.index(max(test_accuracies))+1}")
    print(f"Final Val Accuracy: {val_accuracies[-1]:.2f}%")
    print(f"Final Test Accuracy: {test_accuracies[-1]:.2f}%")

   
    wandb.run.summary['best_val_acc'] = max(val_accuracies)
    wandb.run.summary['best_test_acc'] = max(test_accuracies)
    wandb.run.summary['final_val_acc'] = val_accuracies[-1]
    wandb.run.summary['final_test_acc'] = test_accuracies[-1]
    wandb.run.summary['best_val_epoch'] = val_accuracies.index(max(val_accuracies)) + 1
    wandb.run.summary['best_test_epoch'] = test_accuracies.index(max(test_accuracies)) + 1


    results = {
        'optimizer': OPTIMIZER_NAME,
        'model': MODEL_NAME,
        'train_losses': train_losses,
        'val_accuracies': val_accuracies,
        'test_accuracies': test_accuracies,
        'best_val_acc': max(val_accuracies),
        'best_test_acc': max(test_accuracies),
        'optimizer_params': optimizer_params
    }

    if SAVE_JSON:
        with open(f'training_log_{OPTIMIZER_NAME}_wandb.json', 'w') as f:
            json.dump(results, f, indent=2)
        print(f"\nResults saved to training_log_{OPTIMIZER_NAME}_wandb.json")

  
    wandb.finish()
    print(f"WandB run completed for {OPTIMIZER_NAME}!")


print("\n" + "="*70)
print(f"All {len(OPTIMIZER_NAMES)} optimizers completed successfully!")
print("="*70)
