import random
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from HomOpt import HomM
from torch.amp import autocast
from collections import defaultdict
import os
import json
from datetime import datetime
import time
from torchvision.models import resnet18

# Import from the auto-generated configs
try:
    from optimizer_configs_auto_generated import get_optimizer_configurations, generate_param_combinations
    print("✓ Using auto-generated optimizer configurations from LR tests")
except ImportError:
    print("⚠️  Auto-generated configs not found. Run LR test first!")
    print("   Falling back to manual configs...")
    from optimizer_configs import get_optimizer_configurations, generate_param_combinations

# ---------------------------
# Settings
# ---------------------------
SEED = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 100
BATCH_SIZE = 256
NUM_RUNS = 3  # Number of runs per configuration for statistical significance

# ---------------------------
# Utility functions
# ---------------------------
def set_seed(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def get_data_loaders(batch_size=BATCH_SIZE, num_workers=2):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True,
                                              num_workers=num_workers, pin_memory=True, persistent_workers=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False,
                                             num_workers=num_workers, pin_memory=True, persistent_workers=True)

    return trainloader, testloader

def create_model():
    """Create and adapt ResNet18 for CIFAR-10"""
    model = resnet18()
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    model.fc = nn.Linear(model.fc.in_features, 10)
    return model

def train_model(model, trainloader, testloader, optimizer, device, num_epochs=EPOCHS, scheduler=None):
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    history = {'train_losses': [], 'train_accuracies': [], 'test_accuracies': [], 'learning_rates': []}
    best_test_acc = 0.0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_train, total_train = 0, 0

        for i, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            with autocast(device_type='cuda'):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        train_loss = running_loss / len(trainloader)
        train_acc = 100 * correct_train / total_train
        current_lr = optimizer.param_groups[0]['lr']
        history['train_losses'].append(train_loss)
        history['train_accuracies'].append(train_acc)
        history['learning_rates'].append(current_lr)

        model.eval()
        correct_test, total_test = 0, 0
        with torch.no_grad():
            for inputs, labels in testloader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total_test += labels.size(0)
                correct_test += (predicted == labels).sum().item()
        test_acc = 100 * correct_test / total_test
        history['test_accuracies'].append(test_acc)
        best_test_acc = max(best_test_acc, test_acc)

        if scheduler:
            scheduler.step()

        if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
            print(f"     Epoch {epoch+1}/{num_epochs} | Train Acc: {train_acc:.2f}% | Test Acc: {test_acc:.2f}% | LR: {current_lr:.2e}")

    return history, best_test_acc

def calculate_statistics(run_results):
    """Calculate statistics from multiple runs"""
    best_accs = [r['best_test_acc'] for r in run_results]
    final_accs = [r['final_test_acc'] for r in run_results]
    
    return {
        'mean_best_acc': np.mean(best_accs),
        'std_best_acc': np.std(best_accs),
        'mean_final_acc': np.mean(final_accs),
        'std_final_acc': np.std(final_accs),
        'min_best_acc': np.min(best_accs),
        'max_best_acc': np.max(best_accs)
    }

