import argparse
import os
import torch
from torch.utils.data import DataLoader
import pandas as pd

from advbench.lib import misc
from advbench import algorithms
from advbench import datasets
from advbench import evalulation_methods

def main(args, hparams, test_hparams):

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    dataset = vars(datasets)[args.dataset](args.data_dir, device)
    algorithm = vars(algorithms)[args.algorithm](
        dataset.INPUT_SHAPE,
        dataset.NUM_CLASSES,
        hparams,
        device).to(device)

    test_loader = DataLoader(
        dataset=dataset.splits['test'],
        batch_size=test_hparams['batch_size'],
        num_workers=dataset.N_WORKERS,
        pin_memory=False,
        shuffle=False)

    
    for n_steps in [1, 5, 10, 20, 50, 100, 200]:
        all_results = []
        print(f'N steps: {n_steps}')
        for i in range(50):
            fname = os.path.join(
                args.output_dir, 'ckpts', f'model_ckpt_{i}.pkl'
            )
            algorithm.load_state_dict(torch.load(fname)['state_dict'])

            test_hparams['pgd_n_steps'] = n_steps

            adv_evaluator = vars(evalulation_methods)['PGD'](
                algorithm=algorithm,
                device=device,
                output_dir=args.output_dir,
                test_hparams=test_hparams)
            
            results = adv_evaluator.calculate(test_loader)
            results.update({
                'epoch': i, 'n_steps': n_steps
            })

            # clean_evaluator = vars(evalulation_methods)['Clean'](
            #     algorithm=algorithm,
            #     device=device,
            #     output_dir=args.output_dir,
            #     test_hparams=test_hparams)
            
            # clean_results = clean_evaluator.calculate(test_loader)
            # results.update(clean_results)
            
            all_results.append(results)
            
        df = pd.DataFrame(all_results)
        df.to_pickle(os.path.join(args.output_dir, f'pgd_evaluated_ckpts_n_{n_steps}.pd'))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Evaluate robustness')
    parser.add_argument('--input_dir', type=str, default='train_output')

    args = parser.parse_args()

    train_args = argparse.Namespace(
        **misc.read_dict(os.path.join(args.input_dir, 'args.json'))
    )

    hparams = misc.read_dict(os.path.join(args.input_dir, 'hparams.json'))
    test_hparams = misc.read_dict(os.path.join(args.input_dir, 'test_hparams.json'))

    main(train_args, hparams, test_hparams)