
import random
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torch.optim.optimizer import Optimizer
from torchvision.models import resnet18, resnet34
import torchvision.transforms as transforms

from torch.amp import autocast
from lion_pytorch import Lion
import json
import sys
import os
from collections import defaultdict
import time
from datetime import datetime
from optimizer_configs import get_best_parameters

script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(script_dir, '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from HomOpt import HomM


def print_gpu_usage():
    if torch.cuda.is_available():
        print(f"GPU Memory Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
        print(f"GPU Memory Cached: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

# -----------------------------------------------
# ---  Set random seeds for reproducibility ---
# -----------------------------------------------

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# --------------------------------
# ---  Data loading function ---
# --------------------------------

def get_data_loaders(batch_size=128):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))  # CIFAR-100 stats
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    ])

    trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

    return trainloader, testloader

# ----------------------------------
# --- 3. Model creation function ---
# ----------------------------------

def create_model():
    model = resnet34(weights=None)
    # Adapt for CIFAR-100
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    model.fc = nn.Linear(model.fc.in_features, 100)
    return model


# --------------------------------------------------------
# --- Model training function with detailed logging ---
# --------------------------------------------------------

def train_model_detailed(model, trainloader, testloader, optimizer, device, num_epochs=100, scheduler=None, save_path=None, optimizer_name="Unknown"):
    """ Enhanced training function with detailed metrics and logging """
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    history = {
        'train_losses': [],
        'train_accuracies': [],
        'test_accuracies': [],
        'learning_rates': [],
        'epoch_times': [],
        'best_epoch': 0,
        'total_time': 0
    }
    best_test_acc = 0.0
    start_time = time.time()
    
    # Enhanced header with more details
    print(f"\n{'='*80}")
    print(f"TRAINING {optimizer_name}")
    print(f"{'='*80}")
    print(f"Epochs: {num_epochs} | Device: {device} | Batch Size: {trainloader.batch_size}")
    print(f"Training samples: {len(trainloader.dataset)} | Test samples: {len(testloader.dataset)}")
    print(f"Batches per epoch: {len(trainloader)}")
    
    # Print initial GPU memory if available
    if torch.cuda.is_available():
        print(f"Initial GPU Memory: {torch.cuda.memory_allocated() / 1024**2:.1f}MB allocated, {torch.cuda.memory_reserved() / 1024**2:.1f}MB cached")
    print(f"{'='*80}")

    # -------------------------------------
    # -------- Model Training Loop --------
    # -------------------------------------

    for epoch in range(num_epochs):
        epoch_start = time.time()
        
        # Print epoch header
        print(f"\nEPOCH {epoch+1:3d}/{num_epochs} | LR: {optimizer.param_groups[0]['lr']:.2e}")
        print(f"{'─'*50}")
        
        # Training phase
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0
        
        # Progress tracking
        batch_print_freq = max(1, len(trainloader) // 4)  # Print 4 times per epoch
        
        for i, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            with autocast(device_type='cuda'):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()
            
            # Print progress during training
            if (i + 1) % batch_print_freq == 0 or (i + 1) == len(trainloader):
                current_loss = running_loss / (i + 1)
                current_acc = 100 * correct_train / total_train
                progress = (i + 1) / len(trainloader) * 100
                
                print(f"  Progress: {progress:5.1f}% | "
                      f"Batch {i+1:3d}/{len(trainloader)} | "
                      f"Loss: {current_loss:.4f} | "
                      f"Acc: {current_acc:5.2f}% | "
                      f"Time: {time.time() - epoch_start:4.1f}s", end='\r')
        
        # Clear the progress line and print final training stats
        print(" " * 100, end='\r')  # Clear line
        
        # -----------------------
        # Record training metrics
        # -----------------------

        train_loss = running_loss / len(trainloader)
        train_acc = 100 * correct_train / total_train
        current_lr = optimizer.param_groups[0]['lr']
        
        history['train_losses'].append(train_loss)
        history['train_accuracies'].append(train_acc)
        history['learning_rates'].append(current_lr)
        
        # -----------------------
        # Testing phase
        # -----------------------

        print(f"  Training  → Loss: {train_loss:.4f} | Acc: {train_acc:6.2f}%")
        print(f"  Testing   → ", end='', flush=True)
        
        model.eval()
        correct_test = 0
        total_test = 0
        test_start = time.time()
        
        with torch.no_grad():
            for i, (inputs, labels) in enumerate(testloader):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total_test += labels.size(0)
                correct_test += (predicted == labels).sum().item()
                
                # Simple test progress indicator
                if (i + 1) % max(1, len(testloader) // 4) == 0:
                    print(".", end='', flush=True)
        
        test_acc = 100 * correct_test / total_test
        history['test_accuracies'].append(test_acc)
        
        epoch_time = time.time() - epoch_start
        test_time = time.time() - test_start
        history['epoch_times'].append(epoch_time)
        
        # Update best accuracy
        is_best = False
        if test_acc > best_test_acc:
            best_test_acc = test_acc
            history['best_epoch'] = epoch
            is_best = True
            if save_path:
                torch.save(model.state_dict(), save_path)
        
        # Clear test progress and print final results
        print(f"\r  Testing   → Acc: {test_acc:6.2f}% | Time: {test_time:4.1f}s | {'🎯 BEST!' if is_best else ''}")
        
        # Print epoch summary
        print(f"  Summary   → Best: {best_test_acc:6.2f}% (epoch {history['best_epoch']+1}) | "
              f"Total: {epoch_time:5.1f}s")
        
        # Print GPU memory usage every 10 epochs
        if torch.cuda.is_available() and ((epoch + 1) % 10 == 0 or epoch == num_epochs - 1):
            print(f"  GPU Mem   → {torch.cuda.memory_allocated() / 1024**2:.1f}MB allocated, "
                  f"{torch.cuda.memory_reserved() / 1024**2:.1f}MB cached")
        
        # Step scheduler
        if scheduler:
            scheduler.step()
        
        # Print milestone updates
        if (epoch + 1) % 25 == 0 and epoch < num_epochs - 1:
            remaining_epochs = num_epochs - epoch - 1
            avg_epoch_time = sum(history['epoch_times']) / len(history['epoch_times'])
            estimated_remaining = avg_epoch_time * remaining_epochs / 60
            
            print(f"\n   MILESTONE: {epoch+1}/{num_epochs} epochs completed")
            print(f"   Estimated remaining time: {estimated_remaining:.1f} minutes")
            print(f"   Best accuracy so far: {best_test_acc:.2f}%")

    # Final summary
    history['total_time'] = time.time() - start_time
    
    print(f"\n{'='*80}")
    print(f"TRAINING COMPLETED: {optimizer_name}")
    print(f"{'='*80}")
    print(f"Total time: {history['total_time']/60:.1f} minutes ({history['total_time']:.1f}s)")
    print(f"Best test accuracy: {best_test_acc:.2f}% (achieved at epoch {history['best_epoch']+1})")
    print(f"Final test accuracy: {test_acc:.2f}%")
    print(f"Final train accuracy: {train_acc:.2f}%")
    print(f"Average time per epoch: {history['total_time']/num_epochs:.1f}s")
    
    if torch.cuda.is_available():
        print(f"Final GPU Memory: {torch.cuda.memory_allocated() / 1024**2:.1f}MB allocated")
    
    print(f"{'='*80}\n")
    
    return history, best_test_acc

# -----------------------------------
# --- 6. Main comparison function ---
# -----------------------------------

def run_final_comparison(num_epochs=100, num_runs=3):
    """ Run final comparison using manually set parameters - training only """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Get best parameters
    best_params = get_best_parameters()
    
    # Display the parameters that will be used
    print(f"\nOptimizer Parameters:")
    for opt_name, params in best_params.items():
        print(f"  {opt_name}: {params}")
    
    # Create results directory
    results_dir = "results"
    os.makedirs(results_dir, exist_ok=True)
    
    # Get data loaders
    trainloader, testloader = get_data_loaders(batch_size=128)
    
    # Optimizer configurations
    optimizer_configs = {
        'SGD': optim.SGD,
        'SGD_Nesterov': optim.SGD,
        'Adam': optim.Adam,
        'Lion': Lion,
        'HomM': HomM
    }
    
    all_results = {}
    
    print(f"\n{'='*80}")
    print(f"FINAL COMPARISON - {num_runs} RUNS PER OPTIMIZER")
    print(f"{'='*80}")
    
    for opt_name in best_params.keys():
        if opt_name not in optimizer_configs:
            print(f"Warning: {opt_name} not in optimizer configs, skipping...")
            continue
        
        print(f"\n{opt_name.upper()}")
        print(f"Parameters: {best_params[opt_name]}")
        print("-" * 60)
        
        run_results = []
        run_histories = []
        
        for run in range(num_runs):
            print(f"\nRun {run + 1}/{num_runs}")
            
            # Set seed for reproducibility within runs
            set_seed(42 + run)
            
            # Create fresh model
            model = create_model().to(device)
            
            # Create optimizer with best parameters
            optimizer = optimizer_configs[opt_name](model.parameters(), **best_params[opt_name])
            
            # Create scheduler
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

            
            # Save path for this run
            save_path = f'{results_dir}/{opt_name}_run_{run+1}_best_model.pth'
            
            # Train model
            history, best_acc = train_model_detailed(
                model, trainloader, testloader, optimizer, device,
                num_epochs=num_epochs, scheduler=scheduler, save_path=save_path,
                optimizer_name=f"{opt_name} (Run {run+1})"
            )
            
            run_results.append({
                'best_test_acc': best_acc,
                'final_test_acc': history['test_accuracies'][-1],
                'best_epoch': history['best_epoch'],
                'total_time': history['total_time'],
                'final_train_acc': history['train_accuracies'][-1]
            })
            run_histories.append(history)
            
            # Clear GPU memory
            del model, optimizer, scheduler
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        
        # Calculate statistics across runs
        best_accs = [r['best_test_acc'] for r in run_results]
        final_accs = [r['final_test_acc'] for r in run_results]
        times = [r['total_time'] for r in run_results]
        
        stats = {
            'optimizer': opt_name,
            'parameters': best_params[opt_name],
            'num_runs': num_runs,
            'best_accuracy': {
                'mean': np.mean(best_accs),
                'std': np.std(best_accs),
                'min': np.min(best_accs),
                'max': np.max(best_accs)
            },
            'final_accuracy': {
                'mean': np.mean(final_accs),
                'std': np.std(final_accs)
            },
            'training_time': {
                'mean': np.mean(times),
                'std': np.std(times)
            },
            'individual_results': run_results,
            'histories': run_histories
        }
        
        all_results[opt_name] = stats
        
        print(f"\n{opt_name} Summary ({num_runs} runs):")
        print(f"  Best Accuracy: {np.mean(best_accs):.2f}% ± {np.std(best_accs):.2f}%")
        print(f"  Range: [{np.min(best_accs):.2f}%, {np.max(best_accs):.2f}%]")
        print(f"  Avg Training Time: {np.mean(times):.1f}s ± {np.std(times):.1f}s")
    
    # Save raw results to JSON
    save_raw_results(all_results, results_dir)
    
    
    print(f"\n{'='*80}")
    print("TRAINING COMPLETED!")
    print(f"{'='*80}")
    print(f" Results saved to: {results_dir}/")
    print(f" Main results: final_comparison_results.json")
    print(f" Loss curves: training_histories.json")
    
    return all_results

def save_raw_results(results, results_dir):
    """Save raw results to JSON file (without histories to keep file size manageable)"""
    # Convert numpy types for JSON serialization and exclude histories
    json_results = {}
    histories_data = {}  # Separate storage for histories
    
    for opt_name, stats in results.items():
        json_stats = {}
        for key, value in stats.items():
            if key == 'histories':
                # Save histories separately
                histories_data[opt_name] = []
                for hist in value:
                    # Convert numpy arrays to lists and clean up the history
                    clean_hist = {}
                    for hist_key, hist_value in hist.items():
                        if isinstance(hist_value, list):
                            # Convert any numpy values in the list
                            clean_hist[hist_key] = [float(x) if isinstance(x, (np.floating, np.integer)) else x 
                                                   for x in hist_value]
                        else:
                            clean_hist[hist_key] = float(hist_value) if isinstance(hist_value, (np.floating, np.integer)) else hist_value
                    histories_data[opt_name].append(clean_hist)
                continue
            elif isinstance(value, dict):
                json_value = {}
                for k, v in value.items():
                    if isinstance(v, np.floating):
                        json_value[k] = float(v)
                    elif isinstance(v, np.integer):
                        json_value[k] = int(v)
                    else:
                        json_value[k] = v
                json_stats[key] = json_value
            elif isinstance(value, (np.floating, np.integer)):
                json_stats[key] = float(value) if isinstance(value, np.floating) else int(value)
            else:
                json_stats[key] = value
        json_results[opt_name] = json_stats
    
    # Save main results (without histories)
    results_file = f'{results_dir}/final_comparison_results.json'
    with open(results_file, 'w') as f:
        json.dump(json_results, f, indent=2)
    
    # Save training histories separately (with all loss curves)
    if histories_data:
        histories_file = f'{results_dir}/training_histories.json'
        with open(histories_file, 'w') as f:
            json.dump(histories_data, f, indent=2)
        print(f"Training histories (all loss curves) saved to {histories_file}")
    
    print(f"Main results saved to {results_file}")

# -------------------------
# --- 7. Main execution ---
# -------------------------

if __name__ == "__main__":
    
    # Check GPU availability
    if torch.cuda.is_available():
        gpu_name = torch.cuda.get_device_name(0)
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
        print(f"GPU detected: {gpu_name}")
        print(f"GPU memory: {gpu_memory:.1f} GB")
    else:
        print("No GPU detected. Using CPU (training will be much slower).")
    
    # Get training parameters
    print(f"\nTraining Configuration:")
    epochs = input("Number of epochs (default: 100): ")
    epochs = int(epochs) if epochs.strip() else 100
    
    runs = input("Number of runs per optimizer (default: 3): ")
    runs = int(runs) if runs.strip() else 3
    
    print(f"  Epochs per run: {epochs}")
    print(f"  Runs per optimizer: {runs}")
    
    confirm = input(f"\nProceed with final comparison? (y/n): ").lower().startswith('y')
    
    if confirm:
        results = run_final_comparison(num_epochs=epochs, num_runs=runs)
        
        if results:
            print(f"\n TRAINING COMPLETED!")
        else:
            print("Training failed - check error messages above.")
    else:
        print("Training cancelled.")
    
    