from torch.utils.data import DataLoader
import load_datasets
import os
import numpy as np
import torch
from utils import Utils
from torchvision.utils import save_image
from get_model import GetModel
from fourier_attack import FourierAttack
from piq import multi_scale_ssim, LPIPS, mdsi
import random
import argparse

'''
    Possible models
    ['resnet50_1k', 'resnet152_1k', 'bit_large', 'ViT-B', 'ViT-B-1k', 
    'ViT-L', 'Swin', 'DeiT-S', DeiT-s-nodist']
'''

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class FourierAdversarialAttack:
    def __init__(self, model_name, attack, lam, iteration, lr, decay, endure_thres, src_path, label_path,
                 dst_path, save, save_num, seed):
        random.seed(seed)
        torch.manual_seed(seed)
        np.random.seed(seed)

        '''make dir for dst path'''
        Utils.checkDir(dst_path)

        if model_name == "bit_large":
            self.img_shape = (480, 480)
        else:
            self.img_shape = (224, 224)

        '''setup'''
        self.save = save
        self.save_num = save_num
        self.incorrect = 0
        self.save_idx = 0
        self.dst_path = dst_path
        self.model_name = model_name

        '''adversarial environment'''
        self.lam = lam

        '''Model'''
        self.model, self.transform, self.normalize, self.invNormalize = GetModel.getModel(model_name, pretrained=True)
        self.model.to(device)

        '''attacker'''
        self.attacker = FourierAttack(attack=attack, model=self.model, lam=lam, iteration=iteration, lr=lr, weight_decay=decay, endure_thres=endure_thres,
                                      normalize=self.normalize, invNormalize=self.invNormalize, device=device)
        '''Data Load'''
        self.datasets = load_datasets.AdversarialTrainDataset(src_path, label_path, self.transform)
        self.loader = DataLoader(self.datasets, batch_size=1, shuffle=False, num_workers=0, drop_last=False)

    '''Predict code for model'''
    def predict(self):
        self.model.eval()
        total = len(self.datasets)
        correct = 0
        with torch.no_grad():
            for images, labels in self.loader:
                output = self.model(self.normalize(images))
                predicted = torch.argmax(output, axis=1)
                correct += (predicted == labels).sum().item()
        print("accuracy: %f" % (correct / total))

    def save_img(self, img, dst_path, mode, model_name, idx=0, inv_norm=None):
        dst_path = os.path.join(dst_path, mode)
        Utils.checkDir(dst_path)
        if model_name:
            dst_path = os.path.join(dst_path, model_name)
            Utils.checkDir(dst_path)
        if inv_norm:
            img = inv_norm(img)
            img = torch.clamp(img, 0, 1)
        '''save image'''
        save_image(img, "%s/%s_%d.png"%(dst_path, mode, idx), nrow=4)

    def attack(self):
        psnr_list, ssim_list, acc_list = [], [], []
        lpips_list, mdsi_list = [], []
        save_idx = 1
        check_idx = 0

        print('Start find adversarial examples')
        for batch_x, batch_y in self.loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)

            '''Attack'''
            adv_img, found = self.attacker.fourierAttack(batch_x, batch_y)

            '''found adversarial Examples'''
            if found:
                self.incorrect += 1

            '''Quality metrics'''
            mdsi_list.append(torch.mean(mdsi(batch_x, adv_img), dim=0).item())
            lpips_list.append(torch.mean(LPIPS()(batch_x, adv_img), dim=0).item())
            ssim_list.append(torch.mean(multi_scale_ssim(batch_x, adv_img), dim=0).item())
            psnr_list.append(torch.mean(Utils.psnr(batch_x, adv_img), dim=0).item())

            '''for save pertubation img'''
            if self.save and found:
                '''forier transform for drawing distribution'''
                perturb = torch.subtract(adv_img, batch_x)
                perturb_fourier = torch.fft.fftshift(torch.fft.fft2(perturb))
                real_part, imag_part = perturb_fourier.real, perturb_fourier.imag
                perturb_mag = torch.sqrt(imag_part ** 2 + real_part ** 2)
                Utils.getFreqDistribution(perturb_mag, dst_path=os.path.join(self.dst_path, "hist"), model_name=self.model_name, idx=save_idx)

                self.save_img(batch_x, self.dst_path ,"origin", model_name=None, idx=save_idx,
                              inv_norm=None) # original image
                self.save_img(adv_img, self.dst_path, "adv", model_name=self.model_name,
                              idx=save_idx, inv_norm=None) # attack image
                save_perturb = torch.clamp(20*torch.abs(perturb), 0, 1)
                self.save_img(save_perturb, self.dst_path, "perturb", model_name=self.model_name,
                              idx=save_idx, inv_norm=None) # perturbation image
                save_idx+=1
                
                if self.save and save_idx > self.save_num:
                    break
            check_idx += 1
        total_accuracy = 1 - (self.incorrect / len(self.loader))

        '''remove wrong calculate data'''
        mdsi_list = Utils.remove_inf(mdsi_list)
        lpips_list = Utils.remove_inf(lpips_list)
        ssim_list = Utils.remove_inf(ssim_list)
        psnr_list = Utils.remove_inf(psnr_list)

        total_mdsi = np.array(mdsi_list).mean()
        total_lpips = np.array(lpips_list).mean()
        total_ssim = np.array(ssim_list).mean()
        total_psnr = np.array(psnr_list).mean()

        print("model: %s, acc: %.3f,  lam: %f, psnr: %f, ssim: %f, mdsi: %f, lpips: %f" %
              (self.model_name, total_accuracy, self.lam, total_psnr, total_ssim, total_mdsi, total_lpips))
        return total_accuracy, total_psnr, total_ssim, total_mdsi, total_lpips

def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--model_name', required=True,
                        default='resnet50_1k', help='name of model, you might want to refer get_model.py')
    parser.add_argument('--attack', required=True, default='phase',
                        help='Attacks in Fourier attack framework, [\'phase\', \'mag\', \'pixel\','
                             ' \'phase+mag\', \'all\']')
    parser.add_argument('--lam', required=False, default=5e+4, type=float, help='weight parameter to MSE Loss')
    parser.add_argument('--iteration', required=False, default=1000, type=int, help='max iteration for finding adversarial examples')
    parser.add_argument('--lr', required=False, default=5e-3, type=float, help='learning rate')
    parser.add_argument('--decay', required=False, default=5e-6, type=float, help='weight decay')
    parser.add_argument('--endure_thres', required=False, default=5, type=int, help='enduring count for no loss decrease')
    parser.add_argument('--save', required=False, default=False, help='save image and distribution')
    parser.add_argument('--save_num', required=False, default=30, type=int, help='number of save data')
    parser.add_argument('--src_path', required=False, default='./samples', help='root dir for dataset')
    parser.add_argument('--label_path', required=False, default='./label.txt', help='path for label file')
    parser.add_argument('--dst_path', required=False, default='./saved', help='dir for saving results')
    parser.add_argument('--seed', required=False, default=7, type=int, help='random seed')
    args = parser.parse_args()
    attack = FourierAdversarialAttack(model_name=args.model_name, attack=args.attack, lam=args.lam,
                                      iteration=args.iteration, lr=args.lr, decay=args.decay, endure_thres=args.endure_thres,
                                      src_path= args.src_path, label_path= args.label_path, dst_path=args.dst_path,save=args.save,
                                      save_num=args.save_num, seed=args.seed)
    # attack.predict()
    attack.attack()

if __name__ == "__main__":
    main()


