import torch
import torchvision
from torchvision import transforms, models

import argparse
import os
import random

from utils import *
from Resnet18 import ResNet18
from SSIM import SSIM


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create arg parser
parser = argparse.ArgumentParser(description='Arguments for Toy ZO-MODO task')

# general
parser.add_argument('--seed', type=int, default=199, help='random seed')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--num_epochs', type=int, default=500, help='number of epochs')
parser.add_argument('--lr', type=float, default=0.1, help='learning rate for the model')
parser.add_argument('--fd_eps', default=1e-4, type=float)
parser.add_argument('--clamp', default=0.5, type=float)
parser.add_argument('--d', default=3072, type=int)
parser.add_argument('--k', default=60, type=int)
parser.add_argument('--num_task', default=2, type=int)

parser.add_argument('--q', default=50, type=int)
parser.add_argument('--s2', default=100, type=int)
parser.add_argument('--lname', default='../results/CIFAR_Resnet_', type=str)
parser.add_argument('--one', default=False, type=bool)
parser.add_argument('--random', default=False, type=bool)

# MoDo
parser.add_argument('--gamma_modo', type=float, default=0.001, help='learning rate of lambda')
parser.add_argument('--rho_modo', type=float, default=0.0, help='regularization parameter')

args = parser.parse_args()

# parse args
print(args)

