import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import argparse

from src.data_loader import get_svhn_for_individual_exp, partition_svhn_for_exp2
from src.models import get_model
from src.attacks import get_torchattacks_attacks
from src.id_estimator import get_gradient_vector, estimate_id_from_embeddings
from src.plotting import plot_individual_comparison

def main(args):
    torch.manual_seed(42)
    np.random.seed(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    train_dataset, test_dataset = get_svhn_for_individual_exp()
    model = get_model(num_channels=3, num_classes=10).to(device)

    print("--- Training model on SVHN ---")
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    model.train()
    for epoch in range(args.epochs):
        for data, labels in train_loader:
            data, labels = data.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1}/{args.epochs} complete.")

    b_normal_dataset, eval_dataset = partition_svhn_for_exp2(test_dataset, n_normal=args.n_normal, n_eval=args.n_eval)
    b_normal_loader = torch.utils.data.DataLoader(b_normal_dataset, batch_size=args.n_normal, shuffle=False)
    eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=1, shuffle=False)
    
    print("\n--- Computing baseline gradients (B_normal) ---")
    model.eval()
    b_normal_data, b_normal_labels = next(iter(b_normal_loader))
    b_normal_data, b_normal_labels = b_normal_data.to(device), b_normal_labels.to(device)
    
    b_normal_gradients = np.array([
        get_gradient_vector(model, b_normal_data[i], b_normal_labels[i], criterion)
        for i in range(len(b_normal_data))
    ])
    
    attacks = get_torchattacks_attacks(model, eps=args.epsilon)
    benchmark_results = {}
    
    for attack_name, attack in attacks.items():
        if attack_name not in ["PGD", "AutoAttack"]: continue
        print(f"\n--- Benchmarking: {attack_name} ---")
        incremental_norm_ids, incremental_adv_ids = [], []

        for idx, (data, label) in enumerate(eval_loader):
            data, label = data.to(device), label.to(device)
            adv_data = attack(data, label)
            
            norm_grad = get_gradient_vector(model, data[0], label[0], criterion)
            adv_grad = get_gradient_vector(model, adv_data[0], label[0], criterion)
            
            norm_augmented = np.vstack([b_normal_gradients, norm_grad])
            adv_augmented = np.vstack([b_normal_gradients, adv_grad])
            
            incremental_norm_ids.append(estimate_id_from_embeddings(norm_augmented))
            incremental_adv_ids.append(estimate_id_from_embeddings(adv_augmented))
            
            print(f"Sample {idx}: Norm ID = {incremental_norm_ids[-1]:.4f}, Adv ID = {incremental_adv_ids[-1]:.4f}")

        benchmark_results[attack_name] = {'norm_ids': incremental_norm_ids, 'adv_ids': incremental_adv_ids}

    plot_individual_comparison(benchmark_results)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Run Individual Gradient Analysis (Comparison Plot).')
    parser.add_argument('--epochs', type=int, default=5)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--learning_rate', type=float, default=0.001)
    parser.add_argument('--n_normal', type=int, default=1000, help='Size of the B_normal reference set.')
    parser.add_argument('--n_eval', type=int, default=50, help='Number of samples for evaluation.')
    parser.add_argument('--epsilon', type=float, default=0.1, help='Epsilon for attacks.')
    args = parser.parse_args()
    main(args)