
import LiteEFG
from utils import train
from LiteEFG.baselines.PCFR import graph as PCFRPlusGraph

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--game", type=str, default="kuhn_poker")
    parser.add_argument("--traverse_type", type=str, choices=["Enumerate", "External"], default="Enumerate")
    parser.add_argument("--iter", type=int, default=20000)
    parser.add_argument("--print_freq", type=int, default=100)

    args = parser.parse_args()
   
    alg = PCFRPlusGraph()

    import os
    PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
    csv_filename=os.path.join(PROJECT_ROOT,'data',f'PCFRPlus_{args.game}.csv') 

    train(alg, traverse_type=args.traverse_type, convergence_type="last-iterate", iter=args.iter, print_freq=args.print_freq, game_env=args.game, out_to_file=True,csv_filename=csv_filename)