import torchvision
import torchvision.transforms as transforms

import argparse
import os
import random

from utils import *

vgg = [96, 96, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']


class VGG(nn.Module):
    def __init__(self, vgg):
        super(VGG, self).__init__()
        self.features = self._make_layers(vgg)
        self.dense = nn.Sequential(
            nn.Linear(512, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
        )
        self.classifier = nn.Linear(4096, 10)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.dense(out)
        out = self.classifier(out)
        return out

    def _make_layers(self, vgg):
        layers = []
        in_channels = 3
        for x in vgg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x

        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '1'

# Create arg parser
parser = argparse.ArgumentParser(description='Arguments for Toy ZO-MODO task')

# general
parser.add_argument('--seed', type=int, default=2025, help='random seed')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--num_epochs', type=int, default=1000, help='number of epochs')
parser.add_argument('--lr', type=float, default=0.01, help='learning rate for the model')
parser.add_argument('--fd_eps', default=1e-3, type=float)
parser.add_argument('--d', default=3072, type=int)
parser.add_argument('--k', default=60, type=int)
parser.add_argument('--num_task', default=2, type=int)

parser.add_argument('--q', default=100, type=int)
parser.add_argument('--s2', default=200, type=int)
parser.add_argument('--lname', default='../results/CIFAR_vgg16_', type=str)

# MoDo
parser.add_argument('--gamma_modo', type=float, default=0.001, help='learning rate of lambda')
parser.add_argument('--rho_modo', type=float, default=0.0, help='regularization parameter')

args = parser.parse_args()

# parse args
print(args)


if __name__ == '__main__':
    # Random seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # Define the transforms for data preprocessing
    transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
#     transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.,), (0.5,))])

    # Load the CIFAR dataset
    # get test dataset and create test dataloader
    data = torchvision.datasets.CIFAR10(root='../data', train=False, transform=transform, download=False)
    dataloader = torch.utils.data.DataLoader(data, batch_size=1000, shuffle=True)
    # Set-up for training
    # init model
    m_state_dict = torch.load('vgg16.pt', weights_only=False)
    model = VGG(vgg).to(device)
    model.load_state_dict(m_state_dict)

    # Defining loss functions
    # cross-netropy loss (same as nll loss)
    cross_entropy_loss = nn.CrossEntropyLoss()
    # l1loss
    l1_loss = nn.L1Loss()
    # hinge loss
    hinge_loss = torch.nn.MultiMarginLoss()
    # MSE loss
    mse_loss = torch.nn.MSELoss()
    # Huber loss
    huber_loss = torch.nn.HuberLoss(delta=0.1)  # to make sure this is deifferent from mse

    nllloss = nn.NLLLoss()

    # dictionary of losses
    loss_dict = {'cel': cross_entropy_loss, 'mse':mse_loss}

    # MoDo
    modo_kwargs = {'lambd': torch.ones(args.num_task).to(device) / args.num_task, 'gamma': args.gamma_modo, 'rho': args.rho_modo}

    index = str(args.k) + '_' + str(args.gamma_modo) + '_' + str(
        args.rho_modo) + '.txt'
    lname = args.lname + index

    image_number = 100
    # select 300 numbers from 1000 test_data
    image_id_set = np.random.choice(range(1000), image_number * 3, replace=False)

    succ_count, ii, iii = 0, 0, 0

    I = 1000
    l2_distortion_collect = np.zeros(image_number)
    attack_succ_count = np.zeros(image_number)
    cc = 0
    cc2 = 0

    images, labels = next(iter(dataloader))

    while iii < image_number:
        attack_flag = False
        image_id = image_id_set[ii]
        ii = ii + 1

        orig_img, target = images[image_id].unsqueeze(0).to(device), labels[image_id].to(device)

        orig_prob = model(orig_img)
        orig_class = torch.argmax(orig_prob)

        # untargeted attack
        target_label = target
        true_label = orig_class

        with open(lname, 'a+') as f:
            f.write("\n Image ID:{}, infer label:{}, true label:{} \n".format(image_id, orig_class, true_label))
        print("Image ID:{}, infer label:{}, true label:{}".format(image_id, orig_class, true_label))
        if true_label != orig_class:
            with open(lname, 'a+') as f:
                f.write("True Label is different from the original prediction, pass!\n")
            print("True Label is different from the original prediction, pass!")
            continue
        else:
            iii = iii + 1

        with open(lname, 'a+') as f:
            f.write('\n' + str(iii) + '/' + str(image_number) + '\n')
        print('\n', iii, '/', image_number)

        adv_image = orig_img
        count = 0
        # gradient = compute_gradient(model, adv_image, true_label, loss_dict, modo_kwargs, args)
        for i in range(args.num_epochs):
            gradient = compute_gradient(model, adv_image, true_label, loss_dict, modo_kwargs, args, device)
            # ||delta||_0<k
            delta_tmp = args.lr * gradient
            delta_tmp = delta_tmp.view(args.d)
            top_k_idx = torch.argsort(-torch.abs(delta_tmp))[0:args.k]
            delta = torch.zeros_like(delta_tmp)
            delta[top_k_idx] = delta_tmp[top_k_idx]
            l2_dist = torch.norm(delta, p=2)
            # print(torch.nonzero(delta))
            l0_num = torch.nonzero(delta).size(0)
            l0_dist = l0_num / args.d

            delta = delta.view(1, 3, 32, 32)
            # adv_image = orig_img + delta
            adv_image = torch.clamp(orig_img + delta, min=-0.5, max=0.5)
            attack_prob = model(adv_image)
            attack_predict_class = torch.argmax(attack_prob)
            # Judge whether the attack succeeds, if so, break
            if (i + 1) % 1 == 0:
                if true_label != attack_predict_class:
                    with open(lname, 'a+') as f:
                        f.write("Iter %d (Succ): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d \n" % (
                            i + 1, image_id, l0_dist, l2_dist, true_label, attack_predict_class))
                    print("Iter %d (Succ): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d" % (
                        i + 1, image_id, l0_dist, l2_dist, true_label, attack_predict_class))
                    attack_flag = True
                    count = count + 1
                    if count == 1:
                        attack_succ = i + 1
                        l2_distortion_collect[cc] = l2_dist
                        cc = cc + 1
                    break
                else:
                    with open(lname, 'a+') as f:
                        f.write("Iter %d (Fail): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d \n" % (
                            i + 1, image_id, l0_dist, l2_dist, true_label, attack_predict_class))
                    print("Iter %d (Fail): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d" % (
                        i + 1, image_id, l0_dist, l2_dist, true_label, attack_predict_class))
        if (attack_flag):
            succ_count = succ_count + 1
            attack_succ_count[cc2] = attack_succ
            cc2 = cc2 + 1
            with open(lname, 'a+') as f:
                f.write("It takes {} iterations to find the first attack \n".format(attack_succ))
            print("It takes {} iterations to find the first attack".format(attack_succ))
        else:
            with open(lname, 'a+') as f:
                f.write("Attack Fails\n")
            print("Attack Fails")

    l2_dist_avg = np.sum(l2_distortion_collect) / cc
    attack_succ_count_avg = np.sum(attack_succ_count) / cc2
    print("succ rate: %3.5f, l2_dist_avg: %3.5f, attack_succ_avg: %3.5f  \n" % (
        succ_count / image_number, l2_dist_avg, attack_succ_count_avg))
    print(l2_distortion_collect)
    print(attack_succ_count)
    with open(lname, 'a+') as f:
        f.write("succ rate: %3.5f, l2_dist_avg: %3.5f, attack_succ_avg: %3.5f  \n" % (
        succ_count / image_number, l2_dist_avg, attack_succ_count_avg))






