import logging
import os
import json
import torch
import numpy as np
import pickle
import argparse
import tqdm

from dataloader import create_dataloader
from log import setup_default_logging
from models import create_model
from utils.utils import torch_seed, AverageMeter, extract_correct, calculate_l2_distance, ImageNet_load_data
from adv_attacks import create_attack
from adv_attacks.TA import attack_mask
from foolbox import PyTorchModel
import clip
from torchvision import transforms
from dataloader import ImageNetValidationDataset_test
from torch.utils.data import TensorDataset, DataLoader

_logger = logging.getLogger('adv sample') 

def make_bucket():
    successed_images = {}
    for k in ['clean', 'adv', 'targets']:
        successed_images[k] = torch.Tensor([])    

    return successed_images

def extract_success(
    bucket, 
    inputs, inputs_adv, targets,
    correct_clean, correct_adv
):
    # check success
    for i in range(inputs.size(0)):
        if correct_clean[i] and not correct_adv[i]:
            bucket['clean'] = torch.cat([bucket['clean'], inputs[[i]].detach().cpu()])
            bucket['adv'] = torch.cat([bucket['adv'], inputs_adv[[i]].detach().cpu()])
            bucket['targets'] = torch.cat([bucket['targets'], targets[[i]].detach().cpu()])
    
    return bucket

def validate(model, num_classes, testloader, adv_method, adv_params, savedir, log_interval=1, device='cpu'):

    clean_acc = AverageMeter()
    adv_acc = AverageMeter()
    
    successed_images = make_bucket()
    if adv_method not in ['TA']:
        atk = create_attack(model, adv_method, adv_params, num_classes=num_classes)

    model.eval()
    fmodel = PyTorchModel(model, bounds=(0,1), device=device)

    for i, (inputs, targets) in tqdm.tqdm(enumerate(testloader), total=len(testloader), desc='Testing...'):
        inputs, targets = inputs.to(device), targets.to(device)

        outputs_init = model(inputs)
        correct_clean_init, correct_clean_pred_init, correct_clean_prob_init, correct_clean_prob_init_all = extract_correct(outputs_init, targets)
        selected_inputs = inputs[correct_clean_init]
        selected_targets = targets[correct_clean_init]

        if adv_method == 'TA':
            TA = attack_mask.TA(model=fmodel, input_device = device)
            inputs_adv = TA.attack(args, selected_inputs, selected_targets)
        else:
            if i == 0:
                print("{} attack starts".format(adv_method))
            inputs_adv = atk(selected_inputs, selected_targets)

        outputs = model(selected_inputs)
        correct_clean, _, _, _ = extract_correct(outputs, selected_targets)

        # adv pred
        outputs_adv = model(inputs_adv)
        correct_adv, _, _, _ = extract_correct(outputs_adv, selected_targets)

        # check success
        successed_images = extract_success(
            bucket        = successed_images,
            inputs        = selected_inputs,
            inputs_adv    = inputs_adv,
            targets       = selected_targets,
            correct_clean = correct_clean,
            correct_adv   = correct_adv
        )
        
        # accuracy
        clean_acc.update(correct_clean.sum().item()/selected_targets.size(0), n=selected_targets.size(0))
        adv_acc.update(correct_adv.sum().item()/selected_targets.size(0), n=selected_targets.size(0))
        
        if i % log_interval == 0 and i != 0: 
            _logger.info('TEST [{:>4d}/{}] '
                         'CLEAN: {clean.val:>6.4f} ({clean.avg:>6.4f}) '
                         'ADV: {adv.val:>6.4f} ({adv.avg:>6.4f}) '.format(
                             i+1, len(testloader),
                             clean     = clean_acc,
                             adv       = adv_acc
                         ))

    _logger.info('TEST [FINAL] '
                 'CLEAN: {clean.avg:>6.4f} '
                 'ADV: {adv.avg:>6.4f} '.format(
                     i+1, len(testloader),
                     clean     = clean_acc,
                     adv       = adv_acc
                 ))
    
    l2_distance = calculate_l2_distance(successed_images['adv'], successed_images['clean'])
    print('L2 distance is {:.2f}'.format(l2_distance))

    # save successed images
    pickle.dump(successed_images, open(os.path.join(savedir, 'successed_images.pkl'),'wb'))
    
    # save results
    json.dump(
        {
            'clean acc':clean_acc.avg,
            'adv acc':adv_acc.avg,
            'L2':l2_distance.tolist(),
        }, 
        open(os.path.join(savedir, 'results.json'),'w'),
        indent=4
    )

def save_args_as_json(args, save_path):
    args_dict = vars(args)
    with open(save_path, 'w') as json_file:
        json.dump(args_dict, json_file, indent=4)

