from tasks.generate_tasks import generate_tasks
import torch
from training.training import training
from training.evaluation import evaluation
from argparse import ArgumentParser
import importlib
import numpy as np
from training.utils import write_matrix,write_line,str2bool
import time

parser = ArgumentParser()


parser.add_argument('--config', '-c', required=True)
parser.add_argument('--transfer', '-t',required=True)
parser.add_argument('--sim', '-s',required=True)


def main():


    args = parser.parse_args()
    results_all = []
    cfg = importlib.import_module('configs.' + args.config)
    cfg.transfer_posterior = str2bool(args.transfer)
    cfg.folder = cfg.folder + '-' + str(cfg.transfer_posterior)
    for sim in range(0, int(args.sim)):



        train_dl_list, test_dl_list = generate_tasks(cfg.tasks_description,
                                                     cfg.datadir,
                                                     cfg.batch_size_train,
                                                     cfg.batch_size_test,
                                                     sim)
        torch.manual_seed(sim)

        print('start training', flush=True)

        start_time = time.time()
        selected_model = training(cfg, train_dl_list, sim, start_time)
        print('selected_model',selected_model)
        print('time after training:', time.time() - start_time, flush=True)

        #selected_model = ['results/results-split-mnist-False/model-0.save', 'results/results-split-mnist-False/model-1.save',
        #                  'results/results-split-mnist-False/model-2.save', 'results/results-split-mnist-False/model-3.save',
        #                  'results/results-split-mnist-False/model-4.save']

        accuracies = evaluation(cfg, test_dl_list, selected_model, sim, start_time)
        print('time after evaluation:', time.time() - start_time, flush=True)

        results_all.append(accuracies)

        print(results_all)
        print('axis0 mean', np.mean(results_all, axis=0))
        write_line(cfg.folder,cfg.results_file, 'mean')
        write_line(cfg.folder,cfg.results_file, str(np.mean(results_all, axis=0)))
        print('axis0 std', np.std(results_all, axis=0))
        write_line(cfg.folder, cfg.results_file, 'std')
        write_line(cfg.folder, cfg.results_file, str(np.std(results_all, axis=0)))
        print(
            np.save(
                str(cfg.folder) + '/results' + str(cfg.transfer_posterior) +
                '.npy', results_all))
    print(results_all)


if __name__ == '__main__':
    main()
