import argparse

import os

from experiments.utils import decode_game_name, load_model
from experiments.exploitability_experiment import exploitability_lamis
from games.jax_game_utils import get_game_name

parser = argparse.ArgumentParser()
  
# Training settings
parser.add_argument("--depth_limit", type=int, default=1, help="Depth limit for exploitability calculation in each infoset")
parser.add_argument("--resolve_iterations", type=int, default=1000, help="Number of CFR iterations") 
parser.add_argument("--experiment_type", type=str, default="isets_per_depth", help="Experiment type. Choices: rnad, full_game, each_iset, isets_per_depth, no_dynamics_full_game, trained_dynamics_full_game, trained_legals")
 
parser.add_argument("--seed", type=int, default=73571, help="Random seed")
parser.add_argument("--model_range", type=int, nargs="+", default=[1, 3, 1], help="Model range")

parser.add_argument("--load_saved_strategy", type=bool, default=True, help="Load saved strategy")

#Game setting:
parser.add_argument("--game_details", type=str, default="goofspiel|4", help="Game details")



def main():
  args = parser.parse_args() 
  
  assert len(args.model_range) <= 3, "Range works for up to 3 arguments"
  assert args.model_range[0] < args.model_range[1], "Model range is invalid"
  
  game = decode_game_name(args.game_details)
  
  default_model_path = "data/models/" + get_game_name(game) + "/seed_" + str(args.seed) + "/lamis_" 
  
  default_strategy_path = "data/strategies/" + get_game_name(game) + "/seed_" + str(args.seed) + "/" + args.experiment_type 
  
  os.makedirs(default_strategy_path, exist_ok=True)
  default_strategy_path += "/lamis_"
  
  p1_exploitabilities, p2_exploitabilities = [], []
  
  for i in range(*args.model_range):
    print(i, flush=True)
    model_path = default_model_path + str(i) + ".pkl"
    model = load_model(model_path )
    
    strategy_path = default_strategy_path + str(i) + ".pkl" if args.load_saved_strategy else None 
      
    _, _, p1_expl, p2_expl = exploitability_lamis(model, strategy_path, args.experiment_type, args.resolve_iterations, args.depth_limit)
    print("P1: ", p1_expl)
    print("P2: ", p2_expl)
    p1_exploitabilities.append(p1_expl)
    p2_exploitabilities.append(p2_expl)
    
  
  

if __name__ == "__main__":
  main()