import logging
import os
import json
import torch
import numpy as np
import pickle
import argparse
import tqdm
import clip
from PIL import Image

from log import setup_default_logging
from models import create_model
from utils.utils import torch_seed, AverageMeter, extract_correct, calculate_l2_distance, classify_images_with_ref_vit
from sklearn.metrics import roc_curve, auc
from torchvision import transforms
from dataloader import ImageNetValidationDataset_test
from torch.utils.data import TensorDataset, DataLoader

_logger = logging.getLogger('adv sample') 

def detection(model, detect_model, transform, images, adv, targets, savedir, log_interval, batch_size, device='cpu'):
    clean_acc = AverageMeter()
    clean_cnn_acc = AverageMeter()
    adv_acc = AverageMeter()
    adv_cnn_acc = AverageMeter()

    dataloader = DataLoader(
        TensorDataset(images, adv, targets),
        batch_size  = batch_size,
        shuffle     = False,
        num_workers = 1
    )

    model.eval()
    detect_model.eval()

    labels_for_detector = []
    scores_for_detector = []

    for i, (inputs, inputs_adv, targets) in tqdm.tqdm(enumerate(dataloader), total=len(dataloader), desc='Testing...'):
        inputs, inputs_adv, targets = inputs.to(device), inputs_adv.to(device), targets.to(device)

        outputs = model(inputs)
        correct_clean, correct_clean_pred, correct_clean_prob, correct_clean_prob_all = extract_correct(outputs, targets)

        outputs_adv = model(inputs_adv)
        correct_adv, correct_adv_pred, correct_adv_prob, correct_adv_prob_all = extract_correct(outputs_adv, targets)

        inds_success = torch.logical_and(correct_clean, torch.logical_not(correct_adv)).detach().cpu().numpy()

        if np.sum(inds_success) == 0:
            continue

        inputs = inputs[inds_success]
        targets = targets[inds_success]
        inputs_adv = inputs_adv[inds_success]
        
        # clean pred
        outputs = model(inputs)
        correct_clean, correct_clean_pred, correct_clean_prob, correct_clean_prob_all = extract_correct(outputs, targets)
        correct_vit_clean, correct_vit_clean_pred, correct_vit_clean_prob, correct_vit_clean_prob_all = classify_images_with_ref_vit(detect_model, transform, inputs, targets, correct_clean_pred, device)
        
        # adv pred
        outputs_adv = model(inputs_adv)
        correct_adv, correct_adv_pred, correct_adv_prob, correct_adv_prob_all = extract_correct(outputs_adv, targets)
        correct_vit_adv, correct_vit_adv_pred, correct_vit_adv_prob, correct_vit_adv_prob_all = classify_images_with_ref_vit(detect_model, transform, inputs_adv, targets, correct_adv_pred, device)

        # detection
        label_for_detector = torch.cat([torch.zeros(np.sum(inds_success), dtype=bool), torch.ones(np.sum(inds_success), dtype=bool)]).numpy()
        
        clean_prob_delta = (1-correct_vit_clean_prob)
        adv_prob_delta = (1-correct_vit_adv_prob)
        

        score_for_derector = torch.cat((clean_prob_delta, adv_prob_delta)).detach().cpu().numpy()
        labels_for_detector.append(label_for_detector)
        scores_for_detector.append(score_for_derector)
        
        # accuracy
        clean_acc.update(correct_clean.sum().item()/targets.size(0), n=targets.size(0))
        adv_acc.update(correct_adv.sum().item()/targets.size(0), n=targets.size(0))
        clean_cnn_acc.update(correct_vit_clean.sum().item()/targets.size(0), n=targets.size(0))
        adv_cnn_acc.update(correct_vit_adv.sum().item()/targets.size(0), n=targets.size(0))
        
        if i % log_interval == 0 and i != 0: 
            _logger.info('TEST [{:>4d}/{}] '
                         'CLEAN: {clean.val:>6.4f} ({clean.avg:>6.4f}) '
                         'CLEAN CNN: {clean_cnn.val:>6.4f} ({clean_cnn.avg:>6.4f}) '
                         'ADV: {adv.val:>6.4f} ({adv.avg:>6.4f}) '
                         'ADV CNN: {adv_cnn.val:>6.4f} ({adv_cnn.avg:>6.4f}) '.format(
                             i+1, len(dataloader),
                             clean      = clean_acc,
                             clean_cnn = clean_cnn_acc,
                             adv        = adv_acc,
                             adv_cnn   = adv_cnn_acc,
                         ))

    l2_distance = calculate_l2_distance(images, adv)
    print('L2 distance is {:.2f}'.format(l2_distance))

    _logger.info('TEST [FINAL] '
                 'CLEAN: {clean.avg:>6.4f} '
                 'CLEAN CNN: {clean_cnn.avg:>6.4f} '
                 'ADV: {adv.avg:>6.4f} '
                 'ADV CNN: {adv_cnn.avg:>6.4f} '.format(
                     i+1, len(dataloader),
                     clean      = clean_acc,
                     clean_cnn = clean_cnn_acc,
                     adv        = adv_acc,
                     adv_cnn   = adv_cnn_acc
                 ))
    
    # detection
    labels_for_detector = np.concatenate(labels_for_detector)
    scores_for_detector = np.concatenate(scores_for_detector)
    fprs_success, tprs_success, thresholds_success = roc_curve(labels_for_detector, scores_for_detector)
    roc_auc_success = auc(fprs_success, tprs_success)
    print('AUC score is {:.2f}'.format(roc_auc_success*100))

