from __future__ import division
from __future__ import print_function

import chainer
chainer.config.cudnn_deterministic = True
chainer.config.autotune = False
import warnings

import os
import numpy as np
from chainer import training
import matplotlib
import ipdb
# Disable interactive backend
matplotlib.use('Agg')
import argparse
import yaml
import functools
import copy
from source import yaml_utils as yu
from source import misc_functions as misc

from chainer.training import extension
from chainer.training import extensions


def main(config):
    
    if config['debug']:
        chainer.global_config.cudnn_deterministic = True
        misc.set_random_seed(config['seed'])
    
    #Setting up the saving directory
    savedir = os.path.join(config['out'], str(config['result_ext_name']))
    if not os.path.exists(savedir):
        os.makedirs(savedir)

    # Get the list of datasets
    dataset_fxn = yu.load_component_fxn(config['dataset'])
    dataset = dataset_fxn(**config['dataset']['args'])
    train_dataset_list, test_dataset_list \
        = dataset.train_datasetlist,  dataset.test_datasetlist

    # Iterators, made for each environment
    num_tr_envs =len(train_dataset_list)
    num_ts_envs =len(test_dataset_list)

    train_iterator_dict = {}
    train_iterator_dict_for_eval = {}
    tr_env_indices = np.array(range(num_tr_envs)).astype('str')
    ts_env_indices = np.array(range(num_ts_envs)).astype('str')

    for k in range(len(tr_env_indices)):
        #train_iterator_dict[tr_env_indices[k]] = chainer.iterators.SerialIterator(
        #    train_dataset_list[k], config['batchsize'])
        train_iterator_dict[tr_env_indices[k]] = \
            chainer.iterators.MultithreadIterator(
            train_dataset_list[k], config['batchsize'])
        train_iterator_dict_for_eval[tr_env_indices[k]] = \
            chainer.iterators.MultithreadIterator(
            train_dataset_list[k], config['test_batchsize'], repeat=False, shuffle=False)

    #For debugging, use config['test_batchsize'] to make the answer exactly same as
    #irm_base.
    test_iterator_dict = {}
    for k in range(len(ts_env_indices)):
        test_iterator_dict[ts_env_indices[k]] = \
            chainer.iterators.SerialIterator(
            test_dataset_list[k], config['test_batchsize'], repeat=False, shuffle=False)
        #test_iterator_dict[ts_env_indices[k]] = chainer.iterators.MultithreadIterator(
        #    test_dataset_list[k], config['batchsize'])

    # Create a base model, send it to gpu
    opt_dict = {}
    gen_modelclass = yu.load_component_fxn(config['generator'])
    generator = gen_modelclass(**config['generator']['args'])
    opt_gen = yu.load_component(config['opt'])
    opt_gen.setup(generator)
    opt_gen.add_hook(chainer.optimizer.WeightDecay(config['opt']['wd_rate']))
    opt_dict['opt_gen'] = opt_gen
    
    env_ag_modelclass = yu.load_component_fxn(config['env_ag_predictor'])
    env_ag_predictor = env_ag_modelclass(out_dim=config['dataset']['args']['out_dim'],
                                         **config['env_ag_predictor']['args'])
    opt_env_ag = yu.load_component(config['opt'])
    opt_env_ag.setup(env_ag_predictor)
    opt_env_ag.add_hook(chainer.optimizer.WeightDecay(config['opt']['wd_rate']))
    opt_dict['opt_env_ag'] = opt_env_ag
    
    env_aw_model_list = []
    for k in range(len(tr_env_indices)):
        env_aw_modelclass = yu.load_component_fxn(config['env_aw_predictor'])
        env_aw_predictor = env_aw_modelclass(out_dim=config['dataset']['args']['out_dim'],
                                             **config['env_aw_predictor']['args'])
        env_aw_model_list.append(env_aw_predictor)
        
        opt_env_aw = yu.load_component(config['opt'])
        opt_env_aw.setup(env_aw_predictor)
        opt_env_aw.add_hook(chainer.optimizer.WeightDecay(config['opt']['wd_rate']))
        opt_dict['opt_env_aw{}'.format(k)] = opt_env_aw
    
    if config['gpu'] >= 0:
        chainer.backends.cuda.get_device_from_id(config['gpu']).use()
        generator.to_gpu()
        env_ag_predictor.to_gpu()
        for _model in env_aw_model_list:
            _model.to_gpu()

    # Create an updater
    nll_fxn = yu.load_module(config['updater']['args']['nll']['fn'],
                             config['updater']['args']['nll']['name'])
    updater_name = copy.copy(config['updater']['name'])
    nll_fxn_eval = yu.load_module(config['evaluation']['args']['nll']['fn'],
                                  config['evaluation']['args']['nll']['name'])
    
    updater_kwargs = copy.copy(config['updater']['args']) if 'args' in copy.copy(config['updater']) else {}
    updater_kwargs.update({
        'device': config['gpu'],
        'iterator': train_iterator_dict,
        'envs' : tr_env_indices,
        'generator': generator,
        'env_ag_predictor': env_ag_predictor,
        'env_aw_predictor': env_aw_model_list,
        'optimizer': opt_dict
    })

    updater_fxn = yu.load_module(config['updater']['fn'],
                                 config['updater']['name'])
    updater = updater_fxn(**updater_kwargs)

    #Set up the trainer
    trainer = training.Trainer(updater,
                               (config['iteration'], 'iteration'), out=savedir)
        
    print_keys = ['iteration', 'nll_ag', 'nll_aw_all', 'acc_ag',
                  'acc_aw', 'test_ood_nll', 'nll_ag_exact', 'nll_aw_exact',
                  'test_ood_accuracy', 'acc_ag_exact', 'acc_aw_exact']
    if config['add_print'] != None:
        add_print_list = config['add_print'].split(',')
        for key in add_print_list:
            print_keys.append(key)
        
    report_keys = ['iteration', 'nll_ag', 'nll_aw_all', 'acc_ag',
                   'acc_aw', 'test_nll', 'test_ood_nll', 'nll_ag_exact', 'nll_aw_exact',
                   'test_accuracy', 'test_ood_accuracy', 'acc_ag_exact', 'acc_aw_exact',
                   'elapsed_time']
    if config['add_report'] != None:
        add_report_list = config['add_report'].split(',')
        for key in add_report_list:
            report_keys.append(key)
            

    trainer.extend(extensions.PrintReport(print_keys),
                   trigger=(config['display_interval'], 'iteration'))
    trainer.extend(extensions.LogReport(keys=report_keys,
                                        trigger=(config['report_interval'],
                                                 'iteration')))
    
    if config['lr_decay_extension']['use_flag']:
        lr_decay_fxn_whele = yu.load_component_fxn(config['lr_decay_extension'])
        lr_decay_fxn_gen, trigger = lr_decay_fxn_whele(optimizer=opt_gen,
                                                   opt_config=config['opt'],
                                                   decay_config=config['lr_decay_extension'])
        trainer.extend(lr_decay_fxn_gen, trigger=trigger)
        
        lr_decay_fxn_ag, trigger = lr_decay_fxn_whele(optimizer=opt_env_ag,
                                                   opt_config=config['opt'],
                                                   decay_config=config['lr_decay_extension'])
        trainer.extend(lr_decay_fxn_ag, trigger=trigger)
        
        for k in range(len(tr_env_indices)):
            lr_decay_fxn_aw, trigger = lr_decay_fxn_whele(optimizer=opt_dict['opt_env_aw{}'.format(k)],
                                                       opt_config=config['opt'],
                                                       decay_config=config['lr_decay_extension'])
            trainer.extend(lr_decay_fxn_aw, trigger=trigger)
    
    #Evaluation
    test_evaluator_fxn = yu.load_component_fxn(config['evaluation']['test_data'])
    train_evaluator_fxn = yu.load_component_fxn(config['evaluation']['train_data'])
    
    trainer.extend(test_evaluator_fxn(generator=generator, env_ag_model=env_ag_predictor,
                                      nll_func=nll_fxn_eval,
                                      ts_env_indices=ts_env_indices,
                                      iterator_dict=test_iterator_dict,
                                      test_dataset_list = []),
                   trigger=(config['test_acc_eval_interval'], 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(train_evaluator_fxn(generator=generator, env_ag_model=env_ag_predictor,
                                       env_aw_model_list=env_aw_model_list,
                                       nll_func=nll_fxn_eval,
                                       tr_env_indices=tr_env_indices,
                                       iterator_dict=train_iterator_dict_for_eval,
                                       test_dataset_list = []),
                   trigger=(config['train_acc_eval_interval'], 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(extensions.ProgressBar(update_interval=10))
    # if config['snap_flag']:
    #     trainer.extend(extensions.snapshot_object(model,
    #                    model.__class__.__name__ + '_{.updater.iteration}.npz'),
    #                    trigger=(config['snapshot_interval'], 'iteration'))
    #     trainer.extend(extensions.snapshot_object(model,
    #                    model.__class__.__name__ + '_{.updater.iteration}.npz'),
    #                    trigger=(config['snapshot_interval'], 'iteration'))
    #     trainer.extend(extensions.snapshot_object(model,
    #                    model.__class__.__name__ + '_{.updater.iteration}.npz'),
    #                    trigger=(config['snapshot_interval'], 'iteration'))

    print("start training")
    trainer.run()
    return savedir


if __name__ == '__main__':
    config, args = yu.parse_args()

    for attr in args.attrs:
        module, new_value = attr.split('=')
        keys = module.split('.')
        target = functools.reduce(dict.__getitem__, keys[:-1], config)
        if keys[-1] in target.keys():
            target[keys[-1]] = yaml.load(new_value)
        else:
            raise ValueError('sonna key naissu...:{}', keys)

    if not args.warning:
        # Ignore warnings
        warnings.simplefilter('ignore')


    for k, v in sorted(config.items()):
        print("\t{}: {}".format(k, v))

    _config = copy.deepcopy(config)
    savedir = main(config)
    savepath = os.path.join(savedir, 'config.yml')
    open(savepath, 'w').write(
        yaml.dump(_config, default_flow_style=False))
