import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import argparse
import os
from tqdm import tqdm

from src.data_loader import get_sota_loaders
from src.models import get_cifar10_sota_model, get_coco_sota_model
from src.attacks import get_sota_attacks
from src.id_estimator import get_gradient_vector, get_coco_gradient_vector, estimate_id
from src.utils import save_checkpoint, load_checkpoint

def train_model(model, train_loader, epochs, lr, device, checkpoint_path):
    print(f"--- Training model for {epochs} epochs ---")
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    model.train()
    for epoch in range(epochs):
        for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1} complete.")
    save_checkpoint(model, optimizer, checkpoint_path)

def build_reference_gradients(model, val_loader, num_ref, device, dataset_name, criterion):
    print(f"--- Building reference gradient set G_norm from {num_ref} samples ---")
    G_norm = []
    grad_func = get_coco_gradient_vector if dataset_name == 'coco' else get_gradient_vector
    
    for i, (x, y) in enumerate(tqdm(val_loader, total=num_ref)):
        if i >= num_ref: break
        img, lbl = x[0].to(device), y[0].to(device)
        G_norm.append(grad_func(model, img, lbl, criterion))
    return np.stack(G_norm, axis=0)

def find_best_thresholds(model, val_loader, G_norm, attack, device, dataset_name, criterion, id_method, num_search=500):
    print(f"--- Finding best thresholds for {attack.__class__.__name__} using {num_search} samples ---")
    nat_ids, adv_ids = [], []
    grad_func = get_coco_gradient_vector if dataset_name == 'coco' else get_gradient_vector

    nat_samples, adv_samples = [], []
    for i, (x, y) in enumerate(val_loader):
        if i < 1000: continue # Skip reference set
        if len(nat_samples) < num_search:
            nat_samples.append((x[0], y[0]))
        if len(adv_samples) < num_search:
            adv_samples.append((x[0], y[0]))
        if len(nat_samples) == num_search and len(adv_samples) == num_search: break

    for img, lbl in tqdm(nat_samples, desc="Processing natural samples"):
        img, lbl = img.to(device), lbl.to(device)
        gv = grad_func(model, img, lbl, criterion)
        nat_ids.append(estimate_id(np.vstack([G_norm, gv]), method=id_method))

    for img, lbl in tqdm(adv_samples, desc="Processing adversarial samples"):
        img, lbl = img.to(device), lbl.to(device)
        adv_img = attack(img.unsqueeze(0), lbl.unsqueeze(0))[0]
        gv = grad_func(model, adv_img, lbl, criterion)
        adv_ids.append(estimate_id(np.vstack([G_norm, gv]), method=id_method))
        
    best_dra, best_thresholds = -1, (0, 0)
    for lp in [5, 10, 15, 20]:
        for hp in [80, 85, 90, 95]:
            low, high = np.percentile(nat_ids, lp), np.percentile(nat_ids, hp)
            tp = np.sum((np.array(adv_ids) < low) | (np.array(adv_ids) > high))
            tn = np.sum((np.array(nat_ids) >= low) & (np.array(nat_ids) <= high))
            dra = 100 * (tp + tn) / (2 * num_search)
            if dra > best_dra:
                best_dra = dra
                best_thresholds = (low, high)
    
    print(f"Best thresholds found: [{best_thresholds[0]:.4f}, {best_thresholds[1]:.4f}] with DRa: {best_dra:.2f}%")
    return best_thresholds

def evaluate_attack(model, val_loader, G_norm, attack, thresholds, device, dataset_name, criterion, id_method):
    print(f"--- Full evaluation for {attack.__class__.__name__} ---")
    tp, fn = 0, 0
    low, high = thresholds
    grad_func = get_coco_gradient_vector if dataset_name == 'coco' else get_gradient_vector

    for i, (x, y) in enumerate(tqdm(val_loader, total=len(val_loader))):
        if i < 1500: continue # Skip ref and calib sets
        img, lbl = x[0].to(device), y[0].to(device)
        adv_img = attack(img.unsqueeze(0), lbl.unsqueeze(0))[0]
        
        gv = grad_func(model, adv_img, lbl, criterion)
        inc_id = estimate_id(np.vstack([G_norm, gv]), method=id_method)
        
        if not (low <= inc_id <= high):
            tp += 1
        else:
            fn += 1
            
    dra = 100 * tp / (tp + fn) if (tp + fn) > 0 else 0
    print(f"======> Final Detection Rate (DRa) for {attack.__class__.__name__}: {dra:.2f}% <======")
    return dra

def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    train_loader, val_loader, num_classes = get_sota_loaders(args.dataset)
    model_path = f"checkpoints/{args.dataset}_sota_model.pth"
    os.makedirs("checkpoints", exist_ok=True)
    
    if args.dataset == 'cifar10':
        model = get_cifar10_sota_model(num_classes, device)
        id_method = 'mle'
    elif args.dataset == 'coco':
        model = get_coco_sota_model(num_classes, device)
        id_method = 'twonn'
        
    if not os.path.exists(model_path) or args.force_train:
        train_model(model, train_loader, args.epochs, args.learning_rate, device, model_path)
    else:
        print(f"Loading model from {model_path}")
        model = load_checkpoint(model, model_path)

    model.eval()
    criterion = nn.CrossEntropyLoss().to(device)
    
    attacks = get_sota_attacks(model)
    if args.attack not in attacks:
        raise ValueError(f"Attack '{args.attack}' not supported. Choose from {list(attacks.keys())}")
    
    selected_attack = attacks[args.attack]
    
    G_norm = build_reference_gradients(model, val_loader, 1000, device, args.dataset, criterion)
    
    if args.low_thresh is None or args.high_thresh is None:
        thresholds = find_best_thresholds(model, val_loader, G_norm, selected_attack, device, args.dataset, criterion, id_method)
    else:
        thresholds = (args.low_thresh, args.high_thresh)
        print(f"Using provided thresholds: [{thresholds[0]:.4f}, {thresholds[1]:.4f}]")
    
    evaluate_attack(model, val_loader, G_norm, selected_attack, thresholds, device, args.dataset, criterion, id_method)
    
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Run SOTA Comparison Experiment.')
    parser.add_argument('--dataset', type=str, required=True, choices=['cifar10', 'coco'])
    parser.add_argument('--attack', type=str, required=True, choices=["FGSM", "PGD", "BIM", "DeepFool", "CW"])
    parser.add_argument('--force_train', action='store_true', help='Force training even if a checkpoint exists.')
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--learning_rate', type=float, default=1e-3)
    parser.add_argument('--low_thresh', type=float, default=None, help='Manually set the lower threshold.')
    parser.add_argument('--high_thresh', type=float, default=None, help='Manually set the upper threshold.')
    
    args = parser.parse_args()
    main(args)