def save_args_as_json(args, save_path):
    args_dict = vars(args)
    with open(save_path, 'w') as json_file:
        json.dump(args_dict, json_file, indent=4)

def run(args):
    setup_default_logging()
    torch_seed(args.seed)
    
    savedir = os.path.join(args.savedir,args.exp_name)
    os.makedirs(savedir, exist_ok=True)

    device = 'cuda:{}'.format(args.device) if torch.cuda.is_available() else 'cpu'
    _logger.info('Device: {}'.format(device))
        
    detect_model_name = 'vit_l_16'
    
    # Build Model
    model = create_model(
        modelname             = args.modelname, 
        dataname              = args.dataname,
        num_classes           = args.num_classes, 
        device                = device
    )
    model.to(device)

    detect_model = create_model(
        modelname             = detect_model_name, 
        dataname              = args.dataname,
        num_classes           = args.num_classes,
        device                = device
    )
    detect_model.to(device)

    transform_vit = transforms.Compose([
        transforms.Resize(size=224, interpolation=transforms.InterpolationMode.BICUBIC, max_size=None, antialias=None),
        transforms.CenterCrop(size=(224, 224)),
    ])

    save_path = os.path.join(savedir, 'successed_images.pkl')
    bucket = pickle.load(open(save_path, 'rb'))
    print('Detection starts.')
    detection(
        model        = model, 
        detect_model = detect_model,
        transform    = transform_vit,
        images       = bucket['clean'],
        adv          = bucket['adv'],
        targets      = bucket['targets'],  
        savedir      = savedir, 
        log_interval = args.log_interval, 
        batch_size   = args.batch_size,
        device       = device
    )


if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name',type=str,default='PGD',help='save adversarial examples folder')
    parser.add_argument('--modelname',type=str,default='vgg19')
    parser.add_argument('--device',type=str,default='0',help='specify the used device')

    # dataset
    parser.add_argument('--datadir',type=str,default='datasets',help='data directory')
    parser.add_argument('--savedir',type=str,default='results/CIFAR10/saved_adv_samples',help='saved model directory')
    parser.add_argument('--dataname',type=str,default='CIFAR10',choices=['CIFAR10','ImageNet'],help='data name')
    parser.add_argument('--num_classes',type=int,default=10,help='the number of classes')

    # training
    parser.add_argument('--batch_size',type=int,default=128,help='batch size')
    parser.add_argument('--num-workers',type=int,default=4,help='the number of workers (threads)')
    parser.add_argument('--log-interval',type=int,default=5,help='log interval')
    parser.add_argument('--seed',type=int,default=223,help='seed')

    args = parser.parse_args()

    run(args)