import torch
from torch.utils.data import DataLoader
import pandas as pd
import os
import argparse

from robustbench.data import load_cifar10
from robustbench.utils import load_model

from advbench import datasets
from advbench import algorithms
from advbench import evalulation_methods
from advbench import hparams_registry

# Top 10 entries
MODEL_NAMES = [
    'Wang2023Better_WRN-70-16',
    'Wang2023Better_WRN-28-10',
    'Rebuffi2021Fixing_70_16_cutmix_extra',
    'Gowal2021Improving_70_16_ddpm_100m',
    'Huang2022Revisiting_WRN-A4',
    'Rebuffi2021Fixing_106_16_cutmix_ddpm',
    'Rebuffi2021Fixing_70_16_cutmix_ddpm',
    'Kang2021Stable',
    'Gowal2021Improving_28_10_ddpm_100m',
    'Pang2022Robustness_WRN70_16',
]

def main(args, hparams, test_hparams):

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    dataset = vars(datasets)['CIFAR10'](args.data_dir, device)

    test_loader = DataLoader(
        dataset=dataset.splits['test'],
        batch_size=test_hparams['batch_size'],
        num_workers=dataset.N_WORKERS,
        pin_memory=False,
        shuffle=False,
        drop_last=True)

    os.makedirs(args.output_dir, exist_ok=True)

    all_results = []
    for idx, model in enumerate(MODEL_NAMES):
        print(f'Model {idx}/{len(MODEL_NAMES)} | Name: {model}')
        classifier = load_model(
            model_name=model, 
            dataset='cifar10', 
            threat_model='Linf').to(device)
        
        algorithm = vars(algorithms)['EmptyAlgorithm'](
            classifier=classifier,
            hparams=hparams,
            device=device).to(device)
        
        evaluator = vars(evalulation_methods)['BETA'](
            algorithm=algorithm,
            device=device,
            output_dir='',
            test_hparams=test_hparams)

        results = evaluator.calculate(test_loader)
        results['Model'] = model
        all_results.append(results)

        # re-save results after each run
        df = pd.DataFrame(all_results)
        df.to_pickle(os.path.join(args.output_dir, 'test_performance.pd'))


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Adversarial robustness')
    parser.add_argument('--data_dir', type=str, default='./advbench/data')
    parser.add_argument('--output_dir', type=str, default='robustbench_eval')
    parser.add_argument('--beta_n_steps', type=int, default=20)
    parser.add_argument('--beta_lr', type=float, default=1e-2)
    parser.add_argument('--batch_size', type=int, default=16)
    args = parser.parse_args()

    hparams = hparams_registry.default_hparams('PGD', 'CIFAR10')
    test_hparams = hparams_registry.test_hparams('PGD', 'CIFAR10')

    test_hparams['beta_lr'] = args.beta_lr
    test_hparams['beta_n_steps'] = args.beta_n_steps
    test_hparams['batch_size'] = args.batch_size

    main(args ,hparams, test_hparams)