import LiteEFG as leg
import pyspiel

import QFR

import LiteEFG.baselines.CFR as CFR
import LiteEFG.baselines.CFRplus as CFR_Plus
import LiteEFG.baselines.MMD as MMD
import LiteEFG.baselines.OS_MCCFR as OS_MCCFR
import LiteEFG.baselines.Balanced_OMD as BOMD
import LiteEFG.baselines.Balanced_FTRL as BFTRL
import LiteEFG.baselines.DCFR as DCFR
import LiteEFG.baselines.PCFR as PCFR

import numpy as np

import argparse
import logging
import json
import time
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument("--game", type=str, help="Game file", default='leduc_poker(suit_isomorphism=True)')
parser.add_argument("--algo", type=str, choices=['CFR', 'CFR+', 'DCFR', 'QFR', 'MMD', 'PCFR', 'BOMD', 'BFTRL'], help="Learning algorithm", default='QFR')

parser.add_argument("--eta", type=float, help="Learning rate", default=0.1)
parser.add_argument("--tau", type=float, help="Regularization Coefficient", default=0.01)
parser.add_argument("--gamma", type=float, help="Perturbation", default=0.01)
parser.add_argument("--regularizer", type=str, choices=['Euclidean', 'Entropy'], help="Regularizer Type", default="Euclidean")
parser.add_argument("--feedback", type=str, choices=['counterfactual', 'Q', 'traj-Q'], help="Feedback Type", default="counterfactual")

parser.add_argument("--sample", help="Sampling or not", action='store_true')

parser.add_argument("--T", type=int, help="Number of iterations", default=100000)
parser.add_argument("--json", type=str, help="Output JSON file")
parser.add_argument("--out-freq", type=int, default=100, help="Output frequency")

def Compute(args, is_print=True):
    np.set_printoptions(linewidth=10000)

    logging.basicConfig(
        level=logging.DEBUG,
        format='[%(asctime)s|>%(levelname)s] %(message)s')
    
    logging.info(f"=== Parsed game {args.game}")
    logging.info(f"===========================")

    def make_agent():
        feedback = args.feedback
        if args.sample:
            feedback = "Outcome"
        if args.algo == "CFR":
            return CFR.graph() if(not args.sample) else OS_MCCFR.graph(args.gamma, False)
        elif args.algo == "CFR+":
            return CFR_Plus.graph() if(not args.sample) else OS_MCCFR.graph(args.gamma, True)
        elif args.algo == "QFR":
            return QFR.graph(args.eta, args.tau, args.gamma, args.regularizer, feedback)
        elif args.algo == "MMD":
            return MMD.graph(args.eta, args.tau, args.gamma, args.regularizer, feedback)
        elif args.algo == "DCFR":
            return DCFR.graph(args.eta, args.tau, args.gamma)
        elif args.algo == "PCFR":
            return PCFR.graph()
        elif args.algo == "BOMD":
            return BOMD.graph(args.eta, args.eta / 20)
        elif args.algo == "BFTRL":
            return BFTRL.graph(args.eta, args.eta / 20)
        assert False, "Unrecognized algorithm"

    dps = [vars(args)]

    env = pyspiel.load_game(args.game)
    env = leg.OpenSpielEnv(env, traverse_type="Enumerate" if not args.sample else "Outcome")
    
    seed = np.uint32(time.time() * 1e7)
    leg.set_seed(seed)
    alg = make_agent()
    env.set_graph(alg)

    time_usage = 0

    for t in tqdm(range(1, args.T + 1)):
        start_time = time.time()
        alg.update_graph(env)
        time_usage += time.time() - start_time

        #if args.sample and args.algo in ["CFR", "CFR+"]:
        #    env.update_strategy(alg.current_strategy(), update_best=False)
        # env.update_strategy(alg.current_strategy(), update_best=False)

        if (t-1) % args.out_freq == 0:
            #if args.sample and args.algo in ["CFR", "CFR+"]:
            #    regrets = env.exploitability(alg.current_strategy(), "avg-iterate")
            #else:
            regrets = env.exploitability(alg.current_strategy(), "default")
            logging.info(f"Iteration {t:5} time {time_usage}  regrets  {regrets}   max_regret {max(regrets)}")
            dps.append({'iteration': t, 'time': time_usage, 'regrets': regrets})
    
    if args.json and is_print:
        with open(args.json, 'w') as outfile:
            json.dump(dps, outfile)
    return dps


if __name__ == '__main__':
    args = parser.parse_args()
    import pandas as pd
    if args.sample:
        regrets = []
        time_list = []
        for i in range(100):
            dps = Compute(args, is_print=False)
            data = pd.DataFrame(dps[1:])
            regrets.append(np.sum(data['regrets'].to_list(), axis=1))
            time_list.append(data['time'].to_list())
        
        time_list = np.mean(np.array(time_list), axis=0)
        regrets, std = np.mean(np.array(regrets), axis=0), np.std(np.array(regrets), axis=0)
        for i in range(1, len(dps)):
            dps[i]['time'] = time_list[i-1]
            dps[i]['regrets'] = [regrets[i-1], regrets[i-1]]
            dps[i]['std'] = std[i-1]

        if args.json:
            with open(args.json, 'w') as outfile:
                json.dump(dps, outfile)
    else:
        Compute(args)