import os
import os.path
import json
import copy
import time
import torch
import logging

from utils.factory import get_model
from utils.toolkit import make_logdir, print_args, format_elapsed_time
from utils.toolkit import setup_logging, set_device, set_random
from dataloaders.data_manager import DataManager


def train(args):
    seed_list = copy.deepcopy(args['seed'])
    device = copy.deepcopy(args['device']).split(',')

    args['logdir'] = make_logdir(args)
    config_path = os.path.join(args['logdir'], 'config.json')
    with open(config_path, 'w') as f:
        json.dump(args, f, indent=4)

    for seed in seed_list:
        args['seed'] = seed
        args['device'] = device
        _train(args)


def _train(args):
    args['logfilename'] = os.path.join(args['logdir'], 'seed{}'.format(args['seed']))
    setup_logging(args['logfilename'], args['no_ckp'])
    logging.info('Load pretrained model: {}'.format(args['load']))
    logging.info('Save model to: {}'.format(args['logfilename']))

    # random seed and device
    set_random(args)
    set_device(args)
    print_args(args)

    # datamanager
    data_manager = DataManager(args['dataset'], args['shuffle'], args['seed'], args['init_cls'], args['increment'], args)
    args['class_order'] = data_manager._class_order

    # model
    model = get_model(args['model_name'], args)

    cnn_curve, cnn_curve_with_task, cnn_curve_task = {'top1': []}, {'top1': []}, {'top1': []}

    # Train and Eval sequentially for N tasks
    for task in range(data_manager.task_num):

        model.before_task(data_manager)

        # learning on the new task (train)
        time_start = time.time()
        model.incremental_train(task, data_manager)
        time_end = time.time()
        logging.info('Training time: {}'.format(format_elapsed_time(time_start, time_end)))

        # evaluate the model (eval)
        time_start = time.time()
        cnn_accy, cnn_accy_with_task, cnn_accy_task = model.incremental_test(data_manager)
        time_end = time.time()
        logging.info('Evaluation time: {}'.format(format_elapsed_time(time_start, time_end)))

        model.after_task()

        # logging
        logging.info('CNN Accuracy: {}'.format(cnn_accy['grouped']))
        cnn_curve['top1'].append(cnn_accy['top1'])
        cnn_curve_with_task['top1'].append(cnn_accy_with_task['top1'])
        cnn_curve_task['top1'].append(cnn_accy_task)
        logging.info('(curve) CNN top1: {}'.format(cnn_curve['top1']))  # Average Accuracy (A_t)
        logging.info('(curve) CNN top1 with task: {}'.format(cnn_curve_with_task['top1']))  # Average Accuracy with task id
        logging.info('(curve) CNN top1 task: {}'.format(cnn_curve_task['top1']))
        print('='*100)

        if not args['no_ckp']:
            torch.save(model.network.state_dict(), os.path.join(args['logfilename'], "task_{}.pth".format(int(task))))
