import argparse
import os
import torch
import torch.backends.cudnn as cudnn
import numpy as np
from config import cfg, process_args
from dataset import make_dataset,  process_dataset
from metric import make_logger
from model import make_model
from module import save, resume, process_control
import pandas as pd
import pickle


cudnn.benchmark = True
parser = argparse.ArgumentParser(description='cfg')
for k in cfg:
    exec('parser.add_argument(\'--{0}\', default=cfg[\'{0}\'], type=type(cfg[\'{0}\']))'.format(k))
parser.add_argument('--control_name', default=None, type=str)
args = vars(parser.parse_args())
process_args(args)


def main():
    seeds = list(range(cfg['init_seed'], cfg['init_seed'] + cfg['num_experiments']))
    for i in range(cfg['num_experiments']):
        tag_list = [str(seeds[i]), cfg['control_name']]
        cfg['tag'] = '_'.join([x for x in tag_list if x])
        process_control()
        print('Experiment: {}'.format(cfg['tag']))
        runExperiment()
    return


def runExperiment():
    cfg['seed'] = int(cfg['tag'].split('_')[0])
    np.random.seed(cfg['seed'])
    cfg['path'] = os.path.join('output', 'exp')
    cfg['tag_path'] = os.path.join(cfg['path'], cfg['tag'])
    cfg['checkpoint_path'] = os.path.join(cfg['tag_path'], 'checkpoint')
    cfg['best_path'] = os.path.join(cfg['tag_path'], 'best')
    cfg['logger_path'] = os.path.join('output', 'logger', 'train', 'runs', cfg['tag'])
    cfg['result_path'] = os.path.join('output', 'result', 'mean', cfg['tag'])
    cfg['pred_path'] = os.path.join('output', 'pred', cfg['tag'])
    cfg['step'] = 0

    result = resume(cfg['best_path'])
    if result is None:
        raise ValueError('No valid model, please train model first')
    

    model = make_model(cfg['model'])
    model.load_state_dict(result['model'])
    dataset = make_dataset(cfg['data_name'])
    dataset = process_dataset(dataset)
    data_test = dataset['test']
    logger = make_logger(cfg['logger_path'], data_name=cfg['data_name'])

    test(data_test=data_test, model = model, logger=logger)
    result = resume(cfg['checkpoint_path'])
    result = {'cfg': cfg, 'model': model,
                  'logger': logger.state_dict()}
    save(result, cfg['result_path'])
    return

def test(data_test, model, logger):
    
    model.train(False) 
    input = process_baseline_input(data_test)
    
    if os.path.exists(cfg['pred_path']):
        with open(cfg['pred_path'], 'rb') as f:
            output = pickle.load(f)
    else:
        output = model(input)
        # save output for faster evaluation
        with open(cfg['pred_path'], 'wb') as f:
            pickle.dump(output, f)

    evaluation = logger.evaluate('test', 'batch', input, output)
    logger.append(evaluation, 'test')
    
    # for full batch evaluation
    logger.add('test', input, output)
    evaluation = logger.evaluate('test', 'full', input, output)
    logger.append(evaluation, 'test')
    
    info = {'info': ['Model: {}'.format(cfg['tag']),
                        'Test Epoch: {}({:.0f}%)'.format(cfg['step'] // cfg['eval_period'], 100.)]}
    logger.append(info, 'test')
    print(logger.write('test'))
    logger.save(True)
    return


def process_baseline_input(dataset):
    input = {}
    input_df = pd.DataFrame(data = dataset.data)
    input['data'] = input_df
    input['target'] = dataset.target
    input['sensitive'] = dataset.sensitive

    return input

if __name__ == "__main__":
    main()


