import argparse
import jax.numpy as jnp
import numpy as np
import jax

from experiments.utils import load_model
from lamis_train import LAMISTrain
from lamis_gameplay import LAMISGameplay, LAMISGameplayConfig
from games.jax_game_utils import get_game_name

parser = argparse.ArgumentParser()

# Training setting
parser.add_argument("--model_path", type=str, default="data/models/goofspiel_4_descending/seed_73571/lamis_2.pkl", help="Model path") 
parser.add_argument("--opp_model_path", type=str, default="data/models/goofspiel_4_descending/seed_73571/lamis_2.pkl", help="Opponent model path")
parser.add_argument("--depth_limit", type=int, default=1, help="Depth limit for exploitability calculation in each infoset")
parser.add_argument("--resolve_iterations", type=int, default=1000, help="Number of CFR iterations")

#Gameplay setting
parser.add_argument("--player", type=int, default=0, help="Resolving player")
parser.add_argument("--rounds", type=int, default=10, help="Number of rounds to play until the end.") 
parser.add_argument("--verbose", type=bool, default=True, help="Print played actions.")
parser.add_argument("--seed", type=int, default=4323432, help="Seed for the random number generator")

def get_rnad_policy(opp_iset, opp_legals, model:LAMISTrain):
  pi = model._jit_get_policy(model.network_parameters.rnad_params_target, opp_iset, opp_legals)
  pi = np.asarray(pi, dtype="float64")
  pi /= np.sum(pi)
  return pi

def run_heads_against_rnad(args):
  model = load_model(args.model_path)
  opp_model = load_model(args.opp_model_path)
  assert get_game_name(model.game) == get_game_name(opp_model.game) 
  
  gp_config = LAMISGameplayConfig(player=args.player,
                                   resolve_iterations=args.resolve_iterations,
                                   depth_limit=args.depth_limit)
  gameplay = LAMISGameplay(model, gp_config)  
    
  rewards = []
  
  jax_key = jax.random.key(args.seed)
  np_key = np.random.RandomState(args.seed)
  
  player = args.player
  opponent = 1 - player
  for _ in range(args.rounds): 
    gameplay.reset()
    game_state, legals = model.game.initialize_structures(jax_key)
    terminal = False
    turn = 0
    reward = 0
    while not terminal:
      _, p1_iset, p2_iset, ps = model.game.get_info(game_state)
      
      if player == 0: 
        p1_policy = gameplay.get_policy(ps, p1_iset)
        p2_policy = get_rnad_policy(p2_iset, legals[1], model)
      else:
        p2_policy = gameplay.get_policy(ps, p2_iset)
        p1_policy = get_rnad_policy(p1_iset, legals[0], model)
      
      p1_action = np_key.choice(np.arange(model.game.cards), p=p1_policy)
      p2_action = np_key.choice(np.arange(model.game.cards), p=p2_policy)
      
      print("Applying:", p1_action, p2_action)
      
      jax_key, action_key = jax.random.split(jax_key)
      actions = jnp.stack([p1_action, p2_action], axis=0)
      game_state, terminal, game_rewards, legals = model.game.apply_action(game_state, action_key, turn, actions)
      turn += 1
      terminal = np.array(terminal) 
      reward += np.array(game_rewards)
    print(f"Reward: {reward}", flush=True)
    rewards.append(reward)
    
  print("Player mean reward: ", np.mean(rewards))
  print("Opponent mean reward: ", -np.mean(rewards))
  
def main(): 
  args = parser.parse_args()
  run_heads_against_rnad(args)
  
  
if __name__ == "__main__":
  main()