from __future__ import print_function

import numpy as np
import chainer
import chainer.functions as F


def mean_sigmoid_nll(logits, y, **kwargs):
    return F.sigmoid_cross_entropy(logits, y)
    
def sigmoid_nll(logits, y, **kwargs):
    return F.sigmoid_cross_entropy(logits, y, reduce='no')
    
def mean_softmax_nll(logits, y, **kwargs):
    return F.softmax_cross_entropy(logits, F.flatten(y),
                                   enable_double_backprop=True)

def mean_accuracy(logits, y, **kwargs):
    if logits.array.shape[1] > 1:
        return F.accuracy(logits, F.flatten(y))
    else:
        return F.binary_accuracy(logits, y)

def irm_penalty(self, loss, **kwargs):
    grad = chainer.grad([loss], [self.scale], enable_double_backprop=True)[0]
    return F.sum(grad ** 2)
    
def irm_penalty_for_evaluation(scale, loss, **kwargs):
    grad = chainer.grad([loss], [scale], enable_double_backprop=True)[0]
    return F.sum(grad ** 2)
    
def get_ordered_params(link):
    """Get a list of parameters sorted by parameter names."""
    name_param_pairs = list(link.namedparams())
    ordered_name_param_pairs = sorted(name_param_pairs, key=lambda x: x[0])
    return [x[1] for x in ordered_name_param_pairs]
    
def flatten_and_concat_variables(vs):
    """Flatten and concat variables to make a single flat vector variable."""
    return F.concat([F.flatten(v) for v in vs], axis=0)
    

def grad_varinace_penalty(self, loss_list):
    db_flag = self.penalty_config['db_flag']
    thresh = self.penalty_config['thresh']

    env_num = len(loss_list)
    train_penalty = 0.0
    grad_avg = 0.0
    model_param = get_ordered_params(self.model)
    grad_list = []
    for loss in loss_list:
        grad = chainer.grad([loss], model_param, enable_double_backprop=True)
        grad_flatten = flatten_and_concat_variables(grad)
        grad_avg += grad_flatten / env_num
        grad_list.append(grad_flatten)

    for grad_flatten in grad_list:
        if db_flag:
            train_penalty += F.sum((grad_flatten - grad_avg) ** 2.0)
        else:
            grad_avg_var = chainer.Variable(grad_avg.array)
            train_penalty += F.sum((grad_flatten - grad_avg_var) ** 2.0)
    return F.relu(train_penalty - thresh)
    
def grad_varinace_penalty_for_evaluation(model, loss_list):

    env_num = len(loss_list)
    train_penalty = 0.0
    grad_avg = 0.0
    model_param = get_ordered_params(model)
    grad_list = []
    for loss in loss_list:
        grad = chainer.grad([loss], model_param, enable_double_backprop=False)
        grad_flatten = flatten_and_concat_variables(grad)
        grad_avg += grad_flatten / env_num
        grad_list.append(grad_flatten)

    for grad_flatten in grad_list:
        train_penalty += F.sum((grad_flatten - grad_avg) ** 2.0)
    return train_penalty
