import numpy as np
import os
import torch
from torch import nn, optim
from torch.optim import optimizer
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import pickle
import sys
import shutil
import copy
import torch.nn.functional as F
# import cleverhans
# from cleverhans.torch.attacks.projected_gradient_descent import (
#     projected_gradient_descent,
# )

if __name__=="__main__":
    import path
    folder_path= (path.Path(__file__).abspath()).parent.parent
    sys.path.append(folder_path)

from models.attack_model_base import AttackModel

class PGD(AttackModel):
    def __init__(self, defender, epsilon=0.031, epsilon_iter=0.007, num_steps=20, norm=np.inf, targeted=False):
        self.defender = defender
        self.epsilon = epsilon
        self.num_steps = num_steps
        self.epsilon_iter = epsilon_iter
        self.last_batch_successes = 0
        self.targeted = targeted
        self.norm = np.inf
        self.requires_training = False
    
    def get_perturbed(self, points, labels=None):
        """
        Performs the PGD attack on the given defender model. The output are points whose self.norm from the original
        points is less than self.epsilon. The success/failure status of each point in the input is stored in
        self.last_batch_successes

        Args:
            points (PyTorch Tensor): Tensor of input points.
            labels (PyTorch Tensor): Tensor of corresponding labels.

        Returns:
            data (PyTorch Tensor): Tensor of adversarial examples. Guaranteed to not be more than epsilon perturbed
        """
        criterion_kl = nn.KLDivLoss(size_average=False)
        self.defender.classifier.model.eval()
        x_adv = points.detach() + 0.001 * torch.randn(points.shape).cuda().detach()
        for _ in range(self.num_steps):
            # print("New PGD!!!")
            x_adv.requires_grad_()
            with torch.enable_grad():
                loss_kl = criterion_kl(F.log_softmax(self.defender.classifier.model(x_adv), dim=1),
                                       F.softmax(self.defender.classifier.model(points), dim=1))
            grad = torch.autograd.grad(loss_kl, [x_adv])[0]
            x_adv = x_adv.detach() + self.epsilon_iter * torch.sign(grad.detach())
            x_adv = torch.min(torch.max(x_adv, points - self.epsilon), points + self.epsilon)
            x_adv = torch.clamp(x_adv, 0.0, 1.0)
        x_adv.requires_grad_(False)
        return x_adv

        # if self.targeted==True:
        #     # Targeted PGD attack!
        #     data = projected_gradient_descent(self.defender.classifier, points, self.epsilon, self.epsilon_iter, 
        #                                       self.num_steps, norm=self.norm, targeted=True, y=labels)
        # else:
        #     if self.defender.lazy_attack_update==False:
        #         print(f"PGD params: epsilon:{self.epsilon}, eps_iter:{self.epsilon_iter}, self.num_steps:{self.num_steps}, norm:{self.norm}")
        #         data = projected_gradient_descent(self.defender.classifier, points, 
        #                                   self.epsilon, self.epsilon_iter, self.num_steps, norm=self.norm)
        #     else:
        #         # print("Lazy attack PGD!!!")
        #         # print(f"PGD params: epsilon:{self.epsilon}, eps_iter:{self.epsilon_iter},
        #         #  self.num_steps:{self.num_steps}, norm:{self.norm}")
        #         data = projected_gradient_descent(self.defender.lazy_attack_classifier, points, 
        #                                   self.epsilon, self.epsilon_iter, self.num_steps, norm=self.norm)
        # return data

    def indices_to_points(self, indices):
        """
        Utility function: Takes in indices and returns corresponding dataset slice.
        
        Args:
            indices (List): List of indices. Could also be a numpy array.

        Returns:
            X, Y: A slice of the dataset, indexed by the input indices. 
        """
        return self.dataset[indices]