import os
import sys

import numpy as np
import torch

from utils.misc_utils import load_data, load_model, save_dir, evaluate, evaluate_rand
from argument import parser, print_args

from defenses import *
from attacks import *

# Main script
if __name__ == '__main__':
    # Fix random seed
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Parse arguments
    args = parser()
    print('\nArguments')
    print_args(args)

    # Load dataset
    print('\nLoading dataset..')
    test_loader = load_data(args)

    # Load model
    print('\nLoading model..')
    model = load_model(args)

    # Create safe spot algorithm
    defense_class = getattr(sys.modules[__name__], args.defense_type)
    defense = defense_class(model,
                            args.attack_type, args.epsilon, args.step_size, args.num_steps, args.random_starts,
                            args.delta, args.lr, hessian=args.hessian, scale=args.scale, num_samples=args.num_samples)

    if args.eval or args.eval_rand:
        # Create attack algorithm
        attack_class = getattr(sys.modules[__name__], args.attack_type_test)
        attack = attack_class(model, args.epsilon_test, args.step_size_test, args.num_steps_test, scale=args.scale, num_samples=args.num_samples_test)

    # Create directory
    if args.save:
        safe_spot_path = save_dir(args)

        if not os.path.exists(safe_spot_path):
            os.makedirs(safe_spot_path, exist_ok=True)

    # Run safe spot algorithm
    count = 0
    correct_nat = 0
    correct_adv = 0
    correct_safe = 0
    correct_safe_adv = 0

    total_orig_images = []
    total_safe_images = []
    total_labels = []

    for image, label in test_loader:
        count += args.batch_size

        if count <= args.start:
            continue

        print('\nImage {}-{}'.format(count-args.batch_size, count))

        image = image.cuda()
        label = label.cuda()
        
        # randomized smoothing 
        if 'gaussian' in args.train_type and not 'gaussian_max' in args.train_type:
            pred = evaluate_rand(model, image, label, args.scale, args.num_samples_test)[0]
        # base classifier
        else:
            output = model(image)
            _, pred = torch.max(output, 1)

        # initialize safe spot
        defense.initialize(image, pred)
        
        # update safe spot for some iterations
        for j in range(args.num_iters):
            defense.update()

        # get safe spot
        safe_spot = defense.get_safe_spot()

        total_orig_images.append(image.cpu().numpy())
        total_safe_images.append(safe_spot.cpu().numpy())
        total_labels.append(label.cpu().numpy().reshape(-1, 1))

        # evaluate safe spots
        if args.eval or args.eval_rand:
            # base classifier
            if args.eval:
                # evaluate original images (clean)
                correct_nat_batch = evaluate(model, image, label)

                # evaluate original images (adv)
                image_adv = attack(image, label).detach()
                correct_adv_batch = evaluate(model, image_adv, label)

                # evaluate safe images (clean)
                correct_safe_batch = evaluate(model, safe_spot, label)

                # evaluate safe images (adv)
                safe_spot_adv = attack(safe_spot, label).detach()
                correct_safe_adv_batch = evaluate(model, safe_spot_adv, label)
            
            # Randomized smoothing
            else:
                # evaluate original images (clean)
                correct_nat_batch = evaluate_rand(model, image, label, args.scale, args.num_samples_test)[1]

                # evaluate original images (adv)
                image_adv = attack(image, label).detach()
                correct_adv_batch = evaluate_rand(model, image_adv, label, args.scale, args.num_samples_test)[1]

                # evaluate safe images (clean)
                correct_safe_batch = evaluate_rand(model, safe_spot, label, args.scale, args.num_samples_test)[1]

                # evaluate safe images (adv)
                safe_spot_adv = attack(safe_spot, label).detach()
                correct_safe_adv_batch = evaluate_rand(model, safe_spot_adv, label, args.scale, args.num_samples_test)[1]

            print('batch')
            print('nat acc: {:.3f}, adv acc: {:.3f}, '
                  'safe acc: {:.3f}, safe_adv acc: {:.3f}'.format(correct_nat_batch/args.batch_size,
                                                                  correct_adv_batch/args.batch_size,
                                                                  correct_safe_batch/args.batch_size,
                                                                  correct_safe_adv_batch/args.batch_size))
            
            correct_nat += correct_nat_batch
            correct_adv += correct_adv_batch
            correct_safe += correct_safe_batch
            correct_safe_adv += correct_safe_adv_batch
            
            print('total')
            print('nat acc: {:.3f} ({}/{}), '
                  'adv acc: {:.3f} ({}/{}), '
                  'safe acc: {:.3f} ({}/{}), '
                  'safe_adv acc: {:.3f} ({}/{})'.format(correct_nat/(count-args.start), correct_nat, (count-args.start),
                                                        correct_adv/(count-args.start), correct_adv, (count-args.start),
                                                        correct_safe/(count-args.start), correct_safe, (count-args.start),
                                                        correct_safe_adv/(count-args.start), correct_safe_adv,
                                                        (count-args.start)))
            

        if (count-args.start) >= args.num_images:
            break
    
    # sanity check
    total_orig_images = np.concatenate(total_orig_images)
    total_safe_images = np.concatenate(total_safe_images)
    total_labels = np.concatenate(total_labels)

    assert np.amax(total_orig_images) < 1 + 1e-4
    assert np.amin(total_orig_images) > 0 - 1e-4
    assert np.amax(total_safe_images) < 1 + 1e-4
    assert np.amin(total_safe_images) > 0 - 1e-4
    
    # save safe spots
    if args.save:
        print('Saving..')

        safe_spot_path = os.path.join(safe_spot_path, 'image-{}_{}'.format(args.start, args.start+args.num_images))
        if not os.path.exists(safe_spot_path):
            os.makedirs(safe_spot_path, exist_ok=True)
        np.save(os.path.join(safe_spot_path, 'orig_image.npy'), total_orig_images)
        np.save(os.path.join(safe_spot_path, 'safe_image.npy'), total_safe_images)
        np.save(os.path.join(safe_spot_path, 'label.npy'), total_labels)
