from __future__ import division
import chainer
from chainer.training.updaters import StandardUpdater
from chainer import functions as F
from chainer import Variable
from chainer.dataset import concat_examples
from chainer.backends.cuda import to_cpu, to_gpu
import numpy as np

from source import loss_fxns as lf
from source import yaml_utils as yu


class EachPenaltyUpdaterFullBatch(StandardUpdater):
    def __init__(self, *args, **kwargs):
        
        self.kwargs = kwargs
        self.envs = kwargs.pop('envs')
        self.env_num = len(self.envs)
        self.model = kwargs.pop('base_model')
        self.nll_config = kwargs.pop('nll')
        self.penalty_config = kwargs.pop('penalty')
        self.nll_func = yu.load_module(self.nll_config.pop('fn'),
                                       self.nll_config.pop('name'))
        self.penalty_func = yu.load_module(self.penalty_config.pop('fn'),
                                           self.penalty_config.pop('name'))
        self.penalty_weight = kwargs.pop('penalty_weight')
        self.loss_normalize = kwargs.pop('loss_normalize')
        self.xp = self.model.xp
        self.scale = Variable(self.xp.array(1.0, dtype='f'))
        super(EachPenaltyUpdaterFullBatch, self).__init__(*args, **kwargs)
        
        self.memorized_dataset = {}
        self.data_num = []
        for env in self.envs:
            self.memorized_dataset[env] = {}
            self.get_iterator(env).batch_size = len(self.get_iterator(env).dataset)
            batch = self.get_iterator(env).next()
            bs = self.get_iterator(env).batch_size
            #make the dictionary of memorized datasets in advance
            for key in batch[0].keys():
                variabledict = concat_examples(batch, device=self.device)
                self.memorized_dataset[env].update({key:variabledict[key]})

    def update_core(self):

        optimizer = self.get_optimizer('main')

        loss = 0.0
        nll = 0.0
        acc = 0.0
        penalty = 0.0

        xp = self.model.xp
        for env in self.envs:

            var_x = self.memorized_dataset[env]['x']
            var_y = self.memorized_dataset[env]['y']
            
            logits = self.scale * self.model(var_x)
            
            env_nll = self.nll_func(logits, var_y)
            nll += env_nll
            penalty += self.penalty_func(self, env_nll)
            acc += lf.mean_accuracy(logits, var_y).array

        # Take average
        nll /= self.env_num
        acc /= self.env_num
        penalty /= self.env_num

        # make the loss
        loss += nll
        loss += self.penalty_weight * penalty

        self.model.cleargrads()
        loss.backward()
        optimizer.update()

        chainer.reporter.report({'train_total_loss': loss.array})
        chainer.reporter.report({'train_nll' : nll.array})
        chainer.reporter.report({'train_penalty': penalty.array})
        chainer.reporter.report({'train_accuracy': acc})

    @property
    def epoch(self):
        return self._iterators[self.envs[0]].epoch

    @property
    def epoch_detail(self):
        return self._iterators[self.envs[0]].epoch_detail

    @property
    def previous_epoch_detail(self):
        return self._iterators[self.envs[0]].previous_epoch_detail

    @property
    def is_new_epoch(self):
        return self._iterators[self.envs[0]].is_new_epoch
        

class TotalPenaltyUpdaterFullBatch(StandardUpdater):
    
    def __init__(self, *args, **kwargs):
        
        self.kwargs = kwargs
        self.envs = kwargs.pop('envs')
        self.env_num = len(self.envs)
        self.model = kwargs.pop('base_model')
        self.nll_config = kwargs.pop('nll')
        self.penalty_config = kwargs.pop('penalty')
        self.nll_func = yu.load_module(self.nll_config.pop('fn'),
                                       self.nll_config.pop('name'))
        self.penalty_func = yu.load_module(self.penalty_config.pop('fn'),
                                           self.penalty_config.pop('name'))
        self.penalty_weight = kwargs.pop('penalty_weight')
        self.loss_normalize = kwargs.pop('loss_normalize')
        self.xp = self.model.xp
        self.scale = Variable(self.xp.array(1.0, dtype='f'))
        self._iteration = 0
        super(TotalPenaltyUpdaterFullBatch, self).__init__(*args, **kwargs)

        self.memorized_dataset = {}
        self.data_num = []
        for env in self.envs:
            self.memorized_dataset[env] = {}
            self.get_iterator(env).batch_size = len(self.get_iterator(env).dataset)
            batch = self.get_iterator(env).next()
            bs = self.get_iterator(env).batch_size
            #make the dictionary of memorized datasets in advance
            for key in batch[0].keys():
                variabledict = concat_examples(batch, device=self.device)
                self.memorized_dataset[env].update({key:variabledict[key]})

    def update_core(self):

        optimizer = self.get_optimizer('main')

        loss = 0.0
        nll = 0.0
        acc = 0.0
        penalty = 0.0
        loss_list = []

        xp = self.model.xp
        for env in self.envs:

            var_x = self.memorized_dataset[env]['x']
            var_y = self.memorized_dataset[env]['y']
            
            logits = self.scale * self.model(var_x)
            
            env_nll = self.nll_func(logits, var_y)
            nll += env_nll
            loss_list.append(F.reshape(env_nll, (1, 1)))
            acc += lf.mean_accuracy(logits, var_y).array

        #Take average
        nll /= self.env_num
        acc /= self.env_num
        
        penalty = self.penalty_func(self, loss_list)

        # make the loss
        loss += nll
        loss += self.penalty_weight * penalty
        if self.loss_normalize:
            loss = loss / (1.0 + self.penalty_weight)

        self.model.cleargrads()
        loss.backward()
        optimizer.update()
        
        self._iteration += 1
        chainer.reporter.report({'train_total_loss': loss.array})
        chainer.reporter.report({'train_nll' : nll.array})
        chainer.reporter.report({'train_penalty': penalty.array})
        chainer.reporter.report({'train_accuracy': acc})

    @property
    def epoch(self):
        return self._iterators[self.envs[0]].epoch

    @property
    def epoch_detail(self):
        return self._iterators[self.envs[0]].epoch_detail

    @property
    def previous_epoch_detail(self):
        return self._iterators[self.envs[0]].previous_epoch_detail

    @property
    def is_new_epoch(self):
        return self._iterators[self.envs[0]].is_new_epoch
        
