import LiteEFG
import pyspiel
import pandas as pd
import csv
from tqdm import tqdm

from enum import Enum


class STRATEGY_TYPE_NAME:
    LAST_ITERATE = "last-iterate"
    AVG_ITERATE = "avg-iterate"
    LINEAR_AVG_ITERATE = "linear-avg-iterate"
    BEST_ITERATE = "best-iterate"

class GameValue:
    KUHN_POKER = [-1.0 / 18.0, 1.0 / 18.0]
    LEBEL_POKER = [-0.08560642, 0.08560642]
    LIOURS_DICE_4_SIDES = [1.0 / 16.0, -1.0 / 16.0]
    GOOFSPIEL_4_CARDS = [0.0, 0.0]
    GOOFSPIEL_5_CARDS = [0.0, 0.0]
    
    @staticmethod
    def get_game_value(game_name: str) -> list[float]:
        if game_name == "kuhn_poker":
            return GameValue.KUHN_POKER
        elif game_name == "leduc_poker" or game_name == "leduc_poker(suit_isomorphism=True)":
            return GameValue.LEBEL_POKER
        elif game_name == "liars_dice(dice_sides=4)":
            return GameValue.LIOURS_DICE_4_SIDES
        elif game_name == "turn_based_simultaneous_game(game=goofspiel(imp_info=True,num_cards=4,points_order=descending))":
            return GameValue.GOOFSPIEL_4_CARDS
        elif game_name == "turn_based_simultaneous_game(game=goofspiel(imp_info=True,num_cards=5,points_order=descending))":
            return GameValue.GOOFSPIEL_5_CARDS
        else:
            raise ValueError(f"Unknown game name for retrieving game value for {game_name}")


def conterfactual_values(
    ind_regrets: list[float], utilities: list[float]
) -> list[float]:
    conterfactual_values = [0.0 for _ in ind_regrets]
    for i in range(len(ind_regrets)):
        other_idx = -(i + 1)
        conterfactual_values[i] = ind_regrets[other_idx] + utilities[other_idx]
    return conterfactual_values

def individual_nash_convs(conterfactual_values: list[float], game_name: str):
    game_values = GameValue.get_game_value(game_name)
    game_values = game_values[::-1]
    return [cfv - gv for cfv, gv in zip(conterfactual_values, game_values)]


class BaseSolver(LiteEFG.Graph):
    def __init__(self):
        super().__init__()

    def train(
        self,
        traverse_type,
        convergence_type,
        iter,
        print_freq,
        game_env,
        output_dir,
    ):
        game = pyspiel.load_game(game_env)
        env = LiteEFG.OpenSpielEnv(game, traverse_type=traverse_type, regenerate=False)
        env.set_graph(self)
        pbar = tqdm(total=iter)
        result_dict = {}
        for i in range(iter):
            self.update_graph(env)
            env.update_strategy(
                self.current_strategy(), update_best=(convergence_type == STRATEGY_TYPE_NAME.BEST_ITERATE)
            )
            if i % print_freq == 0:
                strategy = self.current_strategy()
                last_exploitabilities = env.exploitability(strategy, STRATEGY_TYPE_NAME.LAST_ITERATE)
                last_exploitability = sum(last_exploitabilities)
                avg_exploitabilities = env.exploitability(strategy, STRATEGY_TYPE_NAME.AVG_ITERATE)
                avg_exploitability = sum(avg_exploitabilities)
                pbar.set_description(
                    f'Last Exploitability: {last_exploitability:.8f}, Avg Exploitability: {avg_exploitability:.8f}'
                )
                pbar.update(print_freq)
                last_conterfactual_values = conterfactual_values(
                    last_exploitabilities,
                    env.utility(strategy, STRATEGY_TYPE_NAME.LAST_ITERATE),
                )
                avg_conterfactual_values = conterfactual_values(
                    avg_exploitabilities,
                    env.utility(strategy, STRATEGY_TYPE_NAME.AVG_ITERATE),
                )
                last_individual_nash_convs = individual_nash_convs(
                    last_conterfactual_values, game_env
                )
                avg_individual_nash_convs = individual_nash_convs(
                    avg_conterfactual_values, game_env
                )
                result_dict[i] = {
                    "last_exploitability": last_exploitability,
                    "avg_exploitability": avg_exploitability,
                    "last_ind_regrets_player1": last_exploitabilities[1],
                    "last_ind_regrets_player2": last_exploitabilities[0],
                    "avg_ind_regrets_player1": avg_exploitabilities[1],
                    "avg_ind_regrets_player2": avg_exploitabilities[0],
                    "last_ind_nash_conv_player1": last_individual_nash_convs[0],
                    "last_ind_nash_conv_player2": last_individual_nash_convs[1],
                    "sum_ind_nash_conv": sum(last_conterfactual_values),
                    "avg_ind_nash_conv_player1": avg_individual_nash_convs[0],
                    "avg_ind_nash_conv_player2": avg_individual_nash_convs[1],
                    "sum_avg_ind_nash_conv": sum(avg_conterfactual_values),
                }
            if i % (print_freq * 10) == 0 and i != 0:
                df = pd.DataFrame.from_dict(result_dict, orient='index')
                df.to_csv(f'{output_dir}/exploitability_log_{i}.csv')
            
        df = pd.DataFrame.from_dict(result_dict, orient='index')
        df.to_csv(f'{output_dir}/exploitability_log.csv')
        pbar.close()
        
        if True:
            _, avg_df_list = env.get_strategy(self.current_strategy(), STRATEGY_TYPE_NAME.AVG_ITERATE)
            _, last_df_list = env.get_strategy(self.current_strategy(), STRATEGY_TYPE_NAME.LAST_ITERATE)
            for i, df in enumerate(avg_df_list):
                df['Infoset'] = df['Infoset'].apply(lambda x: x.replace('\n', '\\n'))
                df.to_csv(f"{output_dir}/avg_strategy_{i}.csv", quoting=csv.QUOTE_MINIMAL, quotechar='"')
            for i, df in enumerate(last_df_list):
                df['Infoset'] = df['Infoset'].apply(lambda x: x.replace('\n', '\\n'))
                df.to_csv(f"{output_dir}/last_strategy_{i}.csv", quoting=csv.QUOTE_MINIMAL, quotechar='"')


