#!/usr/bin/env python3
"""
Validation script for improved NPT implementation.
This script checks core functionality without running the full training.
"""

import torch
import torch.nn.utils
import numpy as np
from typing import Dict

def test_gradient_clipping():
    """Test gradient clipping functionality"""
    print("Testing gradient clipping...")
    
    # Create a simple model
    model = torch.nn.Linear(10, 5)
    
    # Create dummy input and target
    x = torch.randn(8, 10)
    y = torch.randint(0, 5, (8,))
    
    # Forward pass
    output = model(x)
    loss = torch.nn.functional.cross_entropy(output, y)
    
    # Backward pass
    loss.backward()
    
    # Test gradient clipping
    max_grad_norm = 1.0
    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)
    
    print(f"  Gradient norm before clipping: {grad_norm.item():.4f}")
    print(f"  Max allowed gradient norm: {max_grad_norm}")
    print(f"  Clipping {'applied' if grad_norm > max_grad_norm else 'not needed'}")
    
    return grad_norm.item()

def test_adaptive_weight_decay():
    """Test adaptive weight decay functionality"""
    print("Testing adaptive weight decay...")
    
    base_weight_decay = 1e-4
    grad_norm_scale_factor = 0.1
    
    # Test different gradient norms
    test_grad_norms = [0.1, 0.5, 1.0, 2.0, 5.0]
    
    for grad_norm in test_grad_norms:
        adaptive_wd = base_weight_decay * (1 + grad_norm_scale_factor * grad_norm)
        print(f"  Grad norm: {grad_norm:.1f} -> Adaptive WD: {adaptive_wd:.2e}")

def test_optimizer_update():
    """Test optimizer weight decay update"""
    print("Testing optimizer weight decay update...")
    
    # Create optimizer
    model = torch.nn.Linear(10, 5)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-4)
    
    base_weight_decay = 1e-4
    grad_norm_scale_factor = 0.1
    grad_norm = 2.0
    
    # Calculate adaptive weight decay
    adaptive_wd = base_weight_decay * (1 + grad_norm_scale_factor * grad_norm)
    
    # Update optimizer weight decay
    for group in optimizer.param_groups:
        group['weight_decay'] = adaptive_wd
    
    print(f"  Original weight decay: {base_weight_decay:.2e}")
    print(f"  Adaptive weight decay: {adaptive_wd:.2e}")
    print(f"  Optimizer weight decay: {optimizer.param_groups[0]['weight_decay']:.2e}")
    
    assert abs(optimizer.param_groups[0]['weight_decay'] - adaptive_wd) < 1e-10
    print("  ✓ Weight decay update successful")

def test_metrics_tracking():
    """Test metrics tracking functionality"""
    print("Testing metrics tracking...")
    
    grad_norm_history = []
    adaptive_wd_history = []
    
    # Simulate epoch data
    epoch_grad_norms = [0.8, 1.2, 0.9, 1.1, 0.7]
    epoch_adaptive_wds = [1.08e-4, 1.12e-4, 1.09e-4, 1.11e-4, 1.07e-4]
    
    avg_grad_norm = np.mean(epoch_grad_norms)
    avg_adaptive_wd = np.mean(epoch_adaptive_wds)
    
    grad_norm_history.append(avg_grad_norm)
    adaptive_wd_history.append(avg_adaptive_wd)
    
    print(f"  Average grad norm: {avg_grad_norm:.4f}")
    print(f"  Average adaptive WD: {avg_adaptive_wd:.2e}")
    print(f"  History length: {len(grad_norm_history)}")
    print("  ✓ Metrics tracking successful")

def main():
    """Run all validation tests"""
    print("=== Improved NPT Validation Tests ===\n")
    
    try:
        grad_norm = test_gradient_clipping()
        print()
        
        test_adaptive_weight_decay()
        print()
        
        test_optimizer_update()
        print()
        
        test_metrics_tracking()
        print()
        
        print("=== All validation tests passed! ===")
        
    except Exception as e:
        print(f"Validation failed: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()