import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import argparse
import eagerpy as ep
import foolbox as fb

from src.data_loader import get_loaders_and_stats
from src.models import get_model
from src.attacks import get_fmodel_and_attacks
from src.id_estimator import get_gradient_vectors, estimate_id
from src.plotting import plot_and_save_results

def main(args):
    torch.manual_seed(42)
    np.random.seed(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # 1. Load Data and Model
    train_loader, test_loader, norm_mean, norm_std, num_channels = get_loaders_and_stats(
        dataset_name=args.dataset, batch_size=args.batch_size
    )
    model = get_model(num_channels=num_channels).to(device)
    
    # 2. Train Model (a simple pre-training)
    print(f"--- Training model on {args.dataset.upper()} for {args.epochs} epochs ---")
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    model.train()
    for epoch in range(args.epochs):
        for data, labels in train_loader:
            data, labels = data.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1}/{args.epochs} complete.")
    
    # 3. Setup for Simulation
    print("\n--- Starting Batch-Wise Simulation ---")
    fmodel, attacks = get_fmodel_and_attacks(model, device, norm_mean, norm_std)
    client_names = ["Client_1_Normal"] + list(attacks.keys())
    
    id_history = {client: [] for client in client_names}
    
    model.eval()
    test_iter = iter(test_loader)
    
    attack_epsilons = {'cifar10': 0.05, 'mnist': 0.1, 'svhn': 0.05}
    epsilon = attack_epsilons.get(args.dataset.lower(), 0.05)

    # 4. Run Simulation Rounds
    for round_idx in range(args.num_rounds):
        print(f"--- Round {round_idx+1}/{args.num_rounds} ---")
        try:
            data, labels = next(test_iter)
        except StopIteration:
            test_iter = iter(test_loader)
            data, labels = next(test_iter)

        data, labels = data.to(device), labels.to(device)

        for client_idx, client_name in enumerate(client_names):
            if client_idx == 0:
                data_client = data.clone()
            else:
                attack = attacks[client_name]
                images_ep = ep.astensor(data)
                labels_ep = ep.astensor(labels)
                _, adv_images_ep, _ = attack(fmodel, images_ep, fb.criteria.Misclassification(labels_ep), epsilons=epsilon)
                data_client = adv_images_ep.raw

            gradients = get_gradient_vectors(model, data_client, labels, criterion, device)
            id_val = estimate_id(gradients)
            id_history[client_name].append(id_val)
            print(f"  {client_name}: ID = {id_val:.4f}")
            
    # 5. Plot and Save Results
    print("\n--- Simulation complete. Generating plots. ---")
    plot_and_save_results(id_history, args.dataset, args.num_rounds)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Run Batch-Wise Adversarial Detection Simulation.')
    parser.add_argument('--dataset', type=str, required=True, choices=['cifar10', 'mnist', 'svhn'], help='Dataset to use for the experiment.')
    parser.add_argument('--batch_size', type=int, default=64, help='Batch size for data loaders.')
    parser.add_argument('--epochs', type=int, default=5, help='Number of epochs for pre-training the model.')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate for training.')
    parser.add_argument('--num_rounds', type=int, default=30, help='Number of simulation rounds.')
    
    args = parser.parse_args()
    main(args)