import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions.laplace import Laplace
from autoattack.autopgd_base import APGDAttack_targeted


class Attack(nn.Module):
    def __init__(self, classifier, hparams, device):
        super(Attack, self).__init__()
        self.classifier = classifier
        self.hparams = hparams
        self.device = device

    def forward(self, imgs, labels):
        raise NotImplementedError


class Attack_Linf(Attack):
    def __init__(self, classifier, hparams, device):
        super(Attack_Linf, self).__init__(classifier, hparams, device)
    
    def _clamp_perturbation(self, imgs, adv_imgs):
        """Clamp a perturbed image so that (1) the perturbation is bounded
        in the l_inf norm by self.hparams['epsilon'] and (2) so that the
        perturbed image is in [0, 1]^d."""

        eps = self.hparams['epsilon']
        adv_imgs = torch.min(torch.max(adv_imgs, imgs - eps), imgs + eps)
        return torch.clamp(adv_imgs, 0.0, 1.0)


class APGD_Linf(Attack_Linf):
    def __init__(self, classifier, hparams, device, num_classes):
        super(APGD_Linf, self).__init__(classifier, hparams, device)

        self.eps = self.hparams['epsilon']
        self.num_classes = num_classes
        self.apgd = APGDAttack_targeted(
                classifier, n_restarts=1, n_iter=self.hparams['beta_n_steps'],
                eps=self.eps, norm='Linf', device=device)

    def forward(self, imgs, labels):
        return self.apgd.perturb(imgs, labels)

    
class SBETA_Linf(Attack_Linf):
    def __init__(self, classifier, hparams, device, num_classes):
        super(SBETA_Linf, self).__init__(classifier, hparams, device)

        self.eps = self.hparams['epsilon']
        self.num_classes = num_classes

        self.ones_row = torch.ones(
            self.num_classes
        ).long().to(self.device)
        self.ones_col = torch.ones(
            (self.num_classes, 1, 1, 1)
        ).long().to(self.device)

        # shape: [B * n_classes, n_classes]
        self.target_class_matrix = torch.eye(
            self.num_classes).repeat(self.hparams['batch_size'], 1).to(self.device)
        
    def forward(self, imgs, labels):
        extended_imgs = torch.kron(imgs, self.ones_col)
        extended_labels = torch.kron(labels, self.ones_row)

        # B * n_classes
        extended_batch_size = extended_labels.shape[0]

        lower_img_bound = torch.clamp(extended_imgs - self.eps, min=0.0)
        upper_img_bound = torch.clamp(extended_imgs + self.eps, max=1.0)

        def projection(t):
            return torch.minimum(
                torch.maximum(t, lower_img_bound),
                upper_img_bound
            )
        
        real_class_matrix = torch.zeros(
            extended_batch_size, 
            self.num_classes).to(self.device)
        labels_arange = torch.arange(extended_batch_size).to(self.device)
        real_class_matrix[labels_arange, extended_labels] = 1
        
        compute_output_matrix = real_class_matrix - self.target_class_matrix
        noise = torch.rand_like(extended_imgs)
        adv_imgs = noise * lower_img_bound + (1 - noise) * upper_img_bound
        adv_imgs = adv_imgs.detach().requires_grad_()

        perturbation_optimizer = optim.Adam(
            [adv_imgs,], 
            lr=self.hparams['beta_lr'],
            # amsgrad=True
        )

        for _ in range(self.hparams['beta_n_steps']):

            perturbation_optimizer.zero_grad()
            F_scores = (self.classifier(adv_imgs) * compute_output_matrix).sum(dim=1)
            F_scores = F_scores.reshape(
                self.hparams['batch_size'],
                self.num_classes
            )
            
            loss = F_scores.sum()
            loss.backward()
            perturbation_optimizer.step()
            adv_imgs.data = projection(adv_imgs.data)

        final_F_scores = (self.classifier(adv_imgs) * compute_output_matrix).sum(dim=1)
        final_F_scores = final_F_scores.reshape(
            self.hparams['batch_size'],
            self.num_classes
        )

        softmax_F_scores = F.softmax(
            self.hparams['sbeta_temperature'] * final_F_scores, 
            dim=1)
        return softmax_F_scores, adv_imgs
    