def run(args):
    setup_default_logging()
    torch_seed(args.seed)
    
    savedir = os.path.join(args.savedir,args.exp_name)
    os.makedirs(savedir, exist_ok=True)

    # check file
    if not os.path.isfile(os.path.join(savedir, 'successed_images.pkl')):
        # load adversarial parameteres and update arguments
        adv_params = json.load(open(os.path.join(args.adv_config, f'{args.adv_name.lower()}.json'),'r'))    
        vars(args).update(adv_params)

        # save argsvars
        device = 'cuda:{}'.format(args.device) if torch.cuda.is_available() else 'cpu'
        _logger.info('Device: {}'.format(device))
        for key, value in vars(args).items():
            _logger.info('{}: {}'.format(key, value))
        
        save_args_as_json(args, os.path.join(savedir, 'args.json'))

        if args.dataname == 'CIFAR10':
            _, _, testloader = create_dataloader(
                datadir     = args.datadir, 
                dataname    = args.dataname, 
                batch_size  = args.batch_size, 
                num_workers = args.num_workers
            )
        elif args.dataname == 'ImageNet':
            transform_imagenet = transforms.Compose([        
                transforms.Resize(256), 
                transforms.CenterCrop(224),
                transforms.ToTensor(),
            ])
            args.imagenet_datadir, args.imagenet_labeldir = ImageNet_load_data(args)
            val_dataset = ImageNetValidationDataset_test(image_folder=args.imagenet_datadir, label_file=args.imagenet_labeldir, transform=transform_imagenet)
            testloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)

        # Build Model
        model = create_model(
            modelname             = args.modelname, 
            dataname              = args.dataname,
            num_classes           = args.num_classes,
            test                  = True,
            logits_dim            = args.num_classes,
            device                = device
        ).to(device)
        
        # validate
        validate(
            model        = model,
            num_classes  = args.num_classes,
            testloader   = testloader, 
            adv_method   = args.adv_method, 
            adv_params   = adv_params, 
            savedir      = savedir, 
            log_interval = args.log_interval, 
            device       = device
        )
    else:
        _logger.info('Already result {} file exists'.format(args.exp_name))
        save_path = os.path.join(savedir, 'successed_images.pkl')
        bucket = pickle.load(open(save_path, 'rb'))
        adv = bucket['adv']
        print(adv.shape)


if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name',type=str,default='PGD',help='experiment name')
    parser.add_argument('--modelname',type=str,default='vgg19')
    parser.add_argument('--device',type=str,default='0',help='specify the used device')
    
    # adv
    parser.add_argument('--adv_name',type=str,default='PGD',help='adversrial experiments name')
    parser.add_argument('--adv_method',type=str,default='PGD',help='adversarial attack method name')
    parser.add_argument('--adv_config',type=str,default='configs_adv',help='adversarial attack configuration directory')

    # dataset
    parser.add_argument('--datadir',type=str,default='datasets',help='data directory')
    parser.add_argument('--savedir',type=str,default='results/CIFAR10/saved_adv_samples',help='saved model directory')
    parser.add_argument('--dataname',type=str,default='CIFAR10',choices=['CIFAR10','ImageNet'],help='data name')
    parser.add_argument('--num_classes',type=int,default=10,help='the number of classes')

    # training
    parser.add_argument('--batch_size',type=int,default=128,help='batch size')
    parser.add_argument('--num_workers',type=int,default=1,help='the number of workers (threads)')
    parser.add_argument('--log_interval',type=int,default=5,help='log interval')
    parser.add_argument('--seed',type=int,default=223,help='seed')

    # TA_attack_para
    parser.add_argument('--max_queries',type=int,default=500,help='The max number of queries in model')
    parser.add_argument('--ratio_mask',type=float,default=0.1,help='ratio of mask')
    parser.add_argument('--dim_num',type=int,default=3,help='the number of picked dimensions')
    parser.add_argument('--max_iter_num_in_2d',type=int,default=2,help='the maximum iteration number of attack algorithm in 2d subspace')
    parser.add_argument('--init_alpha',type=float,default=np.pi/2,help='the initial angle of alpha')
    parser.add_argument('--plus_learning_rate',type=float,default=0.01,help='plus learning_rate when success')
    parser.add_argument('--minus_learning_rate',type=float,default=0.05,help='minus learning_rate when fail')
    parser.add_argument('--half_range',type=float,default=0.1,help='half range of alpha from pi/2')
    parser.add_argument('--side_length',type=int,default=32,help='CIFAR-10 inputs side_length is 32')

    args = parser.parse_args()

    run(args)