def run_ablation_study():
    """Run ablation study using optimal LRs from LR range tests"""
    
    # Create timestamped results directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_dir = f"ablation_results_{timestamp}"
    os.makedirs(results_dir, exist_ok=True)
    
    print(f"=== ABLATION STUDY ===")
    print(f"Results will be saved to: {results_dir}")
    print(f"Device: {DEVICE}")
    print(f"Epochs: {EPOCHS}, Batch size: {BATCH_SIZE}, Runs per config: {NUM_RUNS}")
    
    trainloader, testloader = get_data_loaders()
    configs = get_optimizer_configurations()
    all_results = defaultdict(list)

    total_experiments = len(configs) * NUM_RUNS
    exp_count = 0
    
    study_start_time = time.time()

    print(f"\nTesting {len(configs)} configurations from LR tests:")
    for i, (config_name, config) in enumerate(configs.items(), 1):
        params = config['params']
        # Since each config now has a single parameter combination, we get it directly
        param_combo = {k: v[0] for k, v in params.items()}  # Extract single values from lists
        
        print(f"\n[{i}/{len(configs)}] {config_name}")
        print(f"Parameters: {param_combo}")
        
        # Run multiple times for statistical significance
        run_results = []
        config_start_time = time.time()
        
        for run_idx in range(NUM_RUNS):
            exp_count += 1
            run_seed = SEED + run_idx * 1000 + i  # Ensure different seeds
            set_seed(run_seed)
            
            print(f"  Run {run_idx+1}/{NUM_RUNS} (seed={run_seed})...", end=' ')
            
            try:
                model = create_model().to(DEVICE)
                optimizer = config['class'](model.parameters(), **param_combo)
                scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

                history, best_acc = train_model(model, trainloader, testloader, optimizer, DEVICE, EPOCHS, scheduler)

                run_results.append({
                    'run_idx': run_idx,
                    'run_seed': run_seed,
                    'best_test_acc': best_acc,
                    'final_test_acc': history['test_accuracies'][-1],
                    'history': history
                })
                
                print(f"Best: {best_acc:.2f}%")
                
            except Exception as e:
                print(f"ERROR: {str(e)}")
                run_results.append({
                    'run_idx': run_idx,
                    'run_seed': run_seed,
                    'error': str(e)
                })
            
            finally:
                # Clean up
                if 'model' in locals():
                    del model
                if 'optimizer' in locals():
                    del optimizer
                if 'scheduler' in locals():
                    del scheduler
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
        
        config_time = time.time() - config_start_time
        
        # Calculate statistics
        successful_runs = [r for r in run_results if 'best_test_acc' in r]
        if successful_runs:
            stats = calculate_statistics(successful_runs)
            print(f"  Summary: {stats['mean_best_acc']:.2f}±{stats['std_best_acc']:.2f}% (best), {stats['mean_final_acc']:.2f}±{stats['std_final_acc']:.2f}% (final)")
            print(f"  Range: [{stats['min_best_acc']:.2f}%, {stats['max_best_acc']:.2f}%]")
        else:
            print(f"  All runs failed!")
            stats = {}
        
        print(f"  Time: {config_time:.1f}s")
        
        # Store results
        all_results[config_name] = {
            'params': param_combo,
            'runs': run_results,
            'statistics': stats,
            'config_time': config_time
        }
        
        # Save individual config results
        config_file = os.path.join(results_dir, f"{config_name}_results.json")
        with open(config_file, 'w') as f:
            json.dump(all_results[config_name], f, indent=2)

    study_time = time.time() - study_start_time
    
    # Final analysis and ranking
    print(f"\n{'='*80}")
    print("ABLATION STUDY COMPLETED")
    print(f"{'='*80}")
    print(f"Total time: {study_time/60:.1f} minutes")
    print(f"Successful configurations: {len([r for r in all_results.values() if r['statistics']])}")
    
    # Rank configurations by mean best accuracy
    successful_results = [(name, data) for name, data in all_results.items() if data['statistics']]
    ranked_results = sorted(successful_results, key=lambda x: x[1]['statistics']['mean_best_acc'], reverse=True)
    
    print(f"\n=== TOP CONFIGURATIONS ===")
    print(f"{'Rank':<4} {'Configuration':<25} {'Mean±Std (Best)':<15} {'Mean±Std (Final)':<15} {'Parameters'}")
    print("-" * 100)
    
    for rank, (config_name, data) in enumerate(ranked_results[:10], 1):
        stats = data['statistics']
        params = data['params']
        alpha = params.get('alpha', 'N/A')
        beta = params.get('beta', 'N/A')
        lr = params.get('lr', 'N/A')
        
        print(f"{rank:<4} {config_name:<25} {stats['mean_best_acc']:.2f}±{stats['std_best_acc']:.2f}%{'':<6} {stats['mean_final_acc']:.2f}±{stats['std_final_acc']:.2f}%{'':<6} α={alpha}, β={beta}, lr={lr:.1e}")
    
    # Save comprehensive results
    summary = {
        'experiment_info': {
            'timestamp': timestamp,
            'device': str(DEVICE),
            'epochs': EPOCHS,
            'batch_size': BATCH_SIZE,
            'num_runs': NUM_RUNS,
            'total_configs': len(configs),
            'successful_configs': len(successful_results),
            'total_time_minutes': study_time / 60
        },
        'rankings': [(name, data['statistics']) for name, data in ranked_results],
        'all_results': dict(all_results)
    }
    
    summary_file = os.path.join(results_dir, 'ablation_summary.json')
    with open(summary_file, 'w') as f:
        json.dump(summary, f, indent=2)
    
    print(f"\nDetailed results saved to: {results_dir}")
    print(f"Summary file: {summary_file}")
    
    return all_results, results_dir

# ---------------------------
# Main execution
# ---------------------------
if __name__ == "__main__":
    results, results_dir = run_ablation_study()