import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import numpy as np
from art.attacks.evasion import FastGradientMethod, ProjectedGradientDescentPyTorch, CarliniLInfMethod
from art.estimators.classification import PyTorchClassifier
from autoattack import AutoAttack
from model.wideresnet import WRN_34_10, WRN_34_20
from model.dfhl_mnv2 import MobileNet
from model.cmi_resnet import resnet18
from model.preactresnet import PreActResNet34, PreActResNet18
import torchattacks
from autoattack import AutoAttack

import seaborn as sns
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder
import torch.nn.functional as F

parser = argparse.ArgumentParser(description='Test Robust Fair')

parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'cifar100', 'tiny'])
parser.add_argument('--model', default='WRN_34_10', type=str, choices=['WRN_34_10', 'WRN_34_20', 'cifar_resnet18', 'mobilenet', 'preactresnet34', 'preactresnet18'])
parser.add_argument('--model_name', type=str)
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--eps', default=8/255, type=str)
parser.add_argument('--step_size', default=2/255, type=str)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ==========================
# Utility
# ==========================

def clamp(X, lower, upper):
    return torch.clamp(X, min=lower, max=upper)


def cw_loss(logits, labels, kappa=0):
    one_hot = F.one_hot(labels, num_classes=logits.size(1))
    correct = torch.sum(one_hot * logits, dim=1)
    wrong = torch.max((1 - one_hot) * logits - 1e4 * one_hot, dim=1)[0]
    return torch.mean(wrong - correct)


# ==========================
# Attacks
# ==========================

def fgsm(args, model, images, labels):
    images = images.clone().detach().to(device)
    labels = labels.to(device)

    images.requires_grad = True
    loss = F.cross_entropy(model(images), labels)

    model.zero_grad()
    loss.backward()

    adv = images + args.eps * images.grad.sign()
    return clamp(adv, 0, 1).detach()


def pgd_linf(args, model, images, labels):
    images = images.to(device)
    labels = labels.to(device)

    adv = images + torch.empty_like(images).uniform_(-args.eps, args.eps)
    adv = clamp(adv, 0, 1)

    for _ in range(20):
        adv.requires_grad = True
        loss = F.cross_entropy(model(adv), labels)

        model.zero_grad()
        loss.backward()

        adv = adv + args.step_size * adv.grad.sign()
        delta = torch.clamp(adv - images, -args.eps, args.eps)
        adv = clamp(images + delta, 0, 1).detach()

    return adv


def cw_linf(args, model, images, labels):
    images = images.to(device)
    labels = labels.to(device)

    adv = images + torch.empty_like(images).uniform_(-args.eps, args.eps)
    adv = clamp(adv, 0, 1)

    for _ in range(30):
        adv.requires_grad = True
        loss = cw_loss(model(adv), labels)

        model.zero_grad()
        loss.backward()

        adv = adv + args.step_size * adv.grad.sign()
        delta = torch.clamp(adv - images, -args.eps, args.eps)
        adv = clamp(images + delta, 0, 1).detach()

    return adv


# ==========================
# Robust Evaluation
# ==========================

def robust_evaluation(args, model, loader, attack_fn, attack_name, num_classes):
    model.eval()

    correct = torch.zeros(num_classes)
    total = torch.zeros(num_classes)
    confusion = torch.zeros((num_classes, num_classes), dtype=torch.int32)

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        adv_images = attack_fn(args, model, images, labels)

        outputs = model(adv_images)
        preds = outputs.argmax(1)

        for i in range(labels.size(0)):
            y = labels[i].item()
            p = preds[i].item()

            total[y] += 1
            confusion[y][p] += 1
            if y == p:
                correct[y] += 1

    per_class_acc = correct / total.clamp(min=1)
    overall_acc = correct.sum().item() / total.sum().item()

    # Fairness metrics
    acc_np = per_class_acc.cpu().numpy()
    min_acc = np.min(acc_np)
    mean_acc = np.mean(acc_np)
    std_acc = np.std(acc_np)
    nsd = std_acc / mean_acc if mean_acc != 0 else 0

    lowest_k = min(10, num_classes)
    lowest_k_sum = np.sum(np.sort(acc_np)[:lowest_k])

    print(f"\n===== {attack_name} =====")
    print("Per Class Acc:", per_class_acc)
    print("Overall Robust Acc:", overall_acc)
    print("Min Class Acc:", min_acc)
    print("Mean Class Acc:", mean_acc)
    print("Std:", std_acc)
    print("NSD:", nsd)
    print(f"Sum of lowest {lowest_k} classes:", lowest_k_sum)

    # Save confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(confusion.cpu().numpy(),
                cmap="Blues",
                xticklabels=False,
                yticklabels=False)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(attack_name)
    plt.tight_layout()
    plt.savefig(os.path.join(args.img_save_dir, f"{attack_name}.png"), dpi=300)
    plt.close()

    return overall_acc, per_class_acc


