import random, numpy as np, torch, torchvision
import torch.nn as nn, torch.optim as optim
import torchvision.transforms as transforms
from torchvision.models import resnet18
from HomOpt import HomM
from torch.amp import autocast
import json
import os
from datetime import datetime
import time
import matplotlib.pyplot as plt
import itertools

def set_seed(seed=42):
    """Set all random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def setup_data(batch_size=128, num_workers=4):
    """Setup CIFAR-10 data loaders with optimized settings"""
    print("Setting up CIFAR-10 data loaders...")
    
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, 
                                            num_workers=num_workers, pin_memory=True, persistent_workers=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, 
                                           num_workers=num_workers, pin_memory=True, persistent_workers=True)
    
    return trainloader, testloader

def create_model():
    """Create and adapt ResNet18 for CIFAR-10"""
    model = resnet18()
    # Adapt for CIFAR-10's smaller input size
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    model.fc = nn.Linear(model.fc.in_features, 10)
    return model

def find_optimal_lr(model, optimizer_class, optimizer_kwargs, trainloader, device, 
                   start_lr=1e-6, end_lr=1e-1, num_iter=200, plot_results=False):
    """Learning rate finder using minimal loss approach"""
    config_name = f"α={optimizer_kwargs['alpha']:.2f}_β={optimizer_kwargs['beta']:.1f}_γ={optimizer_kwargs['gamma']:.1f}"
    print(f"Finding optimal LR for {config_name}")
    
    # Create fresh model state
    model.train()
    original_state = model.state_dict()
    
    # Setup optimizer and criterion
    optimizer = optimizer_class(model.parameters(), lr=start_lr, **optimizer_kwargs)
    criterion = nn.CrossEntropyLoss()
    
    # LR schedule
    mult_factor = (end_lr / start_lr) ** (1.0 / num_iter)
    
    lrs = []
    losses = []
    train_iter = iter(trainloader)
    
    for i in range(num_iter):
        try:
            inputs, labels = next(train_iter)
        except StopIteration:
            train_iter = iter(trainloader)
            inputs, labels = next(train_iter)
        
        inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
        
        current_lr = optimizer.param_groups[0]['lr']
        lrs.append(current_lr)
        
        optimizer.zero_grad()
        with autocast(device_type='cuda' if device.type == 'cuda' else 'cpu'):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
        
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
        
        # Update LR
        for param_group in optimizer.param_groups:
            param_group['lr'] *= mult_factor
        
        # Early stopping if loss explodes
        if i > 30 and len(losses) > 30:
            recent_avg = np.mean(losses[-10:])
            early_avg = np.mean(losses[10:20])
            if recent_avg > early_avg * 5:
                print(f"    Early stopping at iteration {i} due to loss explosion")
                break
    
    # Restore model state
    model.load_state_dict(original_state)
    
    # Find optimal LR using minimal loss method
    optimal_lrs = {}
    
    # Method 1: Direct minimum loss
    min_loss_idx = np.argmin(losses)
    optimal_lrs['min_loss'] = lrs[min_loss_idx]
    
    # Method 2: Smoothed minimum loss (more robust)
    if len(losses) > 30:
        window = 5
        smoothed = np.convolve(losses, np.ones(window)/window, mode='valid')
        smoothed_lrs = lrs[window//2:len(smoothed)+window//2]
        
        if len(smoothed) > 10:
            min_smoothed_idx = np.argmin(smoothed)
            optimal_lrs['smoothed_min_loss'] = smoothed_lrs[min_smoothed_idx]
    
    # Choose smoothed minimum if available, otherwise use direct minimum
    recommended_lr = optimal_lrs.get('smoothed_min_loss', optimal_lrs['min_loss'])
    
    print(f"  Selected LR: {recommended_lr:.2e}")
    
    return recommended_lr, optimal_lrs, lrs, losses

def get_parameter_grid():
    """Define the parameter grid to test"""
    return {
        'alpha': [-0.75, -0.5, -0.25],
        'beta': [0.1,0.3,0.5, 0.7,0.9],
        'gamma': [0.9]
    }

def generate_all_combinations():
    """Generate all parameter combinations"""
    param_grid = get_parameter_grid()
    param_names = list(param_grid.keys())
    param_values = list(param_grid.values())
    
    combinations = []
    for combo in itertools.product(*param_values):
        param_dict = dict(zip(param_names, combo))
        # Use ASCII characters to avoid encoding issues
        name = f"HomM_a{param_dict['alpha']:.2f}_b{param_dict['beta']:.1f}_g{param_dict['gamma']:.1f}"
        param_dict['name'] = name
        combinations.append(param_dict)
    
    return combinations

def run_lr_range_test_study():
    """Run LR range tests for ALL parameter combinations"""
    
    # Create results directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_dir = f"lr_range_tests_all_combos_{timestamp}"
    os.makedirs(results_dir, exist_ok=True)
    
    # Setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    trainloader, testloader = setup_data(batch_size=256, num_workers=4)
    
    # Generate all combinations
    test_configs = generate_all_combinations()
    
    print(f"\nStarting LR Range Test Study for ALL combinations")
    print(f"Testing {len(test_configs)} parameter combinations")
    print(f"Results directory: {results_dir}")
    
    # Print all combinations that will be tested
    print(f"\nAll combinations to test:")
    for i, config in enumerate(test_configs, 1):
        print(f"  {i:2d}. {config['name']}")
    
    all_results = []
    
    for i, config in enumerate(test_configs, 1):
        print(f"\n{'='*60}")
        print(f"Configuration {i}/{len(test_configs)}: {config['name']}")
        print(f"Parameters: α={config['alpha']}, β={config['beta']}, γ={config['gamma']}")
        print(f"{'='*60}")
        
        try:
            # Create fresh model
            set_seed(42)
            model = create_model().to(device)
            
            # Prepare optimizer kwargs (exclude 'name' from config)
            optimizer_kwargs = {k: v for k, v in config.items() if k != 'name'}
            
            # Run LR range test
            start_time = time.time()
            recommended_lr, all_optimal_lrs, lrs, losses = find_optimal_lr(
                model=model,
                optimizer_class=HomM,
                optimizer_kwargs=optimizer_kwargs,
                trainloader=trainloader,
                device=device,
                start_lr=1e-7,  # Wider range
                end_lr=1e0,     # Higher upper bound
                num_iter=300,   # More iterations for better curves
                plot_results=False  # Disable plotting for batch processing
            )
            test_time = time.time() - start_time
            
            # Store results
            result = {
                'config_name': config['name'],
                'alpha': config['alpha'],
                'beta': config['beta'],
                'gamma': config['gamma'],
                'recommended_lr': float(recommended_lr),
                'optimal_lrs': {k: float(v) for k, v in all_optimal_lrs.items()},
                'test_time': test_time,
                'timestamp': datetime.now().isoformat()
            }
            
            all_results.append(result)
            
            # Save individual result
            config_file = os.path.join(results_dir, f"{config['name']}_lr_test.json")
            with open(config_file, 'w') as f:
                json.dump(result, f, indent=2)
            
            print(f"✓ Completed in {test_time:.1f}s")
            
        except Exception as e:
            print(f"✗ Error testing {config['name']}: {str(e)}")
            error_result = {
                'config_name': config['name'],
                'alpha': config['alpha'],
                'beta': config['beta'],
                'gamma': config['gamma'],
                'error': str(e),
                'timestamp': datetime.now().isoformat()
            }
            all_results.append(error_result)
    
    # Save comprehensive summary
    summary_file = os.path.join(results_dir, 'lr_range_test_summary.json')
    with open(summary_file, 'w') as f:
        json.dump({
            'experiment_info': {
                'timestamp': timestamp,
                'total_configs': len(test_configs),
                'successful_tests': len([r for r in all_results if 'recommended_lr' in r]),
                'failed_tests': len([r for r in all_results if 'error' in r]),
                'device': str(device),
                'parameter_grid': get_parameter_grid()
            },
            'results': all_results
        }, f, indent=2)
    
    # Print summary
    successful_results = [r for r in all_results if 'recommended_lr' in r]
    
    print(f"\n{'='*80}")
    print("LR RANGE TEST STUDY COMPLETED")
    print(f"{'='*80}")
    print(f"Successful tests: {len(successful_results)}/{len(test_configs)}")
    print(f"Results saved in: {results_dir}")
    
    if successful_results:
        print(f"\nRecommended Learning Rates Summary:")
        print(f"{'Configuration':<25} {'Alpha':<6} {'Beta':<5} {'Gamma':<6} {'Recommended LR':<15}")
        print("-" * 70)
        for result in successful_results:
            print(f"{result['config_name']:<25} {result['alpha']:<6.2f} {result['beta']:<5.1f} {result['gamma']:<6.1f} {result['recommended_lr']:<15.2e}")
        
        # Find best and worst LRs
        lrs = [r['recommended_lr'] for r in successful_results]
        best_idx = np.argmax(lrs)
        worst_idx = np.argmin(lrs)
        
        print(f"\nBest (highest) LR: {successful_results[best_idx]['config_name']} = {lrs[best_idx]:.2e}")
        print(f"Worst (lowest) LR: {successful_results[worst_idx]['config_name']} = {lrs[worst_idx]:.2e}")
    
    return results_dir, all_results

if __name__ == "__main__":
    print("Starting LR range test for ALL parameter combinations...")
    results_directory, results = run_lr_range_test_study()
    print(f"\nAll results saved in: {results_directory}")