import torch
from my_utils.utils import load_json
from os.path import join

class MyPGD():
    def __init__(self, dataset_name, pgd=1, alpha=0.001, beta=3.0, gamma=0.1, args=None):
        """
        alpha: PGD的下降时候的系数
        beta: [mean-beta*std, mean+beta*std]
        gamma: 每次更新perturbation的时候，新的扰动的权重
        """
        self.dataset_name = dataset_name
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.statistics = load_json(join('statistics', f'{dataset_name}.json'))
        self.pgd = pgd
        self.args = args
        self.mean = self.statistics['mean']
        self.std = self.statistics['std']
        self.min = self.mean - self.beta * self.std
        self.max = self.mean + self.beta * self.std
    
    def step(self, attack_module):
        with torch.no_grad():
            # 这里我觉得是减号，先下降
            p = attack_module.perturbation - self.alpha * torch.sign(attack_module.perturbation.grad)
            # 计算新的被扰动的数据，并且clip
            # 使用全局的feature
            if self.pgd == 1:
                perturbed_input_new = torch.clip(attack_module.batch['past_target']+p, self.min, self.max)
            # 使用每个样本的statistics
            elif self.pgd == 2:
                my_min = attack_module.mean - self.beta * attack_module.std
                my_max = attack_module.mean + self.beta * attack_module.std
                perturbed_input_new = []
                # range of batch size
                for i in range(attack_module.perturbation.shape[0]):
                    perturbed_input_new.append(torch.clip(attack_module.batch['past_target'][i]+p[i], my_min[i], my_max[i]))
                perturbed_input_new = torch.stack(perturbed_input_new, dim=0)
            # 使用这个区间的中点+epsilon
            elif self.pgd == 3:
                center = 0.5 * (attack_module.batch['past_target'][:,self.args.start] + attack_module.batch['past_target'][:,self.args.start+self.args.length-1])
                my_min = center - self.args.epsilon
                my_max = center + self.args.epsilon
                perturbed_input_new = []
                # range of batch size
                for i in range(attack_module.perturbation.shape[0]):
                    perturbed_input_new.append(torch.clip(attack_module.batch['past_target'][i]+p[i], my_min[i], my_max[i]))
                perturbed_input_new = torch.stack(perturbed_input_new, dim=0)
            else:
                raise NotImplementedError(f"PGD mode = {self.pgd} is not implemented.")
            # 更新为新的扰动
            attack_module.perturbation.data = self.gamma * (perturbed_input_new - attack_module.batch['past_target']) + (1 - self.gamma) * attack_module.perturbation.data