class BETA_Linf(Attack_Linf):
    def __init__(self, classifier, hparams, device, num_classes):
        super(BETA_Linf, self).__init__(classifier, hparams, device)

        self.eps = self.hparams['epsilon']
        self.num_classes = num_classes

        self.ones_row = torch.ones(
            self.num_classes
        ).long().to(self.device)
        self.ones_col = torch.ones(
            (self.num_classes, 1, 1, 1)
        ).long().to(self.device)

        # shape: [B * n_classes, n_classes]
        self.target_class_matrix = torch.eye(
            self.num_classes).repeat(self.hparams['batch_size'], 1).to(self.device)

    def forward(self, imgs, labels):

        # extended_imgs is [B * n_classes, C, H, W]
        extended_imgs = torch.kron(imgs, self.ones_col)
        extended_labels = torch.kron(labels, self.ones_row)

        # B * n_classes
        extended_batch_size = extended_labels.shape[0]

        lower_img_bound = torch.clamp(extended_imgs - self.eps, min=0.0)
        upper_img_bound = torch.clamp(extended_imgs + self.eps, max=1.0)

        def projection(t):
            return torch.minimum(
                torch.maximum(t, lower_img_bound),
                upper_img_bound
            )
        
        real_class_matrix = torch.zeros(
            extended_batch_size, 
            self.num_classes).to(self.device)
        
        labels_arange = torch.arange(extended_batch_size).to(self.device)
        real_class_matrix[labels_arange, extended_labels] = 1

        # [B * n_classes, n_classes]
        try:
            compute_output_matrix = real_class_matrix - self.target_class_matrix
        except RuntimeError:
            print(imgs.shape, extended_imgs.shape)
            print(real_class_matrix.shape)
            print(self.target_class_matrix.shape)
            quit()

        noise = torch.rand_like(extended_imgs)
        adv_imgs = noise * lower_img_bound + (1 - noise) * upper_img_bound
        adv_imgs = adv_imgs.detach().requires_grad_()
        perturbation_optimizer = optim.RMSprop(
            [adv_imgs,], 
            lr=self.hparams['beta_lr'],
            # amsgrad=True
        )

        for _ in range(self.hparams['beta_n_steps']):

            perturbation_optimizer.zero_grad()
            F_scores = (self.classifier(adv_imgs) * compute_output_matrix).sum(dim=1)
            F_scores = F_scores.reshape(
                self.hparams['batch_size'],
                self.num_classes
            )
            
            loss = F_scores.sum()
            loss.backward()
            perturbation_optimizer.step()
            adv_imgs.data = projection(adv_imgs.data)

        final_F_scores = (self.classifier(adv_imgs) * compute_output_matrix).sum(dim=1)
        final_F_scores = final_F_scores.reshape(
            self.hparams['batch_size'],
            self.num_classes
        )

        which = torch.argmin(final_F_scores, dim=1)
        adv_imgs = torch.reshape(
            adv_imgs,
            (self.hparams['batch_size'], self.num_classes) + adv_imgs.shape[1:]
        )

        adv_examples = adv_imgs[torch.arange(self.hparams['batch_size']), which, :, :, :]
        return adv_examples


class PGD_Linf(Attack_Linf):
    def __init__(self, classifier, hparams, device):
        super(PGD_Linf, self).__init__(classifier, hparams, device)
    
    def forward(self, imgs, labels):
        self.classifier.eval()

        adv_imgs = imgs.detach() # + 0.001 * torch.randn(imgs.shape).to(self.device).detach() #AR: is this detach necessary?
        for _ in range(self.hparams['pgd_n_steps']):
            adv_imgs.requires_grad_(True)
            with torch.enable_grad():
                adv_loss = F.cross_entropy(self.classifier(adv_imgs), labels)
            grad = torch.autograd.grad(adv_loss, [adv_imgs])[0].detach()
            adv_imgs = adv_imgs + self.hparams['pgd_step_size']* torch.sign(grad)
            adv_imgs = self._clamp_perturbation(imgs, adv_imgs)
            
        self.classifier.train()
        return adv_imgs.detach()    # this detach may not be necessary


