import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import os
import random
import scipy.stats as st
import copy
from utils import ROOT_PATH
from functools import partial
import copy
import pickle as pkl
from torch.autograd import Variable
import torch.nn.functional as F

from dataset import params
from model import get_model
from utils_dct import dct_2d, idct_2d

class BaseAttack(object):
    def __init__(self, attack_name, model_name, target, pre_trained=True, weight_path=None):
        self.attack_name = attack_name
        self.model_name = model_name
        self.target = target
        if self.target:
            self.loss_flag = -1
        else:
            self.loss_flag = 1

        self.used_params = params(self.model_name) 

        #: loading model
        if weight_path is None:
            self.model = [get_model(model,pre_trained=pre_trained,weight_path=weight_path).cuda().eval() for model in self.model_name]
        else:
            self.model = []
            for model, path in zip(self.model_name, weight_path):
                self.model.append(get_model(model,pre_trained=False,weight_path=path).cuda().eval())


    def forward(self, *input):
        """
        Rewrite
        """
        raise NotImplementedError
    
    def _mul_std_add_mean(self, inps):
        dtype = inps.dtype
        mean = torch.as_tensor(self.used_params['mean'], dtype=dtype).cuda()
        std = torch.as_tensor(self.used_params['std'], dtype=dtype).cuda()
        inps.mul_(std[:,None, None]).add_(mean[:,None,None])
        return inps

    def _sub_mean_div_std(self, inps):
        dtype = inps.dtype
        mean = torch.as_tensor(self.used_params['mean'], dtype=dtype).cuda()
        std = torch.as_tensor(self.used_params['std'], dtype=dtype).cuda()
        inps = (inps - mean[:,None,None])/std[:,None,None]
        return inps


    def _save_images(self, inps, filenames, output_dir):
        unnorm_inps = self._mul_std_add_mean(inps)
        for i,filename in enumerate(filenames):
            save_path = os.path.join(output_dir, filename)
            image = unnorm_inps[i].permute([1,2,0]) # c,h,w to h,w,c
            image[image<0] = 0
            image[image>1] = 1
            image = Image.fromarray((image.detach().cpu().numpy()*255).astype(np.uint8))
            # print ('Saving to ', save_path)
            image.save(save_path)

    def _update_inps(self, inps, grad, step_size):
        unnorm_inps = self._mul_std_add_mean(inps.clone().detach())
        unnorm_inps = unnorm_inps + step_size * grad.sign()
        unnorm_inps = torch.clamp(unnorm_inps, min=0, max=1).detach()
        adv_inps = self._sub_mean_div_std(unnorm_inps)
        return adv_inps

    def _update_perts(self, perts, grad, step_size):
        perts = perts + step_size * grad.sign()
        perts = torch.clamp(perts, -self.epsilon, self.epsilon)
        return perts

    def _return_perts(self, clean_inps, inps):
        clean_unnorm = self._mul_std_add_mean(clean_inps.clone().detach())
        adv_unnorm = self._mul_std_add_mean(inps.clone().detach())
        return adv_unnorm - clean_unnorm

    def __call__(self, *input, **kwargs):
        images = self.forward(*input, **kwargs)
        return images

