import os
import sys
import math

import numpy as np

import chainer
chainer.config.cudnn_deterministic = True
chainer.config.autotune = False
import chainer.cuda
from chainer import Variable
from chainer import functions as F
from source import loss_fxns as lf

'''
standard version with iterator function
'''
def test_evaluation(model=None, nll_func=None, ts_env_indices=[], iterator_dict=[],
                    test_dataset_list=[], **kwargs):
    @chainer.training.make_extension()
    #@profile
    def evaluate(trainer):
        xp = model.xp
        nll = 0.0
        acc = 0.0
        nll_list = xp.zeros(len(ts_env_indices))
        acc_list = xp.zeros(len(ts_env_indices))
        with chainer.using_config('train', False):
            with chainer.using_config('enable_backprop', False):
                env_index = 0
                for env in ts_env_indices:
                    env_nll = 0.0
                    env_acc = 0.0
                    env_test_count = 0
                    for batch in iterator_dict[env]:

                        variabledict = chainer.dataset.concat_examples(batch, device=model.device)
                        var_x = variabledict['x']
                        var_y = variabledict['y']

                        logits = model(var_x)
                        env_nll += nll_func(logits, var_y).array * len(var_x)
                        env_acc += lf.mean_accuracy(logits, var_y).array * len(var_x)
                        env_test_count += len(var_x)
                        
                    iterator_dict[env].reset()
                    env_nll /= env_test_count
                    env_acc /= env_test_count
                    nll_list[env_index] = env_nll
                    acc_list[env_index] = env_acc
                    env_index += 1
                
        
        nll = xp.mean(nll_list)
        acc = xp.mean(acc_list)
        ood_nll = xp.min(nll_list)
        ood_acc = xp.min(acc_list)

        chainer.reporter.report({'test_nll' : nll})
        chainer.reporter.report({'test_accuracy': acc})
        chainer.reporter.report({'test_ood_nll': ood_nll})
        chainer.reporter.report({'test_ood_accuracy': ood_acc})

    return evaluate


def train_evaluation(model=None, nll_func=None, ts_env_indices=[], iterator_dict=[],
                    test_dataset_list=[], **kwargs):
    @chainer.training.make_extension()
    #@profile
    def evaluate(trainer):
        xp = model.xp
        nll = 0.0
        acc = 0.0
        nll_list = xp.zeros(len(ts_env_indices))
        acc_list = xp.zeros(len(ts_env_indices))
        with chainer.using_config('train', False):
            with chainer.using_config('enable_backprop', False):
                env_index = 0
                for env in ts_env_indices:
                    env_nll = 0.0
                    env_acc = 0.0
                    env_test_count = 0
                    for batch in iterator_dict[env]:

                        variabledict = chainer.dataset.concat_examples(batch, device=model.device)
                        var_x = variabledict['x']
                        var_y = variabledict['y']

                        logits = model(var_x)
                        env_nll += nll_func(logits, var_y).array * len(var_x)
                        env_acc += lf.mean_accuracy(logits, var_y).array * len(var_x)
                        env_test_count += len(var_x)
                        
                    iterator_dict[env].reset()
                    env_nll /= env_test_count
                    env_acc /= env_test_count
                    nll_list[env_index] = env_nll
                    acc_list[env_index] = env_acc
                    env_index += 1
                
        
        nll = xp.mean(nll_list)
        acc = xp.mean(acc_list)

        chainer.reporter.report({'train_nll_exact' : nll})
        chainer.reporter.report({'train_acc_exact': acc})

    return evaluate
    
def train_evaluation_with_penalty(model=None, nll_func=None,
                                  plt_func=None, updater_name=None,
                                  ts_env_indices=[], iterator_dict=[],
                                  test_dataset_list=[], **kwargs):
    @chainer.training.make_extension()
    #@profile
    def evaluate(trainer):
        xp = model.xp
        scale = Variable(xp.array(1.0, dtype='f'))
        nll = 0.0
        acc = 0.0
        if 'Total' in updater_name:
            nll_list = []
        else:
            nll_list = xp.zeros(len(ts_env_indices))
        if 'Each' in updater_name:
            penalty_list = xp.zeros(len(ts_env_indices))
        acc_list = xp.zeros(len(ts_env_indices))
        with chainer.using_config('train', False):
            with chainer.using_config('enable_backprop', True):
                env_index = 0
                for env in ts_env_indices:
                    env_nll = 0.0
                    if 'Each' in updater_name:
                        env_penalty = 0.0
                    env_acc = 0.0
                    env_test_count = 0
                    for batch in iterator_dict[env]:

                        variabledict = chainer.dataset.concat_examples(batch, device=model.device)
                        var_x = variabledict['x']
                        var_y = variabledict['y']

                        logits = scale * model(var_x)
                        env_nll_part = nll_func(logits, var_y)
                        if 'Each' in updater_name:
                            env_penalty += plt_func(scale, env_nll_part).array * len(var_x)
                        if 'Total' in updater_name:
                            env_nll += env_nll_part * len(var_x)
                        else:
                            env_nll += env_nll_part.array * len(var_x)
                        env_acc += lf.mean_accuracy(logits, var_y).array * len(var_x)
                        env_test_count += len(var_x)
                        
                    iterator_dict[env].reset()
                    env_nll /= env_test_count
                    if 'Each' in updater_name:
                        env_penalty /= env_test_count
                        penalty_list[env_index] = env_penalty
                    env_acc /= env_test_count
                    if 'Total' in updater_name:
                        nll_list.append(env_nll)
                    else:
                        nll_list[env_index] = env_nll
                    acc_list[env_index] = env_acc
                    env_index += 1
                
                if 'Total' in updater_name:
                    penalty = plt_func(model, nll_list).array
        
        if 'Total' in updater_name:
            for _nll in nll_list:
                nll += _nll.array / len(nll_list)
        else:
            nll = xp.mean(nll_list)
        if 'Each' in updater_name:
            penalty = xp.mean(penalty_list)
        acc = xp.mean(acc_list)

        chainer.reporter.report({'train_nll_exact' : nll})
        chainer.reporter.report({'train_penalty_exact' : penalty})
        chainer.reporter.report({'train_acc_exact': acc})

    return evaluate

'''
uses the memorized dataset
'''
def test_evaluation_fullbatch(model=None, nll_func=None, ts_env_indices=[],
                              iterator_dict=[], test_dataset_list=[]):
    @chainer.training.make_extension()
    #@profile
    def evaluate(trainer):

        nll = 0.0
        acc = 0.0
        nll_list = []
        acc_list = []
        with chainer.using_config('train', False):
            with chainer.using_config('enable_backprop', False):
                for env in ts_env_indices:
                    var_x = test_dataset_list[int(env)]['x']
                    var_y = test_dataset_list[int(env)]['y']
                    logits = model(var_x)
                    env_nll = nll_func(logits, var_y).array
                    env_acc = lf.mean_accuracy(logits, var_y).array
                    nll_list.append(env_nll)
                    acc_list.append(env_acc)
                    nll += env_nll
                    acc += env_acc
                    iterator_dict[env].reset()

        acc /= len(ts_env_indices)
        nll /= len(ts_env_indices)
        ood_nll = xp.min(nll_list)
        ood_acc = xp.min(acc_list)

        chainer.reporter.report({'test_nll' : nll})
        chainer.reporter.report({'test_accuracy': acc})
        chainer.reporter.report({'test_ood_nll': ood_nll})
        chainer.reporter.report({'test_ood_accuracy': ood_acc})

    return evaluate
