import argparse
import os
import torch
from torch.utils.data import DataLoader
import pandas as pd
import collections

from advbench.lib import misc
from advbench import algorithms
from advbench import datasets
from advbench import networks
from advbench import evalulation_methods

START_EPOCH = 199
STOP_EPOCH = 199

def main(args, hparams, test_hparams):

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    dataset = vars(datasets)[args.dataset](args.data_dir, device)

    classifier = vars(networks)[args.architecture](
        dataset.INPUT_SHAPE,
        dataset.NUM_CLASSES
    ).to(device)

    algorithm = vars(algorithms)[args.algorithm](
        classifier=classifier,
        hparams=hparams,
        device=device
    ).to(device)

    validation_loader = DataLoader(
        dataset=dataset.splits['validation'],
        batch_size=test_hparams['batch_size'],
        num_workers=dataset.N_WORKERS,
        pin_memory=False,
        shuffle=False,
        drop_last=algorithm.DROP_LAST_BATCH)
    test_loader = DataLoader(
        dataset=dataset.splits['test'],
        batch_size=test_hparams['batch_size'],
        num_workers=dataset.N_WORKERS,
        pin_memory=False,
        shuffle=False,
        drop_last=algorithm.DROP_LAST_BATCH)
    
    records = []
    for epoch in range(START_EPOCH, STOP_EPOCH+1):
        fname = os.path.join(
            args.output_dir,
            'ckpts',
            f'model_ckpt_alg-{epoch}.pkl'
        )
        algorithm.load_state_dict(torch.load(fname)['state_dict'])

        evaluators = [
            vars(evalulation_methods)[e](
                algorithm=algorithm,
                device=device,
                output_dir=args.output_dir,
                test_hparams=test_hparams)
            for e in args.evaluators]
        
        results = {'Epoch': epoch, 'Validation': {}, 'Test': {}}
        
        for evaluator in evaluators:
            for k, v in evaluator.calculate(validation_loader).items():
                results['Validation'].update({k: v})
            for k, v in evaluator.calculate(test_loader).items():
                results['Test'].update({k: v})
            print(results['Test'])

        records.append(results)

    validation_dict = collections.defaultdict(lambda: [])
    test_dict = collections.defaultdict(lambda: [])

    for record in records:
        for k in records[0]['Validation'].keys():
            validation_dict[k].append(record['Validation'][k])
            test_dict[k].append(record['Test'][k])

    def dict_to_dataframe(split, d):
        df = pd.DataFrame.from_dict(d)
        df['Split'] = split
        df = df.join(pd.DataFrame({
            'Algorithm': args.algorithm,
            'trial_seed': args.trial_seed,
            'Architecture': args.architecture,
            'seed': args.seed,
            'path': args.output_dir
        }, index=df.index))
        df['Epoch'] = range(START_EPOCH, STOP_EPOCH+1)
        return df
    
    validation_df = dict_to_dataframe('Validation', validation_dict)
    test_df = dict_to_dataframe('Test', test_dict)
    selection_df = pd.concat([validation_df, test_df], ignore_index=True)

    selection_df.to_pickle(os.path.join(
        args.output_dir, 
        f'selection-test-time-rerun-no-ams.pd'
    ))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Evaluate robustness')
    parser.add_argument('--input_dir', type=str, default='train_output')
    parser.add_argument('--evaluators', type=str, nargs='+', default=['Clean'])

    args = parser.parse_args()

    train_args = argparse.Namespace(
        **misc.read_dict(os.path.join(args.input_dir, 'args.json'))
    )
    train_args.evaluators = args.evaluators

    hparams = misc.read_dict(
        os.path.join(args.input_dir, 'hparams.json')
    )
    test_hparams = misc.read_dict(
        os.path.join(args.input_dir, 'test_hparams.json')
    )

    main(train_args, hparams, test_hparams)