class PGD(BaseAttack):
    def __init__(self, model_name, steps=20, epsilon=4/255, target=False, decay=1.0, pre_trained=True, weight_path=False):
        super(PGD, self).__init__('PGD', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        # self.step_size = self.epsilon/self.steps
        self.step_size = 1/255
        self.decay = decay
        self.image_size = 224
        
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        loss = nn.CrossEntropyLoss()

        momentum = torch.zeros_like(inps).cuda()
        unnorm_inps = self._mul_std_add_mean(inps)
        

        perts = torch.zeros_like(unnorm_inps).cuda()
        perts.requires_grad_()

        for i in range(self.steps):

            outputs_ = [model((self._sub_mean_div_std(unnorm_inps + perts))) for idx, model in enumerate(self.model)]
            outputs = torch.mean(torch.stack(outputs_), dim=0)  # 求平均

            cost = self.loss_flag * loss(outputs, labels).cuda()
            cost.backward()
            grad = perts.grad.data
            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None

class MI(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False):
        super(MI, self).__init__('MI', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps
        self.decay = decay
        self.image_size = 224
        
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        loss = nn.CrossEntropyLoss()

        momentum = torch.zeros_like(inps).cuda()
        unnorm_inps = self._mul_std_add_mean(inps)

        perts = torch.zeros_like(unnorm_inps).cuda()
        perts.requires_grad_()

        for i in range(self.steps):
    
            outputs_ = [model((self._sub_mean_div_std(unnorm_inps + perts))) for model in self.model]
            outputs = torch.mean(torch.stack(outputs_), dim=0)  # 求平均

            cost = self.loss_flag * loss(outputs, labels).cuda()
            cost.backward()
            grad = perts.grad.data
            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)

            grad += momentum*self.decay
            momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None

class SSA_TI_DI_MI_FGSM(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False,number = 20,ti_kernel_size=5):
        super(SSA_TI_DI_MI_FGSM, self).__init__('SSA_TI_DI_MI_FGSM', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps
        # self.step_size = 1/255
        self.decay = decay
        self.image_size = 224
        self.number = number

        self.ti_kernel_size=ti_kernel_size
        kernel = gkern(ti_kernel_size, 3).astype(np.float32)
        gaussian_kernel = np.stack([kernel, kernel, kernel])
        gaussian_kernel = np.expand_dims(gaussian_kernel, 1)
        self.gaussian_kernel = torch.from_numpy(gaussian_kernel).cuda()
        
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        loss = nn.CrossEntropyLoss()

        momentum = torch.zeros_like(inps).cuda()
        unnorm_inps = self._mul_std_add_mean(inps)
        perts = torch.zeros_like(unnorm_inps).cuda()

        for i in range(self.steps):
            grad = 0
            img_x = unnorm_inps + perts
            for _ in range(self.number):
                with torch.no_grad():
                    gauss = torch.randn_like(img_x) * self.epsilon
                    gauss = gauss.cuda()
                    x_dct = dct_2d(img_x + gauss)
                    mask = (torch.rand_like(img_x) + 0.5).cuda()
                    x_idct = idct_2d(x_dct * mask).detach()
                
                x_idct.requires_grad_()  # 启用对 x_idct 的梯度计算
                outputs_ = []
                for idx, model in enumerate(self.model):
                    if self.model_name[idx]  in ['vit_base_patch16_224', 'swin_tiny_patch4_window7_224'] :
                        outputs_.append(model(self._sub_mean_div_std(F.interpolate(DI(x_idct), size=(224, 224), mode='bilinear'))))
                    else:
                        outputs_.append(model((self._sub_mean_div_std(DI(x_idct)))))
                outputs = torch.mean(torch.stack(outputs_), dim=0)  # 求平均
                cost = self.loss_flag * loss(outputs, labels).cuda()
                cost.backward()
                
                grad += x_idct.grad.data

            grad = grad / self.number
            grad = F.conv2d(grad, self.gaussian_kernel, bias=None, stride=1, padding=((self.ti_kernel_size-1)//2,(self.ti_kernel_size-1)//2), groups=3) 

            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)

            grad += momentum*self.decay
            momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
    
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None

class NI(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False):
        super(NI, self).__init__('NI', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps
        # self.step_size = 1/255
        self.decay = decay
        self.image_size = 224
        
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        loss = nn.CrossEntropyLoss()

        momentum = torch.zeros_like(inps).cuda()
        unnorm_inps = self._mul_std_add_mean(inps)

        perts = torch.zeros_like(unnorm_inps).cuda()
        perts.requires_grad_()

        for i in range(self.steps):
            xadv = unnorm_inps + perts
            xnes = xadv + self.step_size * momentum.sign()
            outputs = self.model((self._sub_mean_div_std(xnes)))
            cost = self.loss_flag * loss(outputs, labels).cuda()
            cost.backward()
            grad = perts.grad.data
            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)

            grad += momentum*self.decay
            momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None

def DI(X_in,prob=0.7):
    rnd = np.random.randint(299, 330,size=1)[0]
    h_rem = 330 - rnd
    w_rem = 330 - rnd
    pad_top = np.random.randint(0, h_rem,size=1)[0]
    pad_bottom = h_rem - pad_top
    pad_left = np.random.randint(0, w_rem,size=1)[0]
    pad_right = w_rem - pad_left
    c = np.random.rand(1)
    if c <= prob:
        X_out = F.pad(F.interpolate(X_in, size=(rnd,rnd)),(pad_left,pad_top,pad_right,pad_bottom),mode='constant', value=0)
        return  X_out 
    else:
        return  X_in
    
def gkern(kernlen=15, nsig=3):
    x = np.linspace(-nsig, nsig, kernlen)
    kern1d = st.norm.pdf(x)
    kernel_raw = np.outer(kern1d, kern1d)
    kernel = kernel_raw / kernel_raw.sum()
    return kernel
