import numpy as np
import os
import torch
from torch import nn, optim
from torch.optim import optimizer
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import pickle
import sys
import shutil
import copy
import torch.nn.functional as F
import cleverhans
from cleverhans.torch.attacks.projected_gradient_descent import (
    projected_gradient_descent,
)

if __name__=="__main__":
    import path
    folder_path= (path.Path(__file__).abspath()).parent.parent
    sys.path.append(folder_path)

from models.attack_model_base import AttackModel

class PGD(AttackModel):
    def __init__(self, defender, epsilon=0.031, epsilon_iter=0.007, num_steps=20, norm=np.inf, targeted=False):
        self.defender = defender
        self.epsilon = epsilon
        self.num_steps = num_steps
        self.epsilon_iter = epsilon_iter
        self.last_batch_successes = 0
        self.targeted = targeted
        self.norm = np.inf
        self.requires_training = False
    
    def get_perturbed(self, points, labels=None):
        """
        Performs the PGD attack on the given defender model. The output are points whose self.norm from the original
        points is less than self.epsilon. The success/failure status of each point in the input is stored in
        self.last_batch_successes

        Args:
            points (PyTorch Tensor): Tensor of input points.
            labels (PyTorch Tensor): Tensor of corresponding labels.

        Returns:
            data (PyTorch Tensor): Tensor of adversarial examples. Guaranteed to not be more than epsilon perturbed
        """
        criterion_kl = nn.KLDivLoss(size_average=False)
        self.defender.classifier.model.eval()
        x_adv = points.detach() + 0.001 * torch.randn(points.shape).cuda().detach()
        for _ in range(self.num_steps):
            # print("New PGD!!!")
            x_adv.requires_grad_()
            with torch.enable_grad():
                loss_kl = criterion_kl(F.log_softmax(self.defender.classifier.model(x_adv), dim=1),
                                       F.softmax(self.defender.classifier.model(points), dim=1))
            grad = torch.autograd.grad(loss_kl, [x_adv])[0]
            x_adv = x_adv.detach() + self.epsilon_iter * torch.sign(grad.detach())
            x_adv = torch.min(torch.max(x_adv, points - self.epsilon), points + self.epsilon)
            x_adv = torch.clamp(x_adv, 0.0, 1.0)
        self.defender.classifier.model.train()
        x_adv.requires_grad_(False)
        return x_adv

    def get_perturbed_cleverhans(self, points, labels=None):
        if self.targeted==True:
            # Targeted PGD attack!
            data = projected_gradient_descent(self.defender.classifier, points, self.epsilon, self.epsilon_iter, 
                                              self.num_steps, norm=self.norm, targeted=True, y=labels)
        else:
            if self.defender.lazy_attack_update==False:
                print(f"PGD params: epsilon:{self.epsilon}, eps_iter:{self.epsilon_iter}, self.num_steps:{self.num_steps}, norm:{self.norm}")
                data = projected_gradient_descent(self.defender.classifier, points, 
                                          self.epsilon, self.epsilon_iter, self.num_steps, norm=self.norm)
            else:
                # print("Lazy attack PGD!!!")
                # print(f"PGD params: epsilon:{self.epsilon}, eps_iter:{self.epsilon_iter},
                #  self.num_steps:{self.num_steps}, norm:{self.norm}")
                data = projected_gradient_descent(self.defender.lazy_attack_classifier, points, 
                                          self.epsilon, self.epsilon_iter, self.num_steps, norm=self.norm)
        return data

    def _pgd_whitebox(self, X, y, random=True):
                #   epsilon=args.epsilon,
                #   num_steps=args.num_steps,
                #   step_size=args.step_size):
        from torch.autograd import Variable
        out = self.defender.classifier.model(X)
        err = (out.data.max(1)[1] != y.data).float().sum()
        X_pgd = Variable(X.data, requires_grad=True)
        step_size = self.epsilon_iter
        if random:
            random_noise = torch.FloatTensor(*X_pgd.shape).uniform_(-self.epsilon, self.epsilon).to(self.defender.device)
            X_pgd = Variable(X_pgd.data + random_noise, requires_grad=True)

        for _ in range(self.num_steps):
            opt = torch.optim.SGD([X_pgd], lr=1e-3)
            opt.zero_grad()

            with torch.enable_grad():
                loss = nn.CrossEntropyLoss()(self.defender.classifier.model(X_pgd), y)
            loss.backward()
            eta = step_size * X_pgd.grad.data.sign()
            X_pgd = Variable(X_pgd.data + eta, requires_grad=True)
            eta = torch.clamp(X_pgd.data - X.data, -self.epsilon, self.epsilon)
            X_pgd = Variable(X.data + eta, requires_grad=True)
            X_pgd = Variable(torch.clamp(X_pgd, 0, 1.0), requires_grad=True)
        err_pgd = (self.defender.classifier.model(X_pgd).data.max(1)[1] != y.data).float().sum()
        print('err pgd (white-box): ', err_pgd)
        return X_pgd, err, err_pgd


    def indices_to_points(self, indices):
        """
        Utility function: Takes in indices and returns corresponding dataset slice.
        
        Args:
            indices (List): List of indices. Could also be a numpy array.

        Returns:
            X, Y: A slice of the dataset, indexed by the input indices. 
        """
        return self.dataset[indices]


