
import numpy as np
import argparse
import os

from experiments.utils import decode_game_name, save_model
from games.jax_game_utils import get_game_name
from lamis_train import LAMISTrain, LAMISTrainConfig

parser = argparse.ArgumentParser()

# Training setting
parser.add_argument("--save_each", type=int, default=1000, help="Save network each amount of iterations")
parser.add_argument("--iterations", type=int, default=2, help="Training iterations,  the whole algorithm will run for --iterations * --save_each")
parser.add_argument("--save_folder", type=str, default="data/models", help="Path to the saved trained models")

# Algorithm setting
parser.add_argument("--sampling_epsilon", type=float, default=0.5, help="Epsilon for epsilon-on policy sampling")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training")

parser.add_argument("--train_rnad", type=bool, default=True, help="Train RNAD")
parser.add_argument("--train_mvs", type=bool, default=True, help="Train MVS")
parser.add_argument("--train_transformations", type=bool, default=True, help="Train transformations")
parser.add_argument("--train_abstraction", type=bool, default=True, help="Train abstraction")
parser.add_argument("--train_dynamics", type=bool, default=True, help="Train dynamics")
parser.add_argument("--train_legal_actions", type=bool, default=True, help="Train legal actions")

parser.add_argument("--use_abstraction", type=bool, default=True, help="Use abstraction")
parser.add_argument("--abstraction_amount", type=int, default=2, help="Abstraction amount")
parser.add_argument("--abstraction_size", type=int, default=32, help="Abstraction size")
parser.add_argument("--similarity_metric", type=str, default="legal_actions", help="Similarity metric. Choices: policy, value, policy_value, legal_actions, legal_policy, legal_policy_value, action_history, action_history_policy, action_history_legal, action_history_legal_policy")
parser.add_argument("--similarity_noise", type=float, default=0.02, help="Similarity noise")

parser.add_argument("--abstraction_soft_k_means_temperature", type=float, default=1.0, help="Abstraction soft k means temperature")
parser.add_argument("--abstraction_soft_k_means_closeness_assignment", type=float, default=0.5, help="Abstraction soft k means closeness assignment")
parser.add_argument("--abstraction_soft_k_means_repulsive_force", type=float, default=3.0, help="Abstraction soft k means repulsive force")
parser.add_argument("--abstraction_hard_k_means_closeness", type=float, default=0.3, help="Abstraction hard k means closeness")
parser.add_argument("--transformation_soft_k_means_temperature", type=float, default=1.0, help="Transformation soft k means temperature")
parser.add_argument("--transformation_soft_k_means_closeness_assignment", type=float, default=0.5, help="Transformation soft k means closeness assignment")
parser.add_argument("--transformation_soft_k_means_repulsive_force", type=float, default=3.0, help="Transformation soft k means repulsive force")  

parser.add_argument("--dynamics_type", type=str, default="public_state", help="Type of dynamics. Choices: iset, public_state")

parser.add_argument("--ps_encoder_hidden_size", type=int, default=256, help="PS encoder hidden size")
parser.add_argument("--ps_decoder_hidden_size", type=int, default=64, help="PS decoder hidden size")
parser.add_argument("--iset_hidden_size", type=int, default=64, help="ISet hidden size")
parser.add_argument("--dynamics_hidden_size", type=int, default=128, help="Dynamics hidden size")
parser.add_argument("--similarity_hidden_size", type=int, default=128, help="Similarity hidden size")
parser.add_argument("--mvs_hidden_size", type=int, default=128, help="MVS hidden size")
parser.add_argument("--legal_actions_hidden_size", type=int, default=64, help="Legal actions hidden size")
parser.add_argument("--transformation_hidden_size", type=int, default=128, help="Transformation hidden size")
parser.add_argument("--rnad_hidden_size", type=int, default=128, help="RNAD hidden size")

parser.add_argument("--transformations", type=int, default=10, help="Number of transformations")
parser.add_argument("--matrix_valued_states", type=bool, default=True, help="Matrix valued states")

parser.add_argument("--c_iset_vtrace", type=float, default=1.0, help="C ISet VTrace")
parser.add_argument("--rho_iset_vtrace", type=float, default=np.inf, help="Rho ISet VTrace")
parser.add_argument("--c_state_vtrace", type=float, default=1.0, help="C State VTrace")
parser.add_argument("--rho_state_vtrace", type=float, default=np.inf, help="Rho State VTrace")

