import jax
import jax.numpy as jnp
from jax.tree_util import tree_map



def pgd_target(params, model_out_grads, step, apply_fn):
    model_out = apply_fn(params)
    target = tree_map(lambda p, g: p -step*g, model_out, model_out_grads)
    return target


def pgd_residual_fn(params, model_out_grads, step, apply_fn):
    target = pgd_target(params, model_out_grads, step, apply_fn)
    
    def residual_fn(params):
        residuals = tree_map(lambda out, t: out-t, apply_fn(params), target)
        return jnp.array(residuals).reshape(-1)
    return residual_fn

def pgd_surrogate_loss(params, model_out_grads, step, apply_fn):
    
    target = pgd_target(params, model_out_grads, step, apply_fn)
    
    def loss(params):
        errors = tree_map(lambda ot, t: jnp.sum((ot-t)**2), apply_fn(params), target)
        return sum(errors)/2

    return loss        


def min_max_output_grads(loss, params, apply_fn):
    model_x, model_y = apply_fn(params)
    grad_func = jax.grad(loss, argnums=(0,1))
    gs_x, gs_y = grad_func(model_x, model_y)
    neg_gs_y = tree_map(lambda x: -x, gs_y)
    return (gs_x, neg_gs_y)

# loss takes in model outputs for both players
def minmax_pgd_surrogate_loss(loss, params, step, apply_fn):
    grads = min_max_output_grads(loss, params, apply_fn)
    return pgd_surrogate_loss(params, grads, step, apply_fn)

def minmax_pgd_residual_fn(loss, params, step, apply_fn):
    grads = min_max_output_grads(loss, params, apply_fn)
    return pgd_residual_fn(params, grads, step, apply_fn)


