

import LiteEFG
from typing import Literal


class OMWUGraph(LiteEFG.Graph):
    def __init__(self, eta=0.1, regularizer: Literal["Euclidean", "Entropy"]="Euclidean", weighted=False):
        super().__init__()
        self.eta = eta
        self.regularizer = regularizer

        # Create a new graph for CFR
        with LiteEFG.backward(is_static=True):
        
            self.alpha = 1.0
            if weighted:
                self.alpha = LiteEFG.const(1, 1.0)
                self.alpha.inplace(LiteEFG.aggregate(self.alpha, "sum"))
                self.alpha.inplace((self.alpha.max() + 1) * 2)

            ev = LiteEFG.const(size=1, val=0.0)
            bar_ev = LiteEFG.const(size=1, val=0.0)

            self.coef = self.alpha / eta
            self.u = LiteEFG.const(self.action_set_size, 1.0 / self.action_set_size)
            self.bar_u = self.u.copy()

        with LiteEFG.backward():

            bar_gradient = LiteEFG.aggregate(bar_ev, "sum") + self.utility
            gradient = LiteEFG.aggregate(ev, "sum") + self.utility

            self.prev_bar_u =  self.bar_u.copy()
            self._update(bar_gradient, self.bar_u, self.bar_u)
            self._update(gradient, self.u, self.bar_u)

            self._get_ev(bar_gradient, bar_ev, self.bar_u, self.prev_bar_u)
            self._get_ev(gradient, ev, self.u, self.bar_u)

        print("===============Graph is ready for OMWU===============")
        print("eta: %f, regularizer: %s" % (self.eta, self.regularizer))
        print("=====================================================\n")
    
    def _get_ev(self, gradient, ev, strategy, ref_strategy):
        if self.regularizer == "Euclidean":
            ev.inplace(LiteEFG.dot(gradient, strategy) - LiteEFG.euclidean(strategy - ref_strategy) * self.coef)
        else:
            kl = LiteEFG.dot((strategy / ref_strategy).log(), strategy) * self.coef
            ev.inplace(LiteEFG.dot(gradient, strategy) - kl)

    def _update(self, gradient, upd_u, ref_u):
        gradient_div = gradient / self.coef
        
        if self.regularizer == "Euclidean":
            upd_u.inplace(ref_u + gradient_div)
            upd_u.inplace(upd_u.project(distance="L2"))

        else:
            upd_u.inplace(ref_u.log() + gradient_div)
            upd_u.inplace(upd_u - upd_u.max())
            upd_u.inplace(upd_u.exp())
            upd_u.inplace(upd_u.project(distance="KL"))
    
    def update_graph(self, env : LiteEFG.Environment) -> None:
        env.update(self.u)
    
    def current_strategy(self) -> LiteEFG.GraphNode:
        return self.u


from tqdm import tqdm
import pyspiel
import csv
import time
import numpy as np
import math
def train(graph, traverse_type, convergence_type, iter, print_freq, game_env="leduc_poker", output_strategy=False, csv_filename = "exploitability_log.csv", out_to_file=True):
    game = pyspiel.load_game(game_env)
    env = LiteEFG.OpenSpielEnv(game, traverse_type=traverse_type, regenerate=False)
    env.set_graph(graph)

    with open(csv_filename, mode='w', newline='') as csv_file:
        csv_writer = csv.writer(csv_file)
        # Write headers
        csv_writer.writerow(['Iteration', 'Exploitability', 'Best Exploitability', 'Runing Time'])

        pbar = tqdm(total=iter)
        best_exp = 1e9
        time_start = time.time()
        for i in range(iter):
            graph.update_graph(env)
            env.update_strategy(graph.current_strategy(), update_best=(convergence_type == "best-iterate"))
                
            if (i+1) % print_freq == 0 or i==0:
                exploitability = sum(env.exploitability(graph.current_strategy(), convergence_type))
                best_exp = min(best_exp, exploitability)
                pbar.set_description(f'iterations:{i+1}, Exploitability: {exploitability:.12f}, Best: {best_exp:.12f}')
                pbar.update(print_freq)

                # Write current state to CSV
                if out_to_file:
                    time_end = time.time()
                    csv_writer.writerow([i+1, exploitability, best_exp, (time_end - time_start)/60.0])
                    csv_file.flush()

    if output_strategy:
        _, df_list = env.get_strategy(graph.current_strategy(), "avg-iterate")
        for i, df in enumerate(df_list):
            df['Infoset'] = df['Infoset'].apply(lambda x: x.replace('\n', '\\n'))
            df.to_csv("strategy_" + str(i) + ".csv", quoting=csv.QUOTE_MINIMAL, quotechar='"')

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--game", type=str, default="kuhn_poker")
    parser.add_argument("--traverse_type", type=str, choices=["Enumerate", "External", "Outcome"], default="Enumerate")
    parser.add_argument("--iter", type=int, default=20000)
    parser.add_argument("--print_freq", type=int, default=100)

    parser.add_argument("--eta", help="learning rate", type=float, default=0.1)
    parser.add_argument("--regularizer", type=str, choices=["Euclidean", "Entropy"], default="Entropy")
    parser.add_argument("--weighted", help="weighted dilated regularizer or not", action="store_true")

    args = parser.parse_args()

    alg=OMWUGraph(args.eta, args.regularizer, args.weighted)
    
    import os
    PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
    csv_filename=os.path.join(PROJECT_ROOT,'data',f'OMWU_{args.game}.csv') 

    train(alg, traverse_type=args.traverse_type, convergence_type="last-iterate", iter=args.iter, print_freq=args.print_freq, game_env=args.game, out_to_file=True,csv_filename=csv_filename)