class BaseAsymSolver():
    def __init__(self):
        self.solvers = [self.create_solver(player_id=0), self.create_solver(player_id=1)]
        
    def create_solver(self, player_id: int) -> LiteEFG.Graph:
        raise NotImplementedError

    def train(
        self,
        traverse_type,
        convergence_type,
        iter,
        print_freq,
        game_env,
        output_dir,
    ):
        envs = [
            LiteEFG.OpenSpielEnv(
                pyspiel.load_game(game_env), traverse_type=traverse_type, regenerate=False
            ) for _ in range(2)
        ]
        for solver, env in zip(self.solvers, envs):
            env.set_graph(solver)
        pbar = tqdm(total=iter)
        result_dict = {}
        for i in range(iter):
            for env, solver in zip(envs, self.solvers):
                solver.update_graph(env)
                env.update_strategy(
                    solver.current_strategy(), update_best=(convergence_type == STRATEGY_TYPE_NAME.BEST_ITERATE)
                )
            if i % print_freq == 0:
                result_dict[i] = {}
                last_exploitability = 0
                avg_exploitability = 0
                for i_e, (env, solver) in enumerate(zip(envs, self.solvers)):
                    strategy = solver.current_strategy()
                    
                    last_ind_regrets = env.exploitability(strategy, STRATEGY_TYPE_NAME.LAST_ITERATE)
                    avg_ind_regrets = env.exploitability(strategy, STRATEGY_TYPE_NAME.AVG_ITERATE)
                    
                    last_conterfactual_values = conterfactual_values(
                        last_ind_regrets,
                        env.utility(strategy, STRATEGY_TYPE_NAME.LAST_ITERATE),
                    )
                    avg_conterfactual_values = conterfactual_values(
                        avg_ind_regrets,
                        env.utility(strategy, STRATEGY_TYPE_NAME.AVG_ITERATE),
                    )
                    last_individual_nash_convs = individual_nash_convs(
                        last_conterfactual_values, game_env
                    )
                    avg_individual_nash_convs = individual_nash_convs(avg_conterfactual_values, game_env)
                    
                    last_exploitability += last_conterfactual_values[i_e]
                    avg_exploitability += avg_conterfactual_values[i_e]
                    
                    result_dict[i].update(
                        {
                            f"last_exploitability_player{i_e+1}": sum(last_ind_regrets),
                            f"avg_exploitability_player{i_e+1}": sum(avg_ind_regrets),
                            f"last_ind_regrets_player{i_e+1}": last_ind_regrets[i_e],
                            f"avg_ind_regrets_player{i_e+1}": avg_ind_regrets[i_e],
                            f"last_ind_nash_conv_player{i_e+1}": last_individual_nash_convs[
                                i_e
                            ],
                            f"avg_ind_nash_conv_player{i_e+1}": avg_individual_nash_convs[
                                i_e
                            ],
                        }
                    )
                pbar.set_description(
                    f'Last Exploitability: {last_exploitability:.8f}, Avg Exploitability: {avg_exploitability:.8f}'
                )
                pbar.update(print_freq)
                result_dict[i].update({
                    "last_exploitability": last_exploitability,
                    "avg_exploitability": avg_exploitability,
                })
            if i % (print_freq * 10) == 0 and i != 0:
                df = pd.DataFrame.from_dict(result_dict, orient='index')
                df.to_csv(f'{output_dir}/exploitability_log_{i}.csv')

        df = pd.DataFrame.from_dict(result_dict, orient='index')
        df.to_csv(f'{output_dir}/exploitability_log.csv')
        pbar.close()
        
        if True:
            for i_e, (env, solver) in enumerate(zip(envs, self.solvers)):
                _, avg_df_list = env.get_strategy(solver.current_strategy(), STRATEGY_TYPE_NAME.AVG_ITERATE)
                _, last_df_list = env.get_strategy(solver.current_strategy(), STRATEGY_TYPE_NAME.LAST_ITERATE)
                for j, df in enumerate(avg_df_list):
                    df['Infoset'] = df['Infoset'].apply(lambda x: x.replace('\n', '\\n'))
                    df.to_csv(f"{output_dir}/avg_strategy_{i_e}_{j}.csv", quoting=csv.QUOTE_MINIMAL, quotechar='"')
                for j, df in enumerate(last_df_list):
                    df['Infoset'] = df['Infoset'].apply(lambda x: x.replace('\n', '\\n'))
                    df.to_csv(f"{output_dir}/last_strategy_{i_e}_{j}.csv", quoting=csv.QUOTE_MINIMAL, quotechar='"')

    def individual_nash_convs(
        self, ind_regrets: list[float], utilities: list[float], gama_name: str
    ) -> list[float]:
        game_values = GameValue.get_game_value(gama_name)
        individual_nash_convs = [0.0 for _ in ind_regrets]
        for i in range(len(ind_regrets)):
            other_idx = -(i + 1)
            cntfct_val = ind_regrets[other_idx] + utilities[other_idx]
            individual_nash_convs[i] = cntfct_val - game_values[other_idx]
        return individual_nash_convs