from enum import Enum

import os
 
from experiments.utils import load_model, save_model
from lamis_train import LAMISTrain 
from experiments.policy_extraction import solve_game_full, solve_game_each_infoset, solve_game_full_trained_dynamics, solve_game_full_no_dynamics, extract_rnad_policy, solve_game_replace_legals, solve_game_per_depth
from games.jax_game_algorithms import exploitability_jax_game

class ExploitabilityExperimentType(str, Enum):
  RNAD = "rnad"
  FULL_GAME = "full_game"
  EACH_ISET = "each_iset"
  NO_DYNAMICS_FULL_GAME = "no_dynamics_full_game"
  TRAINED_DYNAMICS_FULL_GAME = "trained_dynamics_full_game"
  TRAINED_LEGALS = "trained_legals"
  ISETS_PER_DEPTH = "isets_per_depth"

def exploitability_lamis(model: LAMISTrain, strategy_path:str, experiment_type: str, resolve_iterations: int, depth_limit: int): 
  strategy_dict = None
  if strategy_path is not None and os.path.exists(strategy_path):
    strategy = load_model(strategy_path)
    print("Found saved strategy")
    assert strategy["config"] == model.config, "Strategy and model config must match" 
    strategy_dict = strategy["strategy_dict"]
    return exploitability_jax_game(model.game, strategy_dict) 
     
  if experiment_type == ExploitabilityExperimentType.RNAD:
    strategy_dict = extract_rnad_policy(model)
  elif experiment_type == ExploitabilityExperimentType.FULL_GAME:
    strategy_dict = solve_game_full(model, resolve_iterations)
  elif experiment_type == ExploitabilityExperimentType.NO_DYNAMICS_FULL_GAME:
    strategy_dict = solve_game_full_no_dynamics(model, resolve_iterations)
  elif experiment_type == ExploitabilityExperimentType.TRAINED_DYNAMICS_FULL_GAME:
    strategy_dict = solve_game_full_trained_dynamics(model, resolve_iterations)
  elif experiment_type == ExploitabilityExperimentType.TRAINED_LEGALS:
    strategy_dict = solve_game_replace_legals(model, resolve_iterations, True)
  elif experiment_type == ExploitabilityExperimentType.EACH_ISET:
    strategy_dict = solve_game_each_infoset(model, depth_limit, resolve_iterations)
  elif experiment_type == ExploitabilityExperimentType.ISETS_PER_DEPTH:
    strategy_dict = solve_game_per_depth(model, resolve_iterations, depth_limit)
  else:
    raise ValueError("Invalid experiment type")
  
  if strategy_path is not None:
    save_model(strategy_path, {"config": model.config, "strategy_dict": strategy_dict})
  return exploitability_jax_game(model.game, strategy_dict)
 