from __future__ import division
from __future__ import print_function

import chainer
chainer.config.cudnn_deterministic = True
chainer.config.autotune = False
import warnings

import os
import numpy as np
from chainer import training
import matplotlib
import ipdb
# Disable interactive backend
matplotlib.use('Agg')
import argparse
import yaml
import functools
import copy
from source import yaml_utils as yu
from source import misc_functions as misc

from chainer.training import extension
from chainer.training import extensions


def main(config):
    
    if config['debug']:
        chainer.global_config.cudnn_deterministic = True
        misc.set_random_seed(config['seed'])
    
    #Setting up the saving directory
    savedir = os.path.join(config['out'], str(config['result_ext_name']))
    if not os.path.exists(savedir):
        os.makedirs(savedir)

    # Get the list of datasets
    dataset_fxn = yu.load_component_fxn(config['dataset'])
    dataset = dataset_fxn(**config['dataset']['args'])
    train_dataset_list, test_dataset_list \
        = dataset.train_datasetlist,  dataset.test_datasetlist

    # Iterators, made for each environment
    num_tr_envs =len(train_dataset_list)
    num_ts_envs =len(test_dataset_list)

    train_iterator_dict = {}
    train_iterator_dict_for_eval = {}
    tr_env_indices = np.array(range(num_tr_envs)).astype('str')
    ts_env_indices = np.array(range(num_ts_envs)).astype('str')

    for k in range(len(tr_env_indices)):
        #train_iterator_dict[tr_env_indices[k]] = chainer.iterators.SerialIterator(
        #    train_dataset_list[k], config['batchsize'])
        train_iterator_dict[tr_env_indices[k]] = \
            chainer.iterators.MultithreadIterator(
            train_dataset_list[k], config['batchsize'])
        train_iterator_dict_for_eval[tr_env_indices[k]] = \
            chainer.iterators.MultithreadIterator(
            train_dataset_list[k], config['test_batchsize'], repeat=False, shuffle=False)

    #For debugging, use config['test_batchsize'] to make the answer exactly same as
    #irm_base.
    test_iterator_dict = {}
    for k in range(len(ts_env_indices)):
        test_iterator_dict[ts_env_indices[k]] = \
            chainer.iterators.SerialIterator(
            test_dataset_list[k], config['test_batchsize'], repeat=False, shuffle=False)
        #test_iterator_dict[ts_env_indices[k]] = chainer.iterators.MultithreadIterator(
        #    test_dataset_list[k], config['batchsize'])

    # Create a base model, send it to gpu
    modelclass = yu.load_component_fxn(config['model'])
    model = modelclass(out_dim=config['dataset']['args']['out_dim'],
                       **config['model']['args'])
    allmodels = {'base_model':model}
    # opt = yu.make_instance(chainer.optimizers, config['opt'])
    opt = yu.load_component(config['opt'])
    opt.setup(model)
    opt.add_hook(chainer.optimizer.WeightDecay(config['opt']['wd_rate']))
    if config['gpu'] >= 0:
        chainer.backends.cuda.get_device_from_id(config['gpu']).use()
        for modelname in allmodels.keys():
            allmodels[modelname].to_gpu()

    # Create an updater
    nll_fxn = yu.load_module(config['updater']['args']['nll']['fn'],
                             config['updater']['args']['nll']['name'])
    plt_func = yu.load_module(config['evaluation']['args']['penalty']['fn'],
                              config['evaluation']['args']['penalty']['name'])
    updater_name = copy.copy(config['updater']['name'])
    nll_fxn_eval = yu.load_module(config['evaluation']['args']['nll']['fn'],
                                  config['evaluation']['args']['nll']['name'])
    
    updater_kwargs = copy.copy(config['updater']['args']) if 'args' in copy.copy(config['updater']) else {}
    updater_kwargs.update({
        'device': config['gpu'],
        'iterator': train_iterator_dict,
        'envs' : tr_env_indices,
        'base_model' : model,
        'optimizer': {'main': opt}
    })

    updater_fxn = yu.load_module(config['updater']['fn'],
                                 config['updater']['name'])
    updater = updater_fxn(**updater_kwargs)

    #Set up the trainer
    trainer = training.Trainer(updater,
                               (config['iteration'], 'iteration'), out=savedir)
        
    print_keys = ['iteration', 'train_total_loss', 'train_nll', 'train_penalty',
                  'train_accuracy', 'test_ood_nll', 'train_nll_exact',
                  'test_ood_accuracy', 'train_acc_exact', 'train_penalty_exact']#'lgrad_norm', 'pgrad_norm',
    if config['add_print'] != None:
        add_print_list = config['add_print'].split(',')
        for key in add_print_list:
            print_keys.append(key)
        
    report_keys = ['iteration', 'train_total_loss', 'train_nll', 'train_penalty',
                   'train_accuracy', 'test_nll', 'test_ood_nll', 'train_nll_exact',
                   'test_accuracy', 'test_ood_accuracy', 'train_acc_exact',
                   'train_penalty_exact', 'elapsed_time']#, 'grad_ip1', 'grad_ip2'#'lgrad_norm', 'pgrad_norm',
    if config['add_report'] != None:
        add_report_list = config['add_report'].split(',')
        for key in add_report_list:
            report_keys.append(key)
            

    trainer.extend(extensions.PrintReport(print_keys),
                   trigger=(config['display_interval'], 'iteration'))
    trainer.extend(extensions.LogReport(keys=report_keys,
                                        trigger=(config['report_interval'],
                                                 'iteration')))
    
    if config['lr_decay_extension']['use_flag']:
        lr_decay_fxn_whele = yu.load_component_fxn(config['lr_decay_extension'])
        lr_decay_fxn, trigger = lr_decay_fxn_whele(optimizer=opt,
                                                   opt_config=config['opt'],
                                                   decay_config=config['lr_decay_extension'])
        trainer.extend(lr_decay_fxn, trigger=trigger)
    
    #Evaluation
    test_evaluator_fxn = yu.load_component_fxn(config['evaluation']['test_data'])
    train_evaluator_fxn = yu.load_component_fxn(config['evaluation']['train_data'])
    
    trainer.extend(test_evaluator_fxn(model=model, nll_func=nll_fxn_eval,
                                      ts_env_indices=ts_env_indices,
                                      iterator_dict=test_iterator_dict,
                                      test_dataset_list = []),
                   trigger=(config['test_acc_eval_interval'], 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(train_evaluator_fxn(model=model, nll_func=nll_fxn_eval,
                                       plt_func=plt_func, updater_name=updater_name,
                                       ts_env_indices=tr_env_indices,
                                       iterator_dict=train_iterator_dict_for_eval,
                                       test_dataset_list = []),
                   trigger=(config['train_acc_eval_interval'], 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(extensions.ProgressBar(update_interval=10))
    if config['snap_flag']:
        trainer.extend(extensions.snapshot_object(model,
                       model.__class__.__name__ + '_{.updater.iteration}.npz'),
                       trigger=(config['snapshot_interval'], 'iteration'))
    
    if config['jump_extension']['use_flag']:
        jump_config = config['jump_extension']
        in_middle_modification = yu.load_component_fxn(jump_config)
        newconfig = copy.copy(updater_kwargs)
        newconfig['penalty_weight'] = jump_config['args']['penalty_jump_to']
        trainer.extend(in_middle_modification(trainer, updater, newconfig),
                       trigger=training.triggers.ManualScheduleTrigger(
                           jump_config['args']['penalty_jump_at'],
                           'iteration'))

    print("start training")
    trainer.run()
    return savedir


if __name__ == '__main__':
    config, args = yu.parse_args()

    for attr in args.attrs:
        module, new_value = attr.split('=')
        keys = module.split('.')
        target = functools.reduce(dict.__getitem__, keys[:-1], config)
        if keys[-1] in target.keys():
            target[keys[-1]] = yaml.load(new_value)
        else:
            raise ValueError('sonna key naissu...:{}', keys)

    if not args.warning:
        # Ignore warnings
        warnings.simplefilter('ignore')


    for k, v in sorted(config.items()):
        print("\t{}: {}".format(k, v))

    _config = copy.deepcopy(config)
    savedir = main(config)
    savepath = os.path.join(savedir, 'config.yml')
    open(savepath, 'w').write(
        yaml.dump(_config, default_flow_style=False))