def main(args, SSIM):
    # Random seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # Define the transforms for data preprocessing
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    #     transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.,), (0.5,))])

    # Load the CIFAR dataset
    # get test dataset and create test dataloader
    data = torchvision.datasets.CIFAR10(root='../data', train=False, transform=transform, download=False)
    dataloader = torch.utils.data.DataLoader(data, batch_size=1000, shuffle=True)
    # Set-up for training
    # init model
    m_state_dict = torch.load('resnet18.pt', weights_only=False)
    model = ResNet18().to(device)  # 实例化 DenseNet 模型类
    model.load_state_dict(m_state_dict)

    SSIM = SSIM()

    # Defining loss functions
    # cross-netropy loss (same as nll loss)
    cross_entropy_loss = nn.CrossEntropyLoss()
    # l1loss
    l1_loss = nn.L1Loss()
    # hinge loss
    hinge_loss = torch.nn.MultiMarginLoss()
    # MSE loss
    mse_loss = torch.nn.MSELoss()
    # Huber loss
    huber_loss = torch.nn.HuberLoss(delta=0.1)  # to make sure this is deifferent from mse

    nllloss = nn.NLLLoss()

    # dictionary of losses
    loss_dict = {'cel': cross_entropy_loss, 'nll': nllloss}

    # MoDo
    modo_kwargs = {'lambd': torch.ones(args.num_task).to(device) / args.num_task, 'gamma': args.gamma_modo,
                   'rho': args.rho_modo}

    index = str(args.lr) + '_' +str(args.k) + '_' + str(args.gamma_modo) + '_' + '.txt'
    lname = args.lname + index

    image_number = 100
    # select 300 numbers from 1000 test_data
    image_id_set = np.random.choice(range(1000), image_number * 3, replace=False)

    succ_count, ii, iii = 0, 0, 0

    I = 1000
    l2_distortion_collect = np.zeros(image_number)
    ssim_collect = np.zeros(image_number)
    attack_succ_count = np.zeros(image_number)
    cc = 0
    cc2 = 0

    images, labels = next(iter(dataloader))

    while iii < image_number:
        attack_flag = False
        image_id = image_id_set[ii]
        ii = ii + 1

        orig_img, target = images[image_id].unsqueeze(0).to(device), labels[image_id].to(device)

        orig_prob = model(orig_img)
        orig_class = torch.argmax(orig_prob)

        # untargeted attack
        target_label = target
        true_label = orig_class

        with open(lname, 'a+') as f:
            f.write("\n Image ID:{}, infer label:{}, true label:{} \n".format(image_id, orig_class, true_label))
        print("Image ID:{}, infer label:{}, true label:{}".format(image_id, orig_class, true_label))
        if true_label != orig_class:
            with open(lname, 'a+') as f:
                f.write("True Label is different from the original prediction, pass!\n")
            print("True Label is different from the original prediction, pass!")
            continue
        else:
            iii = iii + 1

        with open(lname, 'a+') as f:
            f.write('\n' + str(iii) + '/' + str(image_number) + '\n')
        print('\n', iii, '/', image_number)

        adv_image = orig_img
        count = 0
        # gradient = compute_gradient(model, adv_image, true_label, loss_dict, modo_kwargs, args)
        for i in range(args.num_epochs):
            if args.random:
                gradient = compute_gradient_random(model, adv_image, true_label, loss_dict, modo_kwargs, args, device)
            else:
                gradient = compute_gradient(model, adv_image, true_label, loss_dict, modo_kwargs, args, device)
            # ||delta||_0<k
            delta_tmp = args.lr * gradient
            delta_tmp = delta_tmp.view(args.d)
            top_k_idx = torch.argsort(-torch.abs(delta_tmp))[0:args.k]  # 降序排列，返回索引
            delta = torch.zeros_like(delta_tmp)
            delta[top_k_idx] = delta_tmp[top_k_idx]
            l2_dist = torch.norm(delta, p=2)
            # print(torch.nonzero(delta))
            l0_num = torch.nonzero(delta).size(0)
            l0_dist = l0_num / args.d

            delta = delta.view(1, 3, 32, 32)
            # adv_image = orig_img + delta
            ssim = SSIM(orig_img.data.cpu(), (orig_img + delta).data.cpu())
            adv_image = torch.clamp(orig_img + delta, min=-args.clamp, max=args.clamp)
            attack_prob = model(adv_image)
            attack_predict_class = torch.argmax(attack_prob)
            # Judge whether the attack succeeds, if so, break
            if (i + 1) % 1 == 0:
                if true_label != attack_predict_class:
                    # ssim = SSIM(orig_img.data.cpu(), adv_image.data.cpu())
                    with open(lname, 'a+') as f:
                        f.write(
                            "Iter %d (Succ): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, ssim=%3.5f, TL = %d, PL = %d \n" % (
                                i + 1, image_id, l0_dist, l2_dist, ssim, true_label, attack_predict_class))
                    print("Iter %d (Succ): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, ssim=%3.5f, TL = %d, PL = %d \n" % (
                        i + 1, image_id, l0_dist, l2_dist, ssim, true_label, attack_predict_class))
                    attack_flag = True
                    count = count + 1
                    if count == 1:
                        attack_succ = i + 1
                        l2_distortion_collect[cc] = l2_dist
                        ssim_collect[cc] = ssim
                        cc = cc + 1
                    break
                else:
                    with open(lname, 'a+') as f:
                        f.write("Iter %d (Fail): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d \n" % (
                            i + 1, image_id, l0_dist, l2_dist, true_label, attack_predict_class))
                    print("Iter %d (Fail): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d" % (
                        i + 1, image_id, l0_dist, l2_dist, true_label, attack_predict_class))
        if (attack_flag):
            succ_count = succ_count + 1
            attack_succ_count[cc2] = attack_succ
            cc2 = cc2 + 1
            with open(lname, 'a+') as f:
                f.write("It takes {} iterations to find the first attack \n".format(attack_succ))
            print("It takes {} iterations to find the first attack".format(attack_succ))
        else:
            with open(lname, 'a+') as f:
                f.write("Attack Fails\n")
            print("Attack Fails")

    l2_dist_avg = np.sum(l2_distortion_collect) / cc
    ssim_avg = np.sum(ssim_collect) / cc
    attack_succ_count_avg = np.sum(attack_succ_count) / cc2
    print("succ rate: %3.5f, l2_dist_avg: %3.5f, attack_succ_avg: %3.5f, ssim_avg: %3.5f  \n" % (
        succ_count / image_number, l2_dist_avg, attack_succ_count_avg, ssim_avg))
    print(l2_distortion_collect)
    print(attack_succ_count)
    with open(lname, 'a+') as f:
        f.write("succ rate: %3.5f, l2_dist_avg: %3.5f, attack_succ_avg: %3.5f, ssim_avg: %3.5f  \n" % (
            succ_count / image_number, l2_dist_avg, attack_succ_count_avg, ssim_avg))

if __name__ == '__main__':
    for k in [60]:
        args.k = k
        for lr in [0.01]:
            args.lr = lr
            for gamma in [0.0001]:
                args.gamma_modo = gamma
                main(args, SSIM)





