import torch
import argparse
from torch.utils.data import DataLoader
import os
import pandas as pd
import time
import pickle as pkl

from dataset import AdvDataset
from utils import BASE_ADV_PATH, ROOT_PATH
import methods_ens
import random
import numpy as np

from config import ens_attack_config

def arg_parse(config = None):
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--attack', type=str, default='PGD', help='the name of specific attack method')
    parser.add_argument('--gpu', type=str, default='0', help='gpu device.')
    parser.add_argument('--batch_size', type=int, default=20, metavar='N',
                    help='input batch size for reference (default: 16)')
    parser.add_argument('--model_name', nargs='+', default='', help='Enter model names separated by space')
    parser.add_argument('--weight_path', nargs='+', default=None, help='')
    parser.add_argument('--filename_prefix', type=str, default='', help='')
    parser.add_argument('--mode', type=str, default='pretrained', help='Enable alignment')
    if config:
        parser.set_defaults(**config)
    args = parser.parse_args()
    
    model_names = args.model_name
    
 
    model_names_str = '#'.join(model_names)
    
    args.opt_path = os.path.join(BASE_ADV_PATH, 'ENS', args.mode, args.attack, f'models_{model_names_str}')

    print(f'Saving adversarial examples in {args.opt_path}')

    if not os.path.exists(args.opt_path):
        os.makedirs(args.opt_path)
    return args

def set_seed(seed):
    random.seed(seed)  
    np.random.seed(seed) 
    torch.manual_seed(seed) 
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed) 
        torch.cuda.manual_seed_all(seed) 
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

if __name__ == '__main__':
    args = arg_parse(config=ens_attack_config)
    print(args)

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    set_seed(42)
    
    # loading dataset
    dataset = AdvDataset(args.model_name, os.path.join(ROOT_PATH, 'clean_resized_images')) 
    data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)
    print (args.attack, args.model_name)
    
    model_names = args.model_name
    weight_paths = args.weight_path
    
    attack_method = getattr(methods_ens, args.attack)(model_name=model_names, weight_path=weight_paths)

    # Main
    all_loss_info = {}
    for batch_idx, batch_data in enumerate(data_loader):
        if batch_idx%100 == 0:
            print ('Runing batch_idx', batch_idx)
        batch_x = batch_data[0]
        batch_y = batch_data[1]
        batch_name = batch_data[3]

        adv_inps, loss_info = attack_method(batch_x, batch_y)
        attack_method._save_images(adv_inps, batch_name, args.opt_path)
        if loss_info is not None:
            all_loss_info[batch_name] = loss_info
    if loss_info is not None:
        with open(os.path.join(args.opt_path, 'loss_info.json'), 'wb') as opt:
            pkl.dump(all_loss_info, opt)
