import torch
import argparse
from torch.utils.data import DataLoader
import os
import pandas as pd
import time
import pickle as pkl

from dataset import AdvDataset
from utils import BASE_ADV_PATH, ROOT_PATH
import methods
import random
import numpy as np

from config import attack_config

def arg_parse(config = None):
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--attack', type=str, default='PGD', help='the name of specific attack method')
    parser.add_argument('--gpu', type=str, default='0', help='gpu device.')
    parser.add_argument('--batch_size', type=int, default=20, metavar='N',
                    help='input batch size for reference (default: 16)')
    parser.add_argument('--model_name', type=str, default='', help='')
    parser.add_argument('--weight_path', type=str, default=None, help='')
    parser.add_argument('--filename_prefix', type=str, default='', help='')
    if config:
        parser.set_defaults(**config)
    args = parser.parse_args()

    #:
    if args.weight_path is not None:
        parts = args.weight_path.split('/')
        args.opt_path = os.path.join(BASE_ADV_PATH,f'{parts[-4]}', f'{parts[-3]}',f'{parts[-2]}', 'model_{}_{}'.format(parts[-1].split('.')[0], args.attack))
        #: base KL loss
        # args.opt_path = os.path.join(BASE_ADV_PATH, f'{parts[-3]}',f'{parts[-2]}', 'model_{}_{}'.format(parts[-1].split('.')[0], args.attack))

    else:
        #: pretrained model
        args.opt_path = os.path.join(BASE_ADV_PATH, 'pretrained', 'model_{}-method_{}'.format(args.model_name, args.attack))

    print(f'Saving adversarial examples in {args.opt_path}')

    if not os.path.exists(args.opt_path):
        os.makedirs(args.opt_path)
    return args

def set_seed(seed):
    random.seed(seed)  
    np.random.seed(seed)  
    torch.manual_seed(seed) 
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)  
        torch.cuda.manual_seed_all(seed)  
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

if __name__ == '__main__':
    args = arg_parse(config=attack_config)
    print(args)

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    set_seed(42)
    
    # loading dataset
    dataset = AdvDataset(args.model_name, os.path.join(ROOT_PATH, 'clean_resized_images')) 
    data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)
    print (args.attack, args.model_name)
    
    # Attack
    if args.weight_path is None:
        attack_method = getattr(methods, args.attack)(model_name=args.model_name)
    else:
        attack_method = getattr(methods, args.attack)(model_name=args.model_name, pre_trained=False, weight_path=args.weight_path)

    # Main
    all_loss_info = {}
    for batch_idx, batch_data in enumerate(data_loader):
        if batch_idx%100 == 0:
            print ('Runing batch_idx', batch_idx)
        batch_x = batch_data[0]
        batch_y = batch_data[1]
        batch_name = batch_data[3]

        adv_inps, loss_info = attack_method(batch_x, batch_y)
        attack_method._save_images(adv_inps, batch_name, args.opt_path)
        if loss_info is not None:
            all_loss_info[batch_name] = loss_info
    if loss_info is not None:
        with open(os.path.join(args.opt_path, 'loss_info.json'), 'wb') as opt:
            pkl.dump(all_loss_info, opt)