import argparse

parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data', default='cifar10', type=str)
parser.add_argument('--classifier_type', '-ct', type=str, default='nn_cifar10')
parser.add_argument('--attack_model_type', '-amt', type=str, default='pgd')
parser.add_argument('--epoch', type=int, default=500)
parser.add_argument('--base_train_epochs', type=int, default=120)
#parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--rho', type=float, default=0.1) 
parser.add_argument('--lam', type=float, default=1)
parser.add_argument('-s', '--seed', nargs='+', default=0, type=int)
# parser.add_argument('-K', '--K', default=20, type=int, help='cardinality of the set S')
parser.add_argument('-eta', '--eta', default=10, type=float, help='percentage of the dataset on which to attack')
parser.add_argument('-T', '--T', default=20, type=int, help='Number of iterations for the defense algorithm')
# parser.add_argument('-num', '--num_intervals', default=10, type=int)
parser.add_argument('-p', '--plot', action='store_true')
parser.add_argument('--save_dir', type=str, default='saved_models')
parser.add_argument('--save_dir_base_model', type=str, default='models/defense/CIFAR10_models')
parser.add_argument('--data_dir', type=str, default='data')
parser.add_argument('--load', type=bool, default=True)
parser.add_argument('--save-freq', '-sf', default=1, type=int, metavar='N', help='save frequency')
parser.add_argument('--log-interval', type=int, default=1, metavar='N',
                    help='how many batches to wait before logging training status')

## arguments for distorted greedy
parser.add_argument('-g', '--gamma', default=0.01, type=int, help='gamma parameter for distorted greedy')
parser.add_argument('--dg_batch_size', type=int, default=512, help='batch size for distorted greedy')
parser.add_argument('--eps', type=int, default=0.1, help='error threshold for distorted greedy')

## arguments for base classifier
parser.add_argument('--base_epochs', type=int, default=50)
parser.add_argument('--base_lr', type=float, default=0.001)
parser.add_argument('--base_batch_size', type=int, default=32)
parser.add_argument('--base_weight_decay', type=float, default=0)

# parser.add_argument('--num_labels', type=int, default=10)
# parser.add_argument('--num_channels', type=int, default=1)
# parser.add_argument('--min_clamp', type=float, default=0)
# parser.add_argument('--max_clamp', type=float, default=1)
parser.add_argument('--adv_lr', type=float, default=0.001)
parser.add_argument('--batch_size', type=int, default=1024)
parser.add_argument('-tbc', action='store_true')

parser.add_argument('-lr', '--lr', type=float, default=0.001)
parser.add_argument('-gs', '--gradient_steps', type=int, default=10)
parser.add_argument('--ts_batch_size', type=int, default=512, help='batch size for the train step while updating classifier')
parser.add_argument('--weight_decay', type=float, default=0)
parser.add_argument('-gpu', default=0, type=int, help='gpu:id to be used')
parser.add_argument('-lazy_attack_update', '-lazy_attack_update', default=-1, type=int, help='update attack after few timesteps')

args = parser.parse_args()

if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    device = torch.device('cuda')
    torch.cuda.set_device(args.gpu)
    print('Using Device: ', torch.cuda.get_device_name())
else:
    device = torch.device('cpu')


from torchvision import datasets, transforms 
from models.defender import Defender
transform_test = transforms.Compose([
        transforms.ToTensor(),
        ])
# test_dataset = datasets.CIFAR10(root='data', train=False, download=True, transform=transform_test)
with open('dataset_split.pkl', 'rb') as f:
    pkl_data = pickle.load(f)
    test_dataset = pkl_data['val_ds']       ## validation dataset

