import pickle
import os

from games.jax_game import JaxGame
from games.jax_goofspiel import JaxGoofspiel
from games.jax_leduc import JaxLeduc
from games.jax_oshi_zumo import JaxOshiZumo

from games.jax_game_utils import get_game_name

def load_model(filepath: str):
  if not os.path.exists(filepath):
    raise FileNotFoundError(f"File {filepath} does not exist")
  with open(filepath, "rb") as f:
    data= pickle.load(f)
  return data

def save_model(filepath: str, data):
  os.makedirs(os.path.dirname(filepath), exist_ok=True)
  with open(filepath, "wb") as f:
    pickle.dump(data, f)

def decode_game_name(game_name: str):
  game_name, *params = game_name.split("|")
  if game_name == "goofspiel":
    assert len(params) == 1
    return JaxGoofspiel(int(params[0]))
  elif game_name == "leduc":
    assert len(params) == 0
    return JaxLeduc()
  elif game_name == "oshi_zumo":
    assert len(params) == 2
    return JaxOshiZumo(int(params[0]), int(params[1]))
  else:
    raise ValueError(f"Game {game_name} not found")


def get_game_folder(game: JaxGame, folder_type:str):
  if folder_type == "strategy":
    init_folder = "strategies"
  elif folder_type == "plot":
    init_folder = "plots"
  elif folder_type == "nash":
    init_folder = "nash"
  elif folder_type == "similarities":
    init_folder = "similarities"
  if isinstance(game, JaxGoofspiel):
    folder = "data/" + init_folder + "/" + get_game_name(game)
    os.makedirs(folder, exist_ok=True)
    return folder
  else:
    raise ValueError("Invalid game")