class SmoothAdv(Attack_Linf):
    def __init__(self, classifier, hparams, device):
        super(SmoothAdv, self).__init__(classifier, hparams, device)

    def sample_deltas(self, imgs):
        sigma = self.hparams['rand_smoothing_sigma']
        return sigma * torch.randn_like(imgs)
    
    def forward(self, imgs, labels):
        self.classifier.eval()

        adv_imgs = imgs.detach()
        for _ in range(self.hparams['rand_smoothing_n_steps']):
            adv_imgs.requires_grad_(True)
            loss = 0.
            for _ in range(self.hparams['rand_smoothing_n_samples']):
                deltas = self.sample_deltas(imgs)
                loss += F.softmax(self.classifier(adv_imgs + deltas), dim=1)[range(imgs.size(0)), labels]

            total_loss = -1. * torch.log(loss / self.hparams['rand_smoothing_n_samples']).mean()
            grad = torch.autograd.grad(total_loss, [adv_imgs])[0].detach()
            adv_imgs = imgs + self.hparams['rand_smoothing_step_size'] * torch.sign(grad)
            adv_imgs = self._clamp_perturbation(imgs, adv_imgs)

        self.classifier.train()
        return adv_imgs.detach()    # this detach may not be necessary


class TRADES_Linf(Attack_Linf):
    def __init__(self, classifier, hparams, device):
        super(TRADES_Linf, self).__init__(classifier, hparams, device)
        self.kl_loss_fn = nn.KLDivLoss(reduction='batchmean')  # AR: let's write a method to do the log-softmax part

    def forward(self, imgs, labels):
        self.classifier.eval()

        adv_imgs = imgs.detach() + 0.001 * torch.randn(imgs.shape).to(self.device).detach()  #AR: is this detach necessary?
        for _ in range(self.hparams['trades_n_steps']):
            adv_imgs.requires_grad_(True)
            with torch.enable_grad():
                adv_loss = self.kl_loss_fn(
                    F.log_softmax(self.classifier(adv_imgs), dim=1),   # AR: Note that this means that we can't have softmax at output of classifier
                    F.softmax(self.classifier(imgs), dim=1))
            
            grad = torch.autograd.grad(adv_loss, [adv_imgs])[0].detach()
            adv_imgs = adv_imgs + self.hparams['trades_step_size']* torch.sign(grad)
            adv_imgs = self._clamp_perturbation(imgs, adv_imgs)
        
        self.classifier.train()
        return adv_imgs.detach() # this detach may not be necessary


class FGSMBase(Attack):
    def __init__(self, classifier, hparams, device, add_noise):
        super(FGSMBase, self).__init__(classifier, hparams, device)
        self.add_noise = add_noise

    @staticmethod
    def uniform_like(tensor, lower, upper):
        return (lower - upper) * torch.rand_like(tensor.detach()) + upper

    def forward(self, imgs, labels):
        self.classifier.eval()

        imgs.requires_grad = True
        adv_loss = F.cross_entropy(self.classifier(imgs), labels)
        grad = torch.autograd.grad(adv_loss, [imgs])[0].detach()
        adv_imgs = imgs + self.hparams['epsilon'] * grad.sign()

        if self.add_noise is True:
            adv_imgs += self.uniform_like(
                tensor=adv_imgs, 
                lower=-self.hparams['epsilon'],
                upper=self.hparams['epsilon']).to(self.device)

        adv_imgs = torch.clamp(adv_imgs, 0.0, 1.0)

        self.classifier.train()

        return adv_imgs.detach()

    # def forward(self, imgs, labels):
    #     self.classifier.eval()

    #     if self.add_noise is True:
    #         delta = self.uniform_like(
    #             tensor=imgs, 
    #             lower=-self.hparams['epsilon'],
    #             upper=self.hparams['epsilon']).to(self.device)
    #     else:
    #         delta = torch.zeros_like(imgs.detach()).to(self.device)

    #     delta.requires_grad = True
    #     adv_loss = F.cross_entropy(self.classifier(imgs + delta), labels)
    #     grad = torch.autograd.grad(adv_loss, [delta])[0].detach()

    #     delta.data = torch.clamp(
    #         delta + self.hparams['epsilon'] * grad.sign()
    #         -self.hparams['epsilon'],
    #         self.hparams['epsilon'])

    #     adv_imgs = torch.clamp(imgs + delta, 0.0, 1.0)

    #     self.classifier.train()

    #     return adv_imgs.detach()

