import logging
import os
import json
import torch
import numpy as np
import pickle
import argparse
import tqdm
import clip
from PIL import Image

from dataloader import create_dataloader
from log import setup_default_logging
from models import create_model
from utils.utils import torch_seed, AverageMeter, extract_correct, choose_class_text, calculate_l2_distance
from utils.utils import ImageNet_load_data, clip_classify_images_with_ref_photo
from adv_attacks import create_attack
from sklearn.metrics import roc_curve, auc
from foolbox import PyTorchModel
from foolbox.attacks import LinfDeepFoolAttack
from torchvision import transforms
from dataloader import ImageNetValidationDataset_test
from torch.utils.data import TensorDataset, DataLoader

_logger = logging.getLogger('adv sample') 

def detection(model, clip_model, transform, images, adv, targets, class_texts, savedir, log_interval, batch_size, device='cpu'):
    clean_acc = AverageMeter()
    clean_clip_acc = AverageMeter()
    adv_acc = AverageMeter()
    adv_clip_acc = AverageMeter()

    dataloader = DataLoader(
        TensorDataset(images, adv, targets),
        batch_size  = batch_size,
        shuffle     = False,
        num_workers = 1
    )

    model.eval()
    clip_model.eval()

    labels_for_detector = []
    scores_for_detector = []

    for i, (inputs, inputs_adv, targets) in tqdm.tqdm(enumerate(dataloader), total=len(dataloader), desc='Testing...'):
        inputs, inputs_adv, targets = inputs.to(device), inputs_adv.to(device), targets.to(device)

        # clean pred
        outputs = model(inputs)
        correct_clean, correct_clean_pred, correct_clean_prob, correct_clean_prob_all = extract_correct(outputs, targets)

        outputs_adv = model(inputs_adv)
        correct_adv, correct_adv_pred, correct_adv_prob, correct_adv_prob_all = extract_correct(outputs_adv, targets)

        inds_success = torch.logical_and(correct_clean, torch.logical_not(correct_adv)).detach().cpu().numpy()

        if np.sum(inds_success) == 0:
            continue

        inputs = inputs[inds_success]
        targets = targets[inds_success]
        inputs_adv = inputs_adv[inds_success]
        
        outputs = model(inputs)
        correct_clean, correct_clean_pred, correct_clean_prob, correct_clean_prob_all = extract_correct(outputs, targets)
        correct_clip_clean, correct_clip_clean_pred, correct_clip_clean_prob, correct_clip_clean_prob_all = clip_classify_images_with_ref_photo(clip_model, transform, inputs, targets, correct_clean_pred, class_texts, device)
       
        # adv pred
        outputs_adv = model(inputs_adv)
        correct_adv, correct_adv_pred, correct_adv_prob, correct_adv_prob_all = extract_correct(outputs_adv, targets)
        correct_clip_adv, correct_clip_adv_pred, correct_clip_adv_prob, correct_clip_adv_prob_all = clip_classify_images_with_ref_photo(clip_model, transform, inputs_adv, targets, correct_adv_pred, class_texts, device)
    
        # detection
        label_for_detector = torch.cat([torch.zeros(np.sum(inds_success), dtype=bool), torch.ones(np.sum(inds_success), dtype=bool)]).numpy()

        clean_prob_delta = (1-correct_clip_clean_prob)
        adv_prob_delta = (1-correct_clip_adv_prob)

        score_for_derector = torch.cat((clean_prob_delta, adv_prob_delta)).detach().cpu().numpy()
        labels_for_detector.append(label_for_detector)
        scores_for_detector.append(score_for_derector)
        
        # accuracy
        clean_acc.update(correct_clean.sum().item()/targets.size(0), n=targets.size(0))
        adv_acc.update(correct_adv.sum().item()/targets.size(0), n=targets.size(0))
        clean_clip_acc.update(correct_clip_clean.sum().item()/targets.size(0), n=targets.size(0))
        adv_clip_acc.update(correct_clip_adv.sum().item()/targets.size(0), n=targets.size(0))
        
        if i % log_interval == 0 and i != 0: 
            _logger.info('TEST [{:>4d}/{}] '
                         'CLEAN: {clean.val:>6.4f} ({clean.avg:>6.4f}) '
                         'CLEAN CLIP: {clean_clip.val:>6.4f} ({clean_clip.avg:>6.4f}) '
                         'ADV: {adv.val:>6.4f} ({adv.avg:>6.4f}) '
                         'ADV CLIP: {adv_clip.val:>6.4f} ({adv_clip.avg:>6.4f}) '.format(
                             i+1, len(dataloader),
                             clean      = clean_acc,
                             clean_clip = clean_clip_acc,
                             adv        = adv_acc,
                             adv_clip   = adv_clip_acc,
                         ))

    l2_distance = calculate_l2_distance(images, adv)
    print('L2 distance is {:.2f}'.format(l2_distance))

    _logger.info('TEST [FINAL] '
                 'CLEAN: {clean.avg:>6.4f} '
                 'CLEAN CLIP: {clean_clip.avg:>6.4f} '
                 'ADV: {adv.avg:>6.4f} '
                 'ADV CLIP: {adv_clip.avg:>6.4f} '.format(
                     i+1, len(dataloader),
                     clean      = clean_acc,
                     clean_clip = clean_clip_acc,
                     adv        = adv_acc,
                     adv_clip   = adv_clip_acc
                 ))
    
    # detection
    labels_for_detector = np.concatenate(labels_for_detector)
    scores_for_detector = np.concatenate(scores_for_detector)
    fprs_success, tprs_success, thresholds_success = roc_curve(labels_for_detector, scores_for_detector)
    roc_auc_success = auc(fprs_success, tprs_success)
    print('AUC score is {:.2f}'.format(roc_auc_success*100))

    # save results
    json.dump(
        {
            'clean acc':clean_acc.avg,
            'adv acc':adv_acc.avg,
            'clean clip acc':clean_clip_acc.avg,
            'adv clip acc':adv_clip_acc.avg,
            'AUC':roc_auc_success,
            'L2':l2_distance.tolist()
        }, 
        open(os.path.join(savedir, 'detection_results.json'),'w'),
        indent=4
    )

