import random
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from HomOpt import HomM
from torch.amp import autocast
from lion_pytorch import Lion
from collections import defaultdict
import os
import sys
import json
from torch.utils.data import TensorDataset, DataLoader
from optimizer_configs import get_optimizer_configurations, generate_param_combinations

script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(script_dir, '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from HomOpt import HomM



# ---------------------------
# Settings
# ---------------------------
SEED = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 20
BATCH_SIZE = 256

# ---------------------------
# Utility functions
# ---------------------------
def set_seed(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def get_data_loaders(batch_size=BATCH_SIZE):
    data = np.loadtxt(os.path.join("data", "HIGGS.csv.gz"), delimiter=',', max_rows=1_000_000)
    # Adjust rows for speed
    X = data[:, 1:]
    y = data[:, 0]

    split = int(0.8 * len(X))
    X_train, X_test = X[:split], X[split:]
    y_train, y_test = y[:split], y[split:]

    X_train = torch.tensor(X_train, dtype=torch.float32)
    y_train = torch.tensor(y_train, dtype=torch.long)
    X_test = torch.tensor(X_test, dtype=torch.float32)
    y_test = torch.tensor(y_test, dtype=torch.long)

    train_dataset = TensorDataset(X_train, y_train)
    test_dataset = TensorDataset(X_test, y_test)


    trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
    testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)

    return trainloader, testloader

# ----------------------------------
# ---  Model creation function ---
# ----------------------------------

# MLP for HIGGS dataset
class HiggsMLP(nn.Module):
    def __init__(self, input_dim=28, hidden_dims=[512, 512, 256], num_classes=2, dropout_rate=0.2):
        super().__init__()
        layers = []
        dims = [input_dim] + hidden_dims

        for i in range(len(hidden_dims)):
            layers.append(nn.Linear(dims[i], dims[i+1]))
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm1d(dims[i+1]))
            layers.append(nn.Dropout(dropout_rate))
        layers.append(nn.Linear(dims[-1], num_classes))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


def create_model(input_dim=28, hidden_dims=[512, 512, 256], num_classes=2, dropout_rate=0.2):
    """
    Factory function to create a HiggsMLP model.
    """
    model = HiggsMLP(
        input_dim=input_dim,
        hidden_dims=hidden_dims,
        num_classes=num_classes,
        dropout_rate=dropout_rate
    )
    return model
# ---------------------------
# Training
# ---------------------------

def train_model(model, trainloader, testloader, optimizer, device, num_epochs=EPOCHS, scheduler=None):
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    history = {'train_losses': [], 'train_accuracies': [], 'test_accuracies': [], 'learning_rates': []}
    best_test_acc = 0.0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_train, total_train = 0, 0

        for i, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            with autocast(device_type='cuda'):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        train_loss = running_loss / len(trainloader)
        train_acc = 100 * correct_train / total_train
        current_lr = optimizer.param_groups[0]['lr']
        history['train_losses'].append(train_loss)
        history['train_accuracies'].append(train_acc)
        history['learning_rates'].append(current_lr)

        # Test
        model.eval()
        correct_test, total_test = 0, 0
        with torch.no_grad():
            for inputs, labels in testloader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total_test += labels.size(0)
                correct_test += (predicted == labels).sum().item()
        test_acc = 100 * correct_test / total_test
        history['test_accuracies'].append(test_acc)
        best_test_acc = max(best_test_acc, test_acc)

        # Scheduler step
        if scheduler:
            scheduler.step()

        # Epoch progress print (every 5 epochs)
        if (epoch + 1) % 5 == 0 or epoch == num_epochs - 1:
            print(f"    Epoch {epoch+1}/{num_epochs} | Train Acc: {train_acc:.2f}% | Test Acc: {test_acc:.2f}% | LR: {current_lr:.2e}")

    return history, best_test_acc

# ---------------------------
# Parameter sweep
# ---------------------------
def run_parameter_sweep():
    set_seed(SEED)
    trainloader, testloader = get_data_loaders()
    configs = get_optimizer_configurations()
    all_results = defaultdict(list)

    total_experiments = sum(len(generate_param_combinations(c)) for c in configs.values())
    exp_count = 0

    for opt_name, config in configs.items():
        param_combinations = generate_param_combinations(config)
        print(f"\n--- Sweeping {opt_name}: {len(param_combinations)} configurations ---")

        for i, params in enumerate(param_combinations, 1):
            exp_count += 1
            print(f"[{exp_count}/{total_experiments}] {opt_name} - Config {i}/{len(param_combinations)}: {params}")

            set_seed(SEED)
            model = create_model().to(DEVICE)
            optimizer = config['class'](model.parameters(), **params)
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

            history, best_acc = train_model(model, trainloader, testloader, optimizer, DEVICE, EPOCHS, scheduler)

            all_results[opt_name].append({
                'params': params,
                'best_test_acc': best_acc,
                'final_test_acc': history['test_accuracies'][-1]
            })

            del model, optimizer, scheduler
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        print(f"\n{opt_name} sweep finished. {len(param_combinations)} configurations tested.\n")

    # Summary
    print("\n=== BEST CONFIGS PER OPTIMIZER ===")
    for opt_name, results in all_results.items():
        best_result = max(results, key=lambda x: x['best_test_acc'])
        print(f"{opt_name}: Best Test Acc: {best_result['best_test_acc']:.2f}% | Params: {best_result['params']}")

    # Save final results
    os.makedirs('parameter_sweep_results', exist_ok=True)
    with open('parameter_sweep_results/final_results.json', 'w') as f:
        json.dump(all_results, f, indent=2)

    return all_results

# ---------------------------
# Main execution
# ---------------------------
if __name__ == "__main__":
    print("=== PARAMETER SWEEP TOOL ===")
    print(f"Using device: {DEVICE}")
    run_parameter_sweep()
