import argparse
import numpy as np
import os
import random
import sys

import torch
import torch.nn.functional as F
from torchvision.datasets import CIFAR10, CIFAR100, SVHN, LSUN, ImageFolder
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

from argument import print_args
from models import *
from utils.detect_utils import *

from defenses import *
from attacks import *

# Argument
parser = argparse.ArgumentParser()

# Directory
parser.add_argument('--data_dir', default='/data_large/readonly', type=str)
parser.add_argument('--model_dir', default='./src/checkpoints', type=str)

# Model
parser.add_argument('--model_type', default='WideResNet', type=str)
parser.add_argument('--train_type', default='s_pgd_adv_l2_tune', type=str)
parser.add_argument('--depth', default=34, type=int)
parser.add_argument('--widen_factor', default=10, type=int)

# Experiment
parser.add_argument('--start', default=0, type=int)
parser.add_argument('--num_images', default=10000, type=int)
parser.add_argument('--batch_size', default=50, type=int)
parser.add_argument('--seed', default=0, type=int)

# Attack
parser.add_argument('--attack_type', default='PGDAttackL2', type=str)
parser.add_argument('--attack_criterion', default='xent', type=str)
parser.add_argument('--epsilon', default=0.25, type=float)
parser.add_argument('--step_size', default=0.25, type=float)
parser.add_argument('--num_steps', default=1, type=int)
parser.add_argument('--random_starts', default=1, type=int)

# Safe spot
parser.add_argument('--defense_type', default='SafeSpotL2', type=str)
parser.add_argument('--delta', default=0.25, type=float)
parser.add_argument('--num_iters', default=20, type=int)
parser.add_argument('--lr', default=0.0002, type=float)
parser.add_argument('--hessian', action='store_true')

# Detection
parser.add_argument('--ood', default='LSUN', type=str, help='Out-distribution dataset')
parser.add_argument('--recall_level', default=0.95, type=float, help='True positive rate')
parser.add_argument('--lmbd', default=0.5, type=float)

args = parser.parse_args()

mean = torch.tensor([0.4914, 0.4822, 0.4465]).cuda()
std = torch.tensor([0.2023, 0.1994, 0.2010]).cuda()

