'''Computes similarity matrix'''
import os
import sys
import argparse
import random
import torch
import numpy as np

from robustness import defaults
from robustness.datasets import DATASETS
from robustness.main import setup_args

from utils.model import load_model, get_perturbed_images
from utils.helper import create_logger


def get_args():
    parser = argparse.ArgumentParser(conflict_handler="resolve")
    parser.add_argument('--perturb_method', default=None, type=str,
        choices=['blur', 'noise', 'adv', 'elastic'],
        help='Method for perturbing instances of dataset.')
    parser.add_argument('--seed', default=0, type=int, help='Seed')
    parser.add_argument('--data_split', default='', type=str, choices=['', 'train'],
        help='The split from which to get the images')

    parser = defaults.add_args_to_parser(defaults.CONFIG_ARGS, parser)
    parser = defaults.add_args_to_parser(defaults.MODEL_LOADER_ARGS, parser)
    parser = defaults.add_args_to_parser(defaults.TRAINING_ARGS, parser)
    parser = defaults.add_args_to_parser(defaults.PGD_ARGS, parser)

    parser.add_argument('--eps', default=0.5, type=str, help='Eps')
    parser.add_argument('--attack-lr', default=0.5, type=str, help='Attack LR')
    parser.add_argument('--attack-steps', default=100, type=int, help='Attack steps')
    parser.add_argument('--constraint', default='2', type=str, help='Constraint')

    parser.add_argument('--num_classes', default=10, type=int, help='Number of classes')
    args = parser.parse_args()
    return args

# .ElasticTransform(alpha=250.0)
def main():
    args = get_args()
    args = setup_args(args)
    print(args)

    perturb_path = args.perturb_method
    if args.perturb_method == 'adv':
        perturb_path = os.path.join(perturb_path, f'eps{str(args.eps)}_atlr{args.attack_lr}_atsteps{args.attack_steps}', args.resume.split('/')[-2])
    path = os.path.join(args.out_dir, args.dataset, perturb_path, args.data_split)
    print(path)
    input()
    if not os.path.exists(path):
        os.makedirs(path)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    dataset = DATASETS[args.dataset](args.data) #, **kwargs)
    trainloader, testloader = dataset.make_loaders(args.workers, args.batch_size, 
        shuffle_train=False, shuffle_val=False, data_aug=False, val_batch_size=args.batch_size)
    loader = testloader
    if args.data_split == 'train':
        loader = trainloader

    model, result_model = load_model(
        args, args.arch, args.resume, dataset, testloader=None)
    print(f'Eval model 1: {result_model}')

    perturbed_imgs, labels, preds = get_perturbed_images(args, model, loader, device)
    torch.save(perturbed_imgs, os.path.join(path, 'images.pt'))
    torch.save(labels, os.path.join(path, 'labels.pt'))
    torch.save(preds, os.path.join(path, 'preds.pt'))

if __name__=='__main__':
    main()
