import random
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim

from torch.amp import autocast
from lion_pytorch import Lion
from collections import defaultdict
import os
import sys
import json
from optimizer_configs import get_optimizer_configurations, generate_param_combinations


script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(script_dir, '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from HomOpt import HomM

# ---------------------------
# Settings
# ---------------------------
SEED = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 10
BATCH_SIZE = 128

# ---------------------------
# Utility functions
# ---------------------------
def set_seed(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def get_data_loaders(batch_size=BATCH_SIZE):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    ])
    trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
    return trainloader, testloader

def create_model():
    model = torchvision.models.resnet34(weights=None)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    model.fc = nn.Linear(model.fc.in_features, 100)
    return model

# ---------------------------
# Training
# ---------------------------

def train_model(model, trainloader, testloader, optimizer, device, num_epochs=EPOCHS, scheduler=None):
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    history = {'train_losses': [], 'train_accuracies': [], 'test_accuracies': [], 'learning_rates': []}
    best_test_acc = 0.0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_train, total_train = 0, 0

        for i, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            with autocast(device_type='cuda'):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        train_loss = running_loss / len(trainloader)
        train_acc = 100 * correct_train / total_train
        current_lr = optimizer.param_groups[0]['lr']
        history['train_losses'].append(train_loss)
        history['train_accuracies'].append(train_acc)
        history['learning_rates'].append(current_lr)

        # Test
        model.eval()
        correct_test, total_test = 0, 0
        with torch.no_grad():
            for inputs, labels in testloader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total_test += labels.size(0)
                correct_test += (predicted == labels).sum().item()
        test_acc = 100 * correct_test / total_test
        history['test_accuracies'].append(test_acc)
        best_test_acc = max(best_test_acc, test_acc)

        # Scheduler step
        if scheduler:
            scheduler.step()

        # Epoch progress print (every 5 epochs)
        if (epoch + 1) % 5 == 0 or epoch == num_epochs - 1:
            print(f"    Epoch {epoch+1}/{num_epochs} | Train Acc: {train_acc:.2f}% | Test Acc: {test_acc:.2f}% | LR: {current_lr:.2e}")

    return history, best_test_acc

# ---------------------------
# Parameter sweep
# ---------------------------
def run_parameter_sweep():
    set_seed(SEED)
    trainloader, testloader = get_data_loaders()
    configs = get_optimizer_configurations()
    all_results = defaultdict(list)

    total_experiments = sum(len(generate_param_combinations(c)) for c in configs.values())
    exp_count = 0

    for opt_name, config in configs.items():
        param_combinations = generate_param_combinations(config)
        print(f"\n--- Sweeping {opt_name}: {len(param_combinations)} configurations ---")

        for i, params in enumerate(param_combinations, 1):
            exp_count += 1
            print(f"[{exp_count}/{total_experiments}] {opt_name} - Config {i}/{len(param_combinations)}: {params}")

            set_seed(SEED)
            model = create_model().to(DEVICE)
            optimizer = config['class'](model.parameters(), **params)
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
            # scheduler = None

            history, best_acc = train_model(model, trainloader, testloader, optimizer, DEVICE, EPOCHS, scheduler)

            all_results[opt_name].append({
                'params': params,
                'best_test_acc': best_acc,
                'final_test_acc': history['test_accuracies'][-1]
            })

            del model, optimizer, scheduler
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        print(f"\n{opt_name} sweep finished. {len(param_combinations)} configurations tested.\n")

    # Summary
    print("\n=== BEST CONFIGS PER OPTIMIZER ===")
    for opt_name, results in all_results.items():
        best_result = max(results, key=lambda x: x['best_test_acc'])
        print(f"{opt_name}: Best Test Acc: {best_result['best_test_acc']:.2f}% | Params: {best_result['params']}")

    # Save final results
    os.makedirs('parameter_sweep_results', exist_ok=True)
    with open('parameter_sweep_results/final_results.json', 'w') as f:
        json.dump(all_results, f, indent=2)

    return all_results

# ---------------------------
# Main execution
# ---------------------------
if __name__ == "__main__":
    print("=== PARAMETER SWEEP TOOL ===")
    print(f"Using device: {DEVICE}")
    run_parameter_sweep()
