"""Solver."""

import jax
import jax.numpy as jnp


#def objective_fn(weights, m, a_arr, b, b0):
#    """Objective_fn."""
#    tmpMB = jnp.matmul(m, b).reshape(-1, 1)
#    out = a_arr + tmpMB
#    out1 = - 0.025*jnp.linalg.norm(tmpMB)**2
#    out2 = - 0.05*jnp.linalg.norm(b0 + tmpMB)**2
#    l2_out = jnp.linalg.norm(out, axis=0)**2
#    weighted_dist = sum(weights * l2_out) + out1 + out2
#    
#    return weighted_dist

def objective_fn(weights, m, a_arr, b):
    """Objective_fn."""
    out = a_arr + jnp.matmul(m, b).reshape(-1, 1)
    l2_out = jnp.linalg.norm(out, axis=0)**2
    weighted_dist = sum(weights * l2_out) # - 0.005*jnp.linalg.norm(jnp.matmul(m, b).reshape(-1, 1))**2
    return weighted_dist


def min_max_solver(a_arr, M, b, T, eta=1e-3, eps=2e-5):  # pylint: disable=invalid-name
    """Min max solver.

    Args:

    a_arr: shape (2, 2000)
    M: shape (2, 40)
    b: shape (40,)
    T: gd iterations
    eta: learning rate
    eps: convergence tolerance

    Returns:
    optimized M.
    """
    m = a_arr.shape[1]  # simplex dimension
    weights = jnp.ones(m) / m
    T = 200
    for _ in range(T):
        # Option 1 (Practical): Use some (Ns) gradient steps to find M_t
        # Question 1: do two gd on M and weights use the same eta?
        # Gradient step (ascent on M)
        _, grad_m = jax.value_and_grad(
            objective_fn, argnums=1)(weights, M, a_arr, b)
        updated_M = M + eta * 0.1 *grad_m 
        if jnp.linalg.norm(updated_M)>=0.55:
            updated_M = 0.55*updated_M / jnp.linalg.norm(updated_M)
        # Gradient step (descent on simplex)
        _, grad = jax.value_and_grad(objective_fn)(weights, M, a_arr, b)
        updated_weights = weights * jnp.exp(-eta * 0.001 * grad)
        updated_weights *= (1.0) / jnp.sum(updated_weights)
        # check convergence
        diff_weights = jnp.sum(jnp.square(weights - updated_weights))
        diff_M = jnp.sum(jnp.square(M - updated_M))
        if  diff_weights < 10*eps and diff_M < eps:
            print("gd convergence.")
            break
        M = updated_M
        weights = updated_weights
        # LQcosts = objective_fn_LQ(M, b, b0)
        # end check for convergence
    
    if diff_weights > 100*eps or diff_M > eps:
        print(f"diff_weights {diff_weights} diff_M {diff_M}")
        # raise ValueError
        
    return M, updated_weights #, LQcosts
