import argparse
import os
import csv

import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import time

import torchvision.models as models
from generators import GeneratorResnet
from gaussian_smoothing import *

from my_dataset import sample_from_imagenet_val


import logging
logger = logging.getLogger(__name__)

parser = argparse.ArgumentParser(description='Targeted Transferable Perturbations')
parser.add_argument('--test_dir', default='../../../data/IN/val')
parser.add_argument('--batch_size', type=int, default=100, help='Batch size for evaluation')
parser.add_argument('--eps', type=int, default=16, help='Perturbation Budget')
parser.add_argument('--target_model', type=str, default='vgg19', help='Black-Box(unknown) model: SIN, Augmix etc')
parser.add_argument('--target', type=int, default=23, help='Target label to transfer')
parser.add_argument('--source_model', type=str, default='resnet50', help='TTP Discriminator: \
{res18, res50, res101, res152, dense121, dense161, dense169, dense201,\
 vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19, vgg19_bn,\
 ens_vgg16_vgg19_vgg11_vgg13_all_bn,\
 ens_res18_res50_res101_res152\
 ens_dense121_161_169_201}')
parser.add_argument('--source_domain', type=str, default='IN', help='Source Domain (TTP): Natural Images (IN) or painting')

args = parser.parse_args()

store_data = []
target_models = ['resnet18', 'resnet50', 'resnet101', 'densenet121', 'densenet161', 'vgg16_bn', 'vgg19_bn', 'mobilenet_v2', 'vit_b_16',]
store_data.append(['target ID', ] + target_models)

epoch = 20

for tarID in [23, 54, 60, 124, 344, 443, 465, 600, 642, 744, 769, 885]:
    args.target = tarID
    acc_all = [tarID]
    for args.target_model in target_models:
        print(args)
        t1 = time.time()
        # GPU
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        

        eps = args.eps/255.0

        # Set-up Kernel
        kernel_size = 3
        pad = 2
        sigma = 1
        kernel = get_gaussian_kernel(kernel_size=kernel_size, pad=pad, sigma=sigma).to(device)


        # Load pretrained Generator
        netG = GeneratorResnet()
        # netG = nn.DataParallel(netG)./
        netG.load_state_dict(torch.load(f'pretrained_generators/generator_name.pth'))
        netG = nn.DataParallel(netG)
        netG = netG.to(device)
        netG.eval()

        # Load Targeted Model
        model_names = sorted(name for name in models.__dict__
            if name.islower() and not name.startswith("__")
            and callable(models.__dict__[name]))

        if args.target_model in model_names and args.target_model != 'deit_b':
            model = models.__dict__[args.target_model](pretrained=True)
        elif args.target_model == 'SIN':
            model = torchvision.models.resnet50(pretrained=False)
            model = torch.nn.DataParallel(model)
            checkpoint = torch.load('pretrained_models/resnet50_train_60_epochs-c8e5653e.pth.tar')
            model.load_state_dict(checkpoint["state_dict"])
        elif args.target_model == 'Augmix':
            model = torchvision.models.resnet50(pretrained=False)
            model = torch.nn.DataParallel(model)
            checkpoint = torch.load('pretrained_models/checkpoint.pth.tar')
            model.load_state_dict(checkpoint["state_dict"])
        else:
            assert (args.target_model in model_names), 'Please provide correct target model names: {}'.format(model_names)
        model = nn.DataParallel(model)
        model = model.to(device)
        model.eval()


        ####################
        # Data
        ####################
        # Input dimensions
        scale_size = 256
        img_size = 224
        data_transform = transforms.Compose([
            transforms.Resize(scale_size),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
        ])

        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        def normalize(t):
            t[:, 0, :, :] = (t[:, 0, :, :] - mean[0])/std[0]
            t[:, 1, :, :] = (t[:, 1, :, :] - mean[1])/std[1]
            t[:, 2, :, :] = (t[:, 2, :, :] - mean[2])/std[2]
            return t

        im_path = 'path to test images'

        all_data = sample_from_imagenet_val(im_path)
        test_im_path = []
        test_im_lbl = []
        for idx in range(45000,len(all_data)):
            test_im_path.append(all_data.im_path[idx])
            test_im_lbl.append(all_data.labels[idx])

        test_data = sample_from_imagenet_val(im_path)
        test_data.labels = test_im_lbl
        test_data.im_path = test_im_path
        test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=0,
                                                    pin_memory=True)

        test_size = len(test_data)
        print('Test data size:', test_size)

        acc = 0
        distance = 0

        logger.info('Target_model \t Epsilon \t Target \t Acc. \t Distance')
        for i, (img, _, label) in enumerate(test_loader):
            # print('At Batch:', i)
            img, label = img.to(device), label.to(device)

            target_label = torch.LongTensor(img.size(0))
            target_label.fill_(args.target)
            target_label = target_label.to(device)

            adv = kernel(netG(img)).detach()
            adv = torch.min(torch.max(adv, img - eps), img + eps)
            adv = torch.clamp(adv, 0.0, 1.0)


            out = model(normalize(adv.clone().detach()))
            acc += torch.sum(out.argmax(dim=-1) == target_label).item()
            # print(torch.sum(out.argmax(dim=-1) == target_label).item(), acc)

            distance +=(img - adv).max() *255

        accuracy = 100*acc / test_size
        print(f'{args.target_model}:  Accuracy={accuracy},  totalSample:{test_size},   fooled:{acc}')
        acc_all.append(accuracy)
        logger.info('%s   \t %d             %d\t  %.4f\t \t %.4f',
                    args.target_model, int(eps * 255), args.target, acc / test_size, distance / (i + 1))
        print('Required time:', time.time()-t1)
    store_data.append(acc_all)
    
    folder_name = 'results'
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)
    
    file_path = f'{folder_name}/file_name.csv'

    with open(file_path, 'w', newline='') as csvfile:
        csvwriter = csv.writer(csvfile)
        csvwriter.writerows(store_data)