def save_args_as_json(args, save_path):
    args_dict = vars(args)
    with open(save_path, 'w') as json_file:
        json.dump(args_dict, json_file, indent=4)

def run(args):
    setup_default_logging()
    torch_seed(args.seed)
    
    savedir = os.path.join(args.savedir,args.exp_name)
    os.makedirs(savedir, exist_ok=True)

    device = 'cuda:{}'.format(args.device) if torch.cuda.is_available() else 'cpu'
    _logger.info('Device: {}'.format(device))
    
    # Build Model
    model = create_model(
        modelname             = args.modelname, 
        dataname              = args.dataname,
        num_classes           = args.num_classes, 
        device                = device
    )
    
    clip_model, _ = clip.load("ViT-L/14", device=device)
    class_texts = choose_class_text(args.dataname)
    
    if args.dataname == 'CIFAR10': 
        transform_clip = transforms.Compose([
            transforms.Resize(size=224, interpolation=transforms.InterpolationMode.BICUBIC, max_size=None, antialias=None),
            transforms.CenterCrop(size=(224, 224)),
            transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
        ])
    else: 
        transform_clip = transforms.Compose([
            transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
        ])
 
    save_path = os.path.join(savedir, 'successed_images.pkl')
    bucket = pickle.load(open(save_path, 'rb'))
    print('Detection starts.')
    detection(
        model        = model, 
        clip_model   = clip_model, 
        transform    = transform_clip,
        images       = bucket['clean'],
        adv          = bucket['adv'],
        targets      = bucket['targets'], 
        class_texts  = class_texts,  
        savedir      = savedir, 
        log_interval = args.log_interval, 
        batch_size   = args.batch_size,
        device       = device
    )


if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name',type=str,default='PGD',help='save adversarial examples folder')
    parser.add_argument('--modelname',type=str,default='vgg19')
    parser.add_argument('--device',type=str,default='0',help='specify the used device')

    # dataset
    parser.add_argument('--datadir',type=str,default='datasets',help='data directory')
    parser.add_argument('--savedir',type=str,default='results/CIFAR10/saved_adv_samples',help='saved model directory')
    parser.add_argument('--dataname',type=str,default='CIFAR10',choices=['CIFAR10','ImageNet'],help='data name')
    parser.add_argument('--num_classes',type=int,default=10,help='the number of classes')

    # training
    parser.add_argument('--batch_size',type=int,default=128,help='batch size')
    parser.add_argument('--num-workers',type=int,default=4,help='the number of workers (threads)')
    parser.add_argument('--log-interval',type=int,default=5,help='log interval')
    parser.add_argument('--seed',type=int,default=223,help='seed')

    args = parser.parse_args()

    run(args)