from __future__ import division
import chainer
from chainer.training.updaters import StandardUpdater
from chainer import functions as F
from chainer import Variable
from chainer.dataset import concat_examples
from chainer.backends.cuda import to_cpu, to_gpu
import numpy as np

from source import loss_fxns as lf
from source import yaml_utils as yu


class IRUpdater(StandardUpdater):

    def __init__(self, *args, **kwargs):
        
        self.kwargs = kwargs
        self.envs = kwargs.pop('envs')
        self.env_num = len(self.envs)
        self.generator = kwargs.pop('generator')
        self.env_ag_predictor = kwargs.pop('env_ag_predictor')
        self.env_aw_model_list = kwargs.pop('env_aw_predictor')
        self.nll_config = kwargs.pop('nll')
        self.nll_func = yu.load_module(self.nll_config.pop('fn'),
                                       self.nll_config.pop('name'))
        self.penalty_weight = kwargs.pop('penalty_weight')
        self.xp = self.generator.xp
        super(IRUpdater, self).__init__(*args, **kwargs)

    def update_core(self):

        opt_gen = self.get_optimizer('opt_gen')
        opt_env_ag = self.get_optimizer('opt_env_ag')

        nll_ag = 0.0
        nll_aw_all = 0.0
        acc_ag = 0.0
        acc_aw = 0.0
        
        for env_idx, env in enumerate(self.envs):
            env_aw_predictor = self.env_aw_model_list[env_idx]
            opt_env_aw = self.get_optimizer('opt_env_aw{}'.format(env_idx))
            
            batch = self.get_iterator(env).next()
            variabledict = concat_examples(batch, device=self.device)

            var_x = variabledict['x']
            var_y = variabledict['y']
            
            feature = self.generator(var_x)
            logits_ag = self.env_ag_predictor(feature)
            logits_aw = env_aw_predictor(feature)
            
            nll_ag += self.nll_func(logits_ag, var_y)
            acc_ag += lf.mean_accuracy(logits_ag, var_y)
            
            nll_aw = self.nll_func(logits_aw, var_y)
            nll_aw_all += nll_aw
            acc_aw += lf.mean_accuracy(logits_aw, var_y)
            env_aw_predictor.cleargrads()
            nll_aw.backward()
            opt_env_aw.update()
            
        #Take average
        nll_ag /= self.env_num
        nll_aw_all /= self.env_num
        acc_ag /= self.env_num
        acc_aw /= self.env_num

        #make the loss
        self.env_ag_predictor.cleargrads()
        nll_ag.backward()
        opt_env_ag.update()
        
        loss_gen = nll_ag + self.penalty_weight * F.relu(nll_ag - nll_aw_all)
        self.generator.cleargrads()
        loss_gen.backward()
        opt_gen.update()

        chainer.reporter.report({'nll_ag' : nll_ag.array})
        chainer.reporter.report({'nll_aw_all' : nll_aw_all.array})
        chainer.reporter.report({'acc_ag' : acc_ag.array})
        chainer.reporter.report({'acc_aw': acc_aw.array})

    @property
    def epoch(self):
        return self._iterators[self.envs[0]].epoch

    @property
    def epoch_detail(self):
        return self._iterators[self.envs[0]].epoch_detail

    @property
    def previous_epoch_detail(self):
        return self._iterators[self.envs[0]].previous_epoch_detail

    @property
    def is_new_epoch(self):
        return self._iterators[self.envs[0]].is_new_epoch
        
