from tqdm import tqdm
import pyspiel
import csv
import time
import numpy as np
import math
import LiteEFG

def train(graph, traverse_type, convergence_type, iter, print_freq, game_env="leduc_poker", output_strategy=False, csv_filename = "exploitability_log.csv", out_to_file=True):
    game = pyspiel.load_game(game_env)
    env = LiteEFG.OpenSpielEnv(game, traverse_type=traverse_type, regenerate=False)
    env.set_graph(graph)

    with open(csv_filename, mode='w', newline='') as csv_file:
        csv_writer = csv.writer(csv_file)
        # Write headers
        csv_writer.writerow(['Iteration', 'Exploitability', 'Best Exploitability', 'Runing Time'])

        pbar = tqdm(total=iter)
        best_exp = 1e9
        time_start = time.time()
        for i in range(iter):
            graph.update_graph(env)
            env.update_strategy(graph.current_strategy(), update_best=(convergence_type == "best-iterate"))
                
            if (i+1) % print_freq == 0 or i==0:
                exploitability = sum(env.exploitability(graph.current_strategy(), convergence_type))
                best_exp = min(best_exp, exploitability)
                pbar.set_description(f'iterations:{i+1}, Exploitability: {exploitability:.12f}, Best: {best_exp:.12f}')
                pbar.update(print_freq)

                # Write current state to CSV
                if out_to_file:
                    time_end = time.time()
                    csv_writer.writerow([i+1, exploitability, best_exp, (time_end - time_start)/60.0])
                    csv_file.flush()

    if output_strategy:
        _, df_list = env.get_strategy(graph.current_strategy(), "avg-iterate")
        for i, df in enumerate(df_list):
            df['Infoset'] = df['Infoset'].apply(lambda x: x.replace('\n', '\\n'))
            df.to_csv("strategy_" + str(i) + ".csv", quoting=csv.QUOTE_MINIMAL, quotechar='"')
