import os
import sys
import math

import numpy as np

import chainer
chainer.config.cudnn_deterministic = True
chainer.config.autotune = False
import chainer.cuda
from chainer import Variable
from chainer import functions as F
from source import loss_fxns as lf
import pdb

'''
standard version with iterator function
'''
def test_evaluation(generator=None, env_ag_model=None, nll_func=None,
                    ts_env_indices=[], iterator_dict=[],
                    test_dataset_list=[], **kwargs):
    @chainer.training.make_extension()
    #@profile
    def evaluate(trainer):
        xp = generator.xp
        nll = 0.0
        acc = 0.0
        nll_list = xp.zeros(len(ts_env_indices))
        acc_list = xp.zeros(len(ts_env_indices))
        with chainer.using_config('train', False):
            with chainer.using_config('enable_backprop', False):
                env_index = 0
                for env in ts_env_indices:
                    env_nll = 0.0
                    env_acc = 0.0
                    env_test_count = 0
                    for batch in iterator_dict[env]:

                        variabledict = chainer.dataset.concat_examples(batch,
                                                                       device=generator.device)
                        var_x = variabledict['x']
                        var_y = variabledict['y']

                        logits = env_ag_model(generator(var_x))
                        env_nll += nll_func(logits, var_y).data * len(var_x)
                        env_acc += lf.mean_accuracy(logits, var_y).data * len(var_x)
                        env_test_count += len(var_x)
                        
                    iterator_dict[env].reset()
                    env_nll /= env_test_count
                    env_acc /= env_test_count
                    nll_list[env_index] = env_nll
                    acc_list[env_index] = env_acc
                    env_index += 1
                
        nll = xp.mean(nll_list)
        acc = xp.mean(acc_list)
        ood_nll = xp.min(nll_list)
        ood_acc = xp.min(acc_list)

        chainer.reporter.report({'test_nll' : nll})
        chainer.reporter.report({'test_accuracy': acc})
        chainer.reporter.report({'test_ood_nll': ood_nll})
        chainer.reporter.report({'test_ood_accuracy': ood_acc})

    return evaluate


def train_evaluation(generator=None, env_ag_model=None, env_aw_model_list=None,
                     nll_func=None,
                     tr_env_indices=[], iterator_dict=[],
                     train_dataset_list=[], **kwargs):
    @chainer.training.make_extension()
    #@profile
    def evaluate(trainer):
        xp = generator.xp
        nll = 0.0
        acc = 0.0
        nll_ag_list = xp.zeros(len(tr_env_indices))
        acc_ag_list = xp.zeros(len(tr_env_indices))
        nll_aw_list = xp.zeros(len(tr_env_indices))
        acc_aw_list = xp.zeros(len(tr_env_indices))
        with chainer.using_config('train', False):
            with chainer.using_config('enable_backprop', False):
                env_index = 0
                for env_idx, env in enumerate(tr_env_indices):
                    env_aw_predictor = env_aw_model_list[env_idx]
                    
                    env_nll_ag = 0.0
                    env_acc_ag = 0.0
                    env_nll_aw = 0.0
                    env_acc_aw = 0.0
                    env_test_count = 0
                    for batch in iterator_dict[env]:

                        variabledict = chainer.dataset.concat_examples(batch,
                                                                       device=generator.device)
                        var_x = variabledict['x']
                        var_y = variabledict['y']

                        feature = generator(var_x)
                        logits_ag = env_ag_model(feature)
                        logits_aw = env_aw_predictor(feature)
                        
                        env_nll_ag += nll_func(logits_ag, var_y).data * len(var_x)
                        env_acc_ag += lf.mean_accuracy(logits_ag, var_y).data * len(var_x)
                        env_nll_aw += nll_func(logits_aw, var_y).data * len(var_x)
                        env_acc_aw += lf.mean_accuracy(logits_aw, var_y).data * len(var_x)
                        env_test_count += len(var_x)
                        
                    iterator_dict[env].reset()
                    env_nll_ag /= env_test_count
                    env_acc_ag /= env_test_count
                    env_nll_aw /= env_test_count
                    env_acc_aw /= env_test_count
                    nll_ag_list[env_index] = env_nll_ag
                    acc_ag_list[env_index] = env_acc_ag
                    nll_aw_list[env_index] = env_nll_aw
                    acc_aw_list[env_index] = env_acc_aw
                    env_index += 1
                
        
        nll_ag = xp.mean(nll_ag_list)
        acc_ag = xp.mean(acc_ag_list)
        nll_aw = xp.mean(nll_aw_list)
        acc_aw = xp.mean(acc_aw_list)

        chainer.reporter.report({'nll_ag_exact' : nll_ag})
        chainer.reporter.report({'acc_ag_exact': acc_ag})
        chainer.reporter.report({'nll_aw_exact' : nll_aw})
        chainer.reporter.report({'acc_aw_exact': acc_aw})

    return evaluate
