import jax.numpy as jnp
import jax

@jax.jit
def neg_entropy(x):
    return jnp.sum(x*jnp.log(x))

def loss_bilinear(payoff, l2_coeff=0., entr_coeff=0.):

    assert l2_coeff >= 0 and entr_coeff >= 0
    @jax.jit
    def loss(x,y):
        _loss = x.T @ (payoff @ y)
        
        if l2_coeff > 0:
            _loss += l2_coeff * jnp.sum((x-1./3)**2)/2 - l2_coeff * jnp.sum((y-1./3)**2)/2
        
        if entr_coeff > 0:
            _loss += entr_coeff*neg_entropy(x) - entr_coeff*neg_entropy(y)
        return _loss
    return loss


def reg_matching_pennies(mu):

    @jax.jit
    def loss(x,y):
        _loss = -jnp.dot(2*x-1,2*y-1)
        _loss += mu*jnp.sum((x-0.5)**2)/2
        _loss += -mu*jnp.sum((y-0.5)**2)/2
        return _loss

    return loss

def reg_rps(mu):
    A = jnp.array([[0, -1., 1],
                   [1, 0, -1],
                   [-1, 1, 0]
                   ])
    def loss(x,y):
        _loss = -x.T @ A @ y
        _loss += mu*jnp.sum((x-1./3)**2)/2
        _loss += mu*jnp.sum((y-1./3)**2)/2
        return _loss
    

    return loss