# ==========================
# AutoAttack
# ==========================

def autoattack_eval(args, model, loader, num_classes):
    model.eval()
    adversary = AutoAttack(model, norm='Linf', eps=args.eps, version='standard')

    adv_confusion_matrix = torch.zeros((num_classes, num_classes), dtype=torch.int32)
    correct = [0 for _ in range(num_classes)]
    total = [0 for _ in range(num_classes)]

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
    
        adv_images = attack.run_standard_evaluation(images, labels, bs=args.batch_size)
        adv_images = torch.tensor(adv_images).to(device)
        adv_preds = model(adv_images).argmax(dim=1)

        for i in range(images.size(0)):
            true_label = labels[i].item()
            pred_label = adv_preds[i].item()

            adv_confusion_matrix[true_label][pred_label] += 1

            total[true_label] += 1
            correct[true_label] += (adv_preds[i] == true_label).item()

    robustness = {cls: correct[cls] / total[cls] for cls in range(num_classes)}
    mean_robustness = sum(robustness.values()) / len(robustness)
    print("### AA ###")
    print("Each class robust acc:")
    print(robustness)
    print("Overall robust acc:")
    print(mean_robustness)
    robust_accs = np.array(list(robustness.values()))
    min_robust = np.min(robust_accs)
    mean_robust = np.mean(robust_accs)
    std_robust = np.std(robust_accs)
    nsd_robust = std_robust / mean_robust if mean_robust != 0 else 0.0
    lowest_10_indices = np.argsort(robust_accs)[:10]
    lowest_10_robust_sum = np.sum(robust_accs[lowest_10_indices])
    print(f"Min Robust Acc:  {min_robust:.4f}")
    print(f"Mean Robust Acc: {mean_robust:.4f}")
    print(f"NSD:             {nsd_robust:.4f}")
    print(f"Sum of lowest 10 classes: {lowest_10_robust_sum:.4f}")
    
    print("Adversarial Confusion Matrix (rows: true labels, cols: predicted labels):")
    print(adv_confusion_matrix)
    # Save Confusion Matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(adv_confusion_matrix.cpu().numpy(), annot=True, fmt="d", cmap="Blues", xticklabels=num_classes, yticklabels=num_classes)
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.title("AA")
    save_path = os.path.join(args.img_save_dir, 'AA.png')
    plt.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.close()

    return mean_robustness, robustness



def main():
    args = parser.parse_args()
    args.eps = eval(args.eps)
    args.step_size = eval(args.step_size)
    args.model_path = f'/home/zhengxiao/RoBen/models/{args.dataset}/{args.model_name}.tar'
    args.img_save_dir = './Confusion_Matrix'
    
    # Load Dataset
    if args.dataset == 'cifar10':
        from torchvision.datasets import CIFAR10
        num_classes = 10
        shape = (3, 32, 32)
        CIFAR = CIFAR10
    elif args.dataset == 'cifar100':
        from torchvision.datasets import CIFAR100
        num_classes = 100
        shape = (3, 32, 32)
        CIFAR = CIFAR100
    
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    testset = CIFAR(root='./data', train=False, download=True, transform=transform)
    testloader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=2)
    
    # Load Model
    if args.model == 'WRN_34_10':
        model = WRN_34_10(num_classes=num_classes)
    elif args.model == 'WRN_34_20':
        model = WRN_34_20(num_classes=num_classes)
    elif args.model == 'cifar_resnet18':
        model = resnet18(num_classes=num_classes)
    elif args.model == 'mobilenet':
        model = MobileNet(num_classes=num_classes)
    model = nn.DataParallel(model)
    
    model.load_state_dict(torch.load(args.model_path)["model_state_dict"])
    model = model.to(device)
    model.eval()
    
    robust_evaluation(args, model, testloader, fgsm, "FGSM", num_classes)
    robust_evaluation(args, model, testloader, pgd_linf, "PGD-20", num_classes)
    robust_evaluation(args, model, testloader, cw_linf, "CW", num_classes)
    autoattack_eval(args, model, testloader, num_classes) 
    

if __name__ == '__main__':
    main()
    
    