parser.add_argument("--eta_regularization", type=float, default=0.2, help="Eta regularization")
parser.add_argument("--entropy_schedule_repeats", type=int, nargs='+', default=[1], help="Entropy schedule repeats")
parser.add_argument("--entropy_schedule_size", type=int, nargs='+', default=[2000], help="Entropy schedule size")

parser.add_argument("--learning_rate", type=float, default=3e-4, help="Learning rate")
parser.add_argument("--target_network_update", type=float, default=1e-3, help="Target network update")
parser.add_argument("--seed", type=int, default=73571, help="Random seed, use 0 to have totally random seed")

#Game setting:
parser.add_argument("--game_details", type=str, default="goofspiel|4", help="Game details")


def train(args, game, save_folder):
  
    
  config = LAMISTrainConfig(
    batch_size=args.batch_size,
    trajectory_max=game.max_trajectory_length(),
    sampling_epsilon=args.sampling_epsilon,
    
    train_rnad=args.train_rnad,
    train_mvs=args.train_mvs,
    train_transformations=args.train_transformations,
    train_abstraction=args.train_abstraction,
    train_dynamics=args.train_dynamics,
    train_legal_actions=args.train_legal_actions,
    
    use_abstraction=args.use_abstraction,
    abstraction_amount=args.abstraction_amount,
    abstraction_size=args.abstraction_size,
    similarity_metric=args.similarity_metric,
    similarity_noise=args.similarity_noise,
    
    
    abstraction_soft_k_means_temperature=args.abstraction_soft_k_means_temperature,
    abstraction_soft_k_means_closeness_assignment=args.abstraction_soft_k_means_closeness_assignment,
    abstraction_soft_k_means_repulsive_force=args.abstraction_soft_k_means_repulsive_force,
    abstraction_hard_k_means_closeness=args.abstraction_hard_k_means_closeness,
    transformation_soft_k_means_temperature=args.transformation_soft_k_means_temperature,
    transformation_soft_k_means_closeness_assignment=args.transformation_soft_k_means_closeness_assignment,
    transformation_soft_k_means_repulsive_force=args.transformation_soft_k_means_repulsive_force,
    
    dynamics_type=args.dynamics_type,
    
    ps_encoder_hidden_size=args.ps_encoder_hidden_size,
    ps_decoder_hidden_size=args.ps_decoder_hidden_size,
    iset_hidden_size=args.iset_hidden_size,
    dynamics_hidden_size=args.dynamics_hidden_size,
    similarity_hidden_size=args.similarity_hidden_size,
    mvs_hidden_size=args.mvs_hidden_size,
    legal_actions_hidden_size=args.legal_actions_hidden_size,
    rnad_hidden_size=args.rnad_hidden_size,
    
    transformations=args.transformations,
    matrix_valued_states=args.matrix_valued_states,
    
    c_iset_vtrace=args.c_iset_vtrace,
    rho_iset_vtrace=args.rho_iset_vtrace,
    c_state_vtrace=args.c_state_vtrace,
    rho_state_vtrace=args.rho_state_vtrace,
    
    eta_regularization=args.eta_regularization,
    entropy_schedule_repeats=args.entropy_schedule_repeats,
    entropy_schedule_size=args.entropy_schedule_size,
    
    learning_rate=args.learning_rate,
    target_network_update=args.target_network_update,
    seed=args.seed
  )
  train_algorithm = LAMISTrain(game, config)
   
  if not os.path.exists(save_folder):
    os.makedirs(save_folder)
  
  for iteration in range(args.iterations + 1):
    file_name = save_folder + "lamis_" + str(iteration) + ".pkl" 
    print("Saving iteration", iteration, flush=True)
    save_model(file_name, train_algorithm)
    train_algorithm.multiple_jax_steps(args.save_each)
    

def main(): 
  args = parser.parse_args() 
  if args.seed == 0:
    args.seed = np.random.randint(0, 2**32 - 1)
  game = decode_game_name(args.game_details)
  folder = args.save_folder + "/" + get_game_name(game) + "/" + "seed_" + str(args.seed) + "/"

  train(args, game, folder)

if __name__ == "__main__":
  main()