defender = Defender(classifier_type='nn_cifar10', dataset=None, args=args)
load_path = f"models/defense/CIFAR10_models/model-120-checkpoint"
checkpoint = torch.load(load_path)
defender.classifier.model.load_state_dict(checkpoint["state_dict"])
defender.init_optimizer()
defender.optimizer.load_state_dict(checkpoint["optimizer"])
print('Loaded defender')
    
PGD_attack = PGD(defender)


# ######## single image testing ########
# idx = 107  ### 77
# x_test = test_dataset[idx][0].to(device)
# y_test = test_dataset[idx][1]
# print('Label: ',y_test)
# y_test = torch.tensor(y_test).to(device)
# y_test = torch.unsqueeze(y_test, dim=0)

# plt.figure()
# x_np = x_test.cpu().numpy()
# x_np = np.moveaxis(x_np, 0, 2)
# plt.imshow(x_np)
# plt.title('Original Image')
# plt.savefig(f'Original_img_{idx}.png')
# plt.show()

# x_test = torch.unsqueeze(x_test, dim=0)
# out = defender.classifier.model(x_test)

# print('using new pgd')
# x_adv = PGD_attack.get_perturbed(x_test)
# out_new = defender.classifier.model(x_adv)
# plt.figure()
# x_adv = torch.squeeze(x_adv).cpu().numpy()
# x_adv = np.moveaxis(x_adv, 0, 2)
# plt.imshow(x_adv)
# plt.title('New PGD')
# plt.savefig(f'New_pgd_{idx}.png')
# plt.show()

# print('using CH pgd')
# x_adv_ch = PGD_attack.get_perturbed_cleverhans(x_test)
# out_ch = defender.classifier.model(x_adv_ch)
# x_adv_ch = torch.squeeze(x_adv_ch).detach().cpu().numpy()
# x_adv_ch = np.moveaxis(x_adv_ch, 0, 2)
# plt.figure()
# plt.imshow(x_adv_ch)
# plt.title('Cleverhans PGD')
# plt.savefig(f'CH_pgd_{idx}.png')
# plt.show()

# print('using pgd_whitebox')
# x_adv_wb, err, err_pgd = PGD_attack._pgd_whitebox(x_test, y_test)
# out_wb = defender.classifier.model(x_adv_wb)
# x_adv_wb = torch.squeeze(x_adv_wb).detach().cpu().numpy()
# x_adv_wb = np.moveaxis(x_adv_wb, 0, 2)
# plt.figure()
# plt.imshow(x_adv_wb)
# plt.title('Cleverhans PGD')
# plt.savefig(f'wb_pgd_{idx}.png')
# plt.show()



############ accuracy calc ############
correct = 0
adv_correct_new = 0
adv_correct_ch = 0
adv_correct_wb = 0

dl = DataLoader(test_dataset, batch_size=500)
for idx, (imgs, lbls) in enumerate(dl):
    imgs = imgs.to(device)
    lbls = lbls.to(device)
    acc, loss = list(defender.classifier.evaluate(imgs, lbls).values())
    correct += acc*len(imgs)
    # total_loss += loss*len(imgs)

    x_adv = PGD_attack.get_perturbed(imgs)
    adv_acc_new, loss = list(defender.classifier.evaluate(x_adv, lbls).values())
    adv_correct_new += adv_acc_new*len(imgs)
    
    x_adv_ch = PGD_attack.get_perturbed_cleverhans(imgs)
    adv_acc_ch, loss = list(defender.classifier.evaluate(x_adv_ch, lbls).values())
    adv_correct_ch += adv_acc_ch*len(imgs)
    
    x_adv_wb = PGD_attack.get_perturbed_cleverhans(imgs)
    adv_acc_wb, loss = list(defender.classifier.evaluate(x_adv_wb, lbls).values())
    adv_correct_wb += adv_acc_wb*len(imgs)
    

clean_accuracy = correct/len(test_dataset)
adv_accuracy_new = adv_correct_new/len(test_dataset)
adv_accuracy_ch = adv_correct_ch/len(test_dataset)
adv_accuracy_wb = adv_correct_wb/len(test_dataset)

print('Clean accuracy: ', clean_accuracy)
print('Adv accuracy new: ', adv_accuracy_new)
print('Adv accuracy ch: ', adv_accuracy_ch)
print('Adv accuracy wb: ', adv_accuracy_wb)

