import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import argparse
from sklearn.metrics import accuracy_score

from src.data_loader import get_svhn_for_individual_exp, partition_svhn_for_exp2_b
from src.models import get_model
from src.attacks import get_torchattacks_attacks
from src.id_estimator import get_gradient_vector, estimate_id_from_embeddings
from src.plotting import plot_evaluation_histogram

def evaluate_group(model, loader, b_normal_gradients, criterion, device, bounds, attack=None, group_name="Natural"):
    true_labels, pred_labels, group_ids = [], [], []
    lower_bound, upper_bound = bounds
    
    for idx, (data, label) in enumerate(loader):
        data, label = data.to(device), label.to(device)
        sample = attack(data, label)[0] if attack else data[0]
        
        sample_grad = get_gradient_vector(model, sample, label[0], criterion)
        augmented = np.vstack([b_normal_gradients, sample_grad])
        incremental_id = estimate_id_from_embeddings(augmented)
        
        decision = 0 if lower_bound <= incremental_id <= upper_bound else 1
        true = 0 if group_name == "Natural" else 1
        
        true_labels.append(true)
        pred_labels.append(decision)
        group_ids.append(incremental_id)
        
    acc = accuracy_score(true_labels, pred_labels)
    print(f"Detection Accuracy for {group_name}: {acc*100:.2f}%")
    return group_ids

def main(args):
    torch.manual_seed(42)
    np.random.seed(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    train_dataset, test_dataset = get_svhn_for_individual_exp()
    model = get_model(num_channels=3, num_classes=10).to(device)

    print("--- Training model on SVHN ---")
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    model.train()
    for epoch in range(args.epochs):
        # Training loop... (same as before)
        print(f"Epoch {epoch+1}/{args.epochs} complete.")

    b_normal_ds, calib_ds, eval_ds_dict = partition_svhn_for_exp2_b(
        test_dataset, n_normal=args.n_normal, n_calib=args.n_calib, n_eval_per_group=args.n_eval
    )
    
    model.eval()
    print("\n--- Computing baseline gradients (B_normal) ---")
    b_normal_loader = torch.utils.data.DataLoader(b_normal_ds, batch_size=args.n_normal)
    b_normal_data, b_normal_labels = next(iter(b_normal_loader))
    b_normal_data, b_normal_labels = b_normal_data.to(device), b_normal_labels.to(device)
    b_normal_gradients = np.array([get_gradient_vector(model, b_normal_data[i], b_normal_labels[i], criterion) for i in range(args.n_normal)])

    print("\n--- Calibrating threshold ---")
    calib_loader = torch.utils.data.DataLoader(calib_ds, batch_size=1)
    calib_ids = []
    for data, label in calib_loader:
        data, label = data.to(device), label.to(device)
        sample_grad = get_gradient_vector(model, data[0], label[0], criterion)
        augmented = np.vstack([b_normal_gradients, sample_grad])
        calib_ids.append(estimate_id_from_embeddings(augmented))
        
    lower_bound, upper_bound = np.percentile(calib_ids, 10), np.percentile(calib_ids, 90)
    print(f"Thresholds set: [{lower_bound:.4f}, {upper_bound:.4f}]")

    attacks = get_torchattacks_attacks(model, eps=args.epsilon)
    all_group_ids = {}

    print("\n--- Evaluating groups ---")
    all_group_ids["Natural"] = evaluate_group(model, torch.utils.data.DataLoader(eval_ds_dict["Natural"], 1), b_normal_gradients, criterion, device, (lower_bound, upper_bound), group_name="Natural")
    for attack_name, attack_fn in attacks.items():
        loader = torch.utils.data.DataLoader(eval_ds_dict[attack_name], 1)
        all_group_ids[attack_name] = evaluate_group(model, loader, b_normal_gradients, criterion, device, (lower_bound, upper_bound), attack=attack_fn, group_name=attack_name)
    
    plot_evaluation_histogram(all_group_ids, (lower_bound, upper_bound))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Run Individual Gradient Analysis (Detection & Histogram).')
    parser.add_argument('--epochs', type=int, default=5)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--learning_rate', type=float, default=0.001)
    parser.add_argument('--n_normal', type=int, default=1000)
    parser.add_argument('--n_calib', type=int, default=200)
    parser.add_argument('--n_eval', type=int, default=200)
    parser.add_argument('--epsilon', type=float, default=0.1)
    args = parser.parse_args()
    main(args)