import argparse
import os
import torch
from torch.utils.data import DataLoader

from advbench.lib import misc
from advbench import algorithms
from advbench import datasets
from advbench import evalulation_methods

def main(args, hparams, test_hparams):

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    dataset = vars(datasets)[args.dataset](args.data_dir, device)
    algorithm = vars(algorithms)[args.algorithm](
        dataset.INPUT_SHAPE,
        dataset.NUM_CLASSES,
        hparams,
        device).to(device)

    test_loader = DataLoader(
        dataset=dataset.splits['test'],
        batch_size=100,
        num_workers=dataset.N_WORKERS,
        pin_memory=False,
        shuffle=False)

    algorithm.load_state_dict(torch.load(
        os.path.join(args.output_dir, 'ckpts', 'model_ckpt_final.pkl')
    )['state_dict'])

    adv_evaluator = vars(evalulation_methods)['StochasticPGD'](
        algorithm=algorithm,
        device=device,
        output_dir=args.output_dir,
        test_hparams=test_hparams)
    
    results = adv_evaluator.calculate(test_loader)
    print(results)

    clean_evaluator = vars(evalulation_methods)['StochasticClean'](
        algorithm=algorithm,
        device=device,
        output_dir=args.output_dir,
        test_hparams=test_hparams)
    
    clean_results = clean_evaluator.calculate(test_loader)
    print(clean_results)






if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Photonic')
    parser.add_argument('--input_dir', type=str, default='train_output')

    args = parser.parse_args()

    train_args = argparse.Namespace(
        **misc.read_dict(os.path.join(args.input_dir, 'args.json'))
    )

    hparams = misc.read_dict(os.path.join(args.input_dir, 'hparams.json'))
    test_hparams = misc.read_dict(os.path.join(args.input_dir, 'test_hparams.json'))

    main(train_args, hparams, test_hparams)