# Main script
if __name__ == '__main__':
    # Fix random seed
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Parse arguments
    print('\nArguments')
    print_args(args)

    # Load dataset
    print('\nLoading dataset..')
    dataset_in = CIFAR10(args.data_dir, train=False,
                         transform=transforms.ToTensor(),
                         download=False)
    loader_in = DataLoader(dataset_in, batch_size=args.batch_size, shuffle=True)

    if args.ood == 'gaussian':
        targets_out = torch.ones(len(dataset_in.targets))
        data_out = torch.clamp(torch.randn(size=(targets_out.size(0), 3, 32, 32)) + 0.5, 0, 1)
        dataset_out = torch.utils.data.TensorDataset(data_out, targets_out)
        loader_out= torch.utils.data.DataLoader(dataset_out, batch_size=args.batch_size, shuffle=True)

    elif args.ood == 'CIFAR100':
        dataset_out = CIFAR100(args.data_dir, train=False,
                               transform=transforms.ToTensor(),
                               download=False)
        loader_out = DataLoader(dataset_out, batch_size=args.batch_size, shuffle=True)

    elif args.ood == 'SVHN':
        dataset_out = SVHN(os.path.join(args.data_dir, 'svhn'), split='test',
                           transform=transforms.ToTensor(),
                           download=False)
        loader_out = DataLoader(dataset_out, batch_size=args.batch_size, shuffle=True)

    elif args.ood == 'LSUN':
        dataset_out = LSUN(root=os.path.join(args.data_dir, 'lsun'), classes='test',
                           transform=transforms.Compose([transforms.Resize(32), transforms.CenterCrop(32), transforms.ToTensor()]))
        loader_out = DataLoader(dataset_out, batch_size=args.batch_size, shuffle=True)

    elif args.ood == 'TinyImageNet':
        dataset_out = ImageFolder(root=os.path.join(args.data_dir, 'tiny-imagenet-200', 'val'),
                                  transform=transforms.Compose([transforms.Resize(32), transforms.ToTensor()]))
        loader_out = DataLoader(dataset_out, batch_size=args.batch_size, shuffle=True)

    # Load model
    print('\nLoading model..')
    model_class = getattr(sys.modules[__name__], args.model_type)

    base_model = model_class(depth=args.depth, widen_factor=args.widen_factor)
    base_model = base_model.cuda()
    base_model = torch.nn.DataParallel(base_model)
    checkpoint_path = os.path.join(args.model_dir, 'cifar10', args.model_type.lower(), args.train_type, 'ckpt.pt')
    checkpoint = torch.load(checkpoint_path)
    checkpoint = {'module.'+k:v for k,v in checkpoint['state_dict'].items()}
    base_model.load_state_dict(checkpoint)
    model = ModelWrapper(base_model, mean, std)
    model.eval()

    # Create defense
    defense_class = getattr(sys.modules[__name__], args.defense_type)
    defense = defense_class(model,
                            args.attack_type, args.epsilon, args.step_size, args.num_steps, args.random_starts,
                            args.delta, args.lr)

    # Detect
    count = 0
    msps_in = None
    msps_out = None
    scores_in = None
    scores_out = None

    for ((image_in, label_in), (image_out, label_out)) in zip(loader_in, loader_out):
        count += args.batch_size

        if count <= args.start:
            continue

        print('\nImage {}-{}'.format(count - args.batch_size, count))

        # Generate safe spots for in-distribution samples
        image_in, label_in = image_in.cuda(), label_in.cuda()
        output_in = model(image_in)
        prob_in = F.softmax(output_in, dim=1)
        _, pred_in = output_in.max(1)

        defense.initialize(image_in, pred_in)
        for i in range(args.num_iters):
            defense.update()

        image_safe_in = defense.get_safe_spot()
        image_safe_adv_in = defense.attack(image_safe_in, pred_in, random_start=(args.random_starts > 0))
        output_safe_adv_in = model(image_safe_adv_in)
        prob_safe_adv_in = F.softmax(output_safe_adv_in, dim=1)

        # Get scores for in-distribution samples
        max_prob_in = torch.max(prob_in, dim=1)[0].detach().cpu().numpy()
        max_prob_safe_adv_in = torch.max(prob_safe_adv_in, dim=1)[0].detach().cpu().numpy()

        msp_in = -max_prob_in
        msps_in = msp_in if msps_in is None else np.concatenate([msps_in, msp_in], axis=0)

        score_in = -(args.lmbd * max_prob_in + (1 - args.lmbd) * max_prob_safe_adv_in)
        scores_in = score_in if scores_in is None else np.concatenate([scores_in, score_in], axis=0)

        # Generate safe spots for out-distribution samples
        image_out, label_out = image_out.cuda(), label_out.cuda()
        output_out = model(image_out)
        prob_out = F.softmax(output_out, dim=1)
        _, pred_out = output_out.max(1)

        defense.initialize(image_out, pred_out)
        for i in range(args.num_iters):
            defense.update()

        image_safe_out = defense.get_safe_spot()
        image_safe_adv_out = defense.attack(image_safe_out, pred_out, random_start=(args.random_starts > 0))
        output_safe_adv_out = model(image_safe_adv_out)
        prob_safe_adv_out = F.softmax(output_safe_adv_out, dim=1)

        # Get scores for out-distribution samples
        max_prob_out = torch.max(prob_out, dim=1)[0].detach().cpu().numpy()
        max_prob_safe_adv_out = torch.max(prob_safe_adv_out, dim=1)[0].detach().cpu().numpy()

        msp_out = -max_prob_out
        msps_out = msp_out if msps_out is None else np.concatenate([msps_out, msp_out], axis=0)

        score_out = -(args.lmbd * max_prob_out + (1 - args.lmbd) * max_prob_safe_adv_out)
        scores_out = score_out if scores_out is None else np.concatenate([scores_out, score_out], axis=0)

        # Print
        auroc, aupr, fpr = get_measures(msps_out, msps_in, recall_level=args.recall_level)
        print('\nMSP, AUROC: {}, AUPR: {}, FPR: {}'.format(auroc, aupr, fpr))

        auroc, aupr, fpr = get_measures(scores_out, scores_in, recall_level=args.recall_level)
        print('\nOurs, AUROC: {}, AUPR: {}, FPR: {}'.format(auroc, aupr, fpr))

        if count >= args.start + args.num_images:
            break