class FGSM_Linf(FGSMBase):
    def __init__(self, classifier, hparams, device):
        super(FGSM_Linf, self).__init__(
            classifier=classifier, 
            hparams=hparams, 
            device=device, 
            add_noise=False)


class Noisy_FGSM_Linf(FGSMBase):
    def __init__(self, classifier, hparams, device):
        super(FGSM_Linf, self).__init__(
            classifier=classifier, 
            hparams=hparams, 
            device=device, 
            add_noise=True)


class LMC_Gaussian_Linf(Attack_Linf):
    def __init__(self, classifier, hparams, device):
        super(LMC_Gaussian_Linf, self).__init__(classifier, hparams, device)

    def forward(self, imgs, labels):
        self.classifier.eval()
        batch_size = imgs.size(0)

        adv_imgs = imgs.detach() + 0.001 * torch.randn(imgs.shape).to(self.device).detach() #AR: is this detach necessary?
        for _ in range(self.hparams['g_dale_n_steps']):
            adv_imgs.requires_grad_(True)
            with torch.enable_grad():
                adv_loss = torch.log(1 - torch.softmax(self.classifier(adv_imgs), dim=1)[range(batch_size), labels]).mean()
                # adv_loss = F.cross_entropy(self.classifier(adv_imgs), labels)
            grad = torch.autograd.grad(adv_loss, [adv_imgs])[0].detach()
            noise = torch.randn_like(adv_imgs).to(self.device).detach()

            adv_imgs = adv_imgs + self.hparams['g_dale_step_size'] * torch.sign(grad) + self.hparams['g_dale_noise_coeff'] * noise
            adv_imgs = self._clamp_perturbation(imgs, adv_imgs)
            
        self.classifier.train()

        return adv_imgs.detach()


class LMC_Laplacian_Linf(Attack_Linf):
    def __init__(self, classifier, hparams, device):
        super(LMC_Laplacian_Linf, self).__init__(classifier, hparams, device)

    def forward(self, imgs, labels):
        self.classifier.eval()
        batch_size = imgs.size(0)
        noise_dist = Laplace(torch.tensor(0.), torch.tensor(1.))

        adv_imgs = imgs.detach() + 0.001 * torch.randn(imgs.shape).to(self.device).detach() #AR: is this detach necessary?
        for _ in range(self.hparams['l_dale_n_steps']):
            adv_imgs.requires_grad_(True)
            with torch.enable_grad():
                adv_loss = torch.log(1 - torch.softmax(self.classifier(adv_imgs), dim=1)[range(batch_size), labels]).mean()
            grad = torch.autograd.grad(adv_loss, [adv_imgs])[0].detach()
            noise = noise_dist.sample(grad.shape)
            adv_imgs = adv_imgs + self.hparams['l_dale_step_size'] * torch.sign(grad + self.hparams['l_dale_noise_coeff'] * noise)
            adv_imgs = self._clamp_perturbation(imgs, adv_imgs)

        self.classifier.train()
        return adv_imgs.detach()

