
import argparse
import numpy as np
import os
import shutil
import datetime
import pandas as pd
import time
from config import cfg, process_args
from dataset import make_dataset, process_dataset
from metric import make_logger
from model import make_model
from module import check, process_control
import warnings
warnings.filterwarnings("ignore")

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"  # Disable GPU


parser = argparse.ArgumentParser(description='cfg')
for k in cfg:
    exec('parser.add_argument(\'--{0}\', default=cfg[\'{0}\'], type=type(cfg[\'{0}\']))'.format(k))
parser.add_argument('--control_name', default=None, type=str)
args = vars(parser.parse_args())
process_args(args)


def main():
    np.random.seed(cfg['init_seed'])
    seeds = list(range(cfg['init_seed'], cfg['init_seed'] + cfg['num_experiments']))
    for i in range(cfg['num_experiments']):
        tag_list = [str(seeds[i]), cfg['control_name']]
        cfg['tag'] = '_'.join([x for x in tag_list if x])
        process_control()
        print('Experiment: {}'.format(cfg['tag']))
        runExperiment()
    return


def runExperiment():
    cfg['seed'] = int(cfg['tag'].split('_')[0])
    np.random.seed(cfg['seed'])
    cfg['path'] = os.path.join('output', 'exp')
    cfg['tag_path'] = os.path.join(cfg['path'], cfg['tag'])
    cfg['checkpoint_path'] = os.path.join(cfg['tag_path'], 'checkpoint')
    cfg['best_path'] = os.path.join(cfg['tag_path'], 'best')
    cfg['logger_path'] = os.path.join('output', 'logger', 'train', 'runs', cfg['tag'])
    cfg['step'] = 0

    logger = make_logger(cfg['logger_path'], data_name=cfg['data_name'])
    dataset = make_dataset(cfg['data_name'])
    dataset = process_dataset(dataset)
    model = make_model(cfg['model'])
    
    data_train = dataset['train']
    data_test = dataset['test']
    while cfg['step'] < cfg['num_steps']:
        train(data_train=data_train, model = model, logger=logger)
        test(data_test=data_test, model = model, logger=logger)
        result = {'cfg': cfg, 'model': model,
                  'logger': logger.state_dict()}
        check(result, cfg['checkpoint_path'])
        if logger.compare('test'):
            shutil.copytree(cfg['checkpoint_path'], cfg['best_path'], dirs_exist_ok=True)
        logger.reset()
    return



def train(data_train, model, logger):
    start_time = time.time()
    model.train(True)
    input = process_baseline_input(data_train)
    with logger.profiler: 
        input_size = input['data'].shape[0]       
        output = model(input)
        evaluation = logger.evaluate('train', 'batch', input, output)
        logger.append(evaluation, 'train', input_size)

        step_time = (time.time() - start_time)
        epoch_finished_time = datetime.timedelta(
            seconds=round((cfg['eval_period'] -  1) * step_time))
        exp_finished_time = datetime.timedelta(
            seconds=round((cfg['num_steps'] - (cfg['step'] + 1)) * step_time))
        info = {'info': ['Model: {}'.format(cfg['tag']),
                            'Train Epoch: {}({:.0f}%)'.format((cfg['step'] // cfg['eval_period']) + 1, 
                                                              100),
                            'Epoch Finished Time: {}'.format(epoch_finished_time),
                            'Experiment Finished Time: {}'.format(exp_finished_time)]}
        logger.append(info, 'train')
        print(logger.write('train'))
        cfg['step'] += 1
    
    return


def test(data_test, model, logger):
    
    model.train(False) 
    input = process_baseline_input(data_test)
    output = model(input)

    evaluation = logger.evaluate('test', 'batch', input, output)
    logger.append(evaluation, 'test')
    # for full batch evaluation
    logger.add('test', input, output)
    evaluation = logger.evaluate('test', 'full', input, output)
    logger.append(evaluation, 'test')
    
    info = {'info': ['Model: {}'.format(cfg['tag']),
                        'Test Epoch: {}({:.0f}%)'.format(cfg['step'] // cfg['eval_period'], 100.)]}
    logger.append(info, 'test')
    print(logger.write('test'))
    logger.save(True)
    return


def process_baseline_input(dataset):
    input = {}
    input_df = pd.DataFrame(data = dataset.data)
    input['data'] = input_df
    input['target'] = dataset.target
    input['sensitive'] = dataset.sensitive

    return input

if __name__ == "__main__":
    main()