import os
import sys

import numpy as np
import torch

from utils.misc_utils import load_data, load_model, save_dir, evaluate
from argument import parser, print_args

from attacks import *

# Main script
if __name__ == '__main__':
    # Fix random seed
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Parse arguments
    args = parser()
    print('\nArguments')
    print_args(args)

    # Load dataset
    print('\nLoading dataset..')
    test_loader = load_data(args)

    # Load model
    print('\nLoading model..')
    model = load_model(args)

    # Create safe spot algorithm
    defense_class = getattr(sys.modules[__name__], args.defense_type)
    defense = defense_class(model, args.delta, args.lr, args.num_iters)

    if args.eval:
        # Create attack algorithm
        attack_class = getattr(sys.modules[__name__], args.attack_type_test)
        attack = attack_class(model, args.epsilon_test, args.step_size_test, args.num_steps_test)

    # Create directory
    if args.save:
        safe_spot_path = save_dir(args)

        if not os.path.exists(safe_spot_path):
            os.makedirs(safe_spot_path, exist_ok=True)

    # Run safe spot algorithm
    count = 0
    correct_nat = 0
    correct_adv = 0
    correct_safe = 0
    correct_safe_adv = 0

    total_orig_images = []
    total_safe_images = []
    total_labels = []

    for image, label in test_loader:
        count += args.batch_size

        if count <= args.start:
            continue

        print('\nImage {}-{}'.format(count-args.batch_size, count))

        image = image.cuda()
        label = label.cuda()
        
        # base classifier
        output = model(image)
        _, pred = torch.max(output, 1)

        # get safe spot
        safe_spot = defense(image, pred, random_start=False, targeted=True).detach()

        total_orig_images.append(image.cpu().detach().numpy())
        total_safe_images.append(safe_spot.cpu().numpy())
        total_labels.append(label.cpu().numpy().reshape(-1, 1))

        # evaluate safe spots
        if args.eval:
            # evaluate original images (clean)
            correct_nat_batch = evaluate(model, image, label)

            # evaluate original images (adv)
            image_adv = attack(image, label).detach()
            correct_adv_batch = evaluate(model, image_adv, label)

            # evaluate safe images (clean)
            correct_safe_batch = evaluate(model, safe_spot, label)

            # evaluate safe images (adv)
            safe_spot_adv = attack(safe_spot, label).detach()
            correct_safe_adv_batch = evaluate(model, safe_spot_adv, label)

            print('batch')
            print('nat acc: {:.3f}, adv acc: {:.3f}, '
                  'safe acc: {:.3f}, safe_adv acc: {:.3f}'.format(correct_nat_batch/args.batch_size,
                                                                  correct_adv_batch/args.batch_size,
                                                                  correct_safe_batch/args.batch_size,
                                                                  correct_safe_adv_batch/args.batch_size))
            
            correct_nat += correct_nat_batch
            correct_adv += correct_adv_batch
            correct_safe += correct_safe_batch
            correct_safe_adv += correct_safe_adv_batch
            
            print('total')
            print('nat acc: {:.3f} ({}/{}), '
                  'adv acc: {:.3f} ({}/{}), '
                  'safe acc: {:.3f} ({}/{}), '
                  'safe_adv acc: {:.3f} ({}/{})'.format(correct_nat/(count-args.start), correct_nat, (count-args.start),
                                                        correct_adv/(count-args.start), correct_adv, (count-args.start),
                                                        correct_safe/(count-args.start), correct_safe, (count-args.start),
                                                        correct_safe_adv/(count-args.start), correct_safe_adv,
                                                        (count-args.start)))
            

        if (count-args.start) >= args.num_images:
            break
    
    # sanity check
    total_orig_images = np.concatenate(total_orig_images)
    total_safe_images = np.concatenate(total_safe_images)
    total_labels = np.concatenate(total_labels)

    assert np.amax(total_orig_images) < 1 + 1e-4
    assert np.amin(total_orig_images) > 0 - 1e-4
    assert np.amax(total_safe_images) < 1 + 1e-4
    assert np.amin(total_safe_images) > 0 - 1e-4
    
    # save safe spots
    if args.save:
        print('Saving..')

        safe_spot_path = os.path.join(safe_spot_path, 'image-{}_{}'.format(args.start, args.start+args.num_images))
        if not os.path.exists(safe_spot_path):
            os.makedirs(safe_spot_path, exist_ok=True)
        np.save(os.path.join(safe_spot_path, 'orig_image.npy'), total_orig_images)
        np.save(os.path.join(safe_spot_path, 'safe_image.npy'), total_safe_images)
        np.save(os.path.join(safe_spot_path, 'label.npy'), total_labels)
