import flax
import jax
import jax.numpy as jnp

def masked_mse_loss(pred, target, mask):
    return jnp.mean(jnp.square(pred - target) * mask)
    
def mse_loss(val, target):
    return jnp.mean(jnp.square(val - target))

def value_and_multi_grad(fun, n_outputs, argnums=0, has_aux=False):
    def select_output(index):
        def wrapped(*args, **kwargs):
            if has_aux:
                x, *aux = fun(*args, **kwargs)
                return (x[index], *aux)
            else:
                x = fun(*args, **kwargs)
                return x[index]
        return wrapped

    grad_fns = tuple(
        jax.value_and_grad(select_output(i), argnums=argnums, has_aux=has_aux)
        for i in range(n_outputs)
    )
    def multi_grad_fn(*args, **kwargs):
        grads = []
        values = []
        for grad_fn in grad_fns:
            (value, *aux), grad = grad_fn(*args, **kwargs)
            values.append(value)
            grads.append(grad)
        return (tuple(values), *aux), tuple(grads)
    return multi_grad_fn


def merge_with_args(ray_config, build_parser):
    args = build_parser(init=None)

    if ray_config is None:
        return args

    for key, val in ray_config.items():
        if key == "hidden_dims":
            setattr(args, "feat_dim", val)

        if key in args.__dict__.keys():
            setattr(args, key, val)
    return args


def recursive_sum(dictionary):
    current_sum = 0
    for key in dictionary:
        if not isinstance(dictionary[key], flax.core.frozen_dict.FrozenDict):
            if not isinstance([key], str):
                current_sum = current_sum + jnp.mean(dictionary[key])
        else:
            current_sum = current_sum + recursive_sum(dictionary[key])
    return current_sum

def recursive_len(dictionary):
    current_len = 0
    for key in dictionary:
        if not isinstance(dictionary[key], flax.core.frozen_dict.FrozenDict):
            if not isinstance([key], str):
                current_len = current_len + 1
        else:
            current_len = current_len + recursive_len(dictionary[key])
    return current_len