

from games.jax_goofspiel import JaxGoofspiel 
from games.jax_game_algorithms import exploitability_jax_game, nash_equilibrium_jax_game
from games.jax_game_utils import stringify

import jax

   
  
def test_nash(cards: int):
    
  points_order = "descending" 
  
  game = JaxGoofspiel(cards=cards, points_order=points_order)
  
  cfr, nash_policy, nash_value = nash_equilibrium_jax_game(game)
  
  
  p1_br, p2_br, jax_p1_exp, jax_p2_exp = exploitability_jax_game(game, nash_policy)
  assert abs(jax_p1_exp - nash_value) < 1e-4
  assert abs(jax_p2_exp - nash_value) < 1e-4
  
  
  
if __name__ == "__main__":
  test_nash(3)
  test_nash(4)
  test_nash(5)