import argparse
import pickle
import pprint
import yaml
import time
import os
import ray
import matplotlib.pyplot as plt
import numpy as np
import logging

from expground.settings import BASE_DIR, LOG_DIR, DATA_SUB_DIR_NAME
from expground.types import List
from expground.logger import Log
from expground.algorithms.tabular import minmax, psro


def run_p2sro(args):
    pass


def run_sp(args):
    pass


def run_fsp(args):
    pass


def run_algo(algo, args):
    Log.info("Run algorithm {} with configuration".format(algo))
    pprint.pprint(args.__dict__)
    start = time.time()
    if algo == "fsp":
        res = run_fsp(args)
    elif algo == "sp":
        res = run_sp(args)
    elif algo == "p2sro":
        res = run_p2sro(args)
    elif algo == "psro":
        res = psro.multi_learn(args)
    elif algo == "epsro":
        res = minmax.multi_learn(args)
    Log.info("\t* time consumption: {}".format(time.time() - start))
    return res


def smooth(learning_curves: List[np.ndarray], ratio: float):
    for i, e in enumerate(learning_curves):
        last = 0
        tmp = []
        for _e in e:
            v = ratio * _e + last * (1.0 - ratio)
            tmp.append(v)
            last = v
        learning_curves[i] = tmp
    return learning_curves


if __name__ == "__main__":
    parser = argparse.ArgumentParser("Tabular experiments and visualization.")

    parser.add_argument(
        "--env",
        type=str,
        help="specify environment.",
        choices={"non-transitive", "alphastar", "random"},
        required=True,
    )
    parser.add_argument(
        "--algos", type=str, help="speficy algorithm, could be multiple"
    )
    parser.add_argument("--config_name", type=str, help="configuration file path")
    parser.add_argument(
        "--dim_a", type=int, help="original policy space size for player A.", default=3
    )
    parser.add_argument(
        "--dim_b", type=int, help="original policy space size for player B.", default=3
    )
    parser.add_argument(
        "--symmetric_init",
        action="store_true",
        help="Init constraint policy space in symmetric mode or not.",
    )
    parser.add_argument("--n_episode", type=int, help="training epochs.", default=1000)
    parser.add_argument("--n_group", default=5, type=int)
    parser.add_argument(
        "--max_support_size",
        type=int,
        help="determine how big the final support is.",
        default=10,
    )
    parser.add_argument("--seed", type=int, default=100)
    parser.add_argument("--payoff_config", type=str, default=None)
    parser.add_argument("--local_mode", action="store_true")

    args = parser.parse_args()

    if args.config_name is not None:
        config_path = os.path.join(BASE_DIR, args.config_name)
        if not os.path.exists(config_path):
            Log.error(
                "Illegal config path, please check it again: {}".format(config_path)
            )
            exit(1)
        Log.info("Load configuration from: {}".format(config_path))
        with open(config_path, "r") as f:
            config = yaml.safe_load(f)
        if config.get("debug", False):
            Log.setLevel(logging.DEBUG)
        algo_spec = config.pop("algorithms")
        if config.get("payoff_config", None) is not None:
            config["payoff_config"] = os.path.expanduser(config["payoff_config"])
        # update common
        args.__dict__.update(config)
        Log.info("Updated common args are")
        pprint.pprint(args.__dict__)

        res = {}
        for algo in algo_spec:
            args.__dict__.update(algo_spec[algo])
            res[algo] = run_algo(algo, args)
        # pprint.pprint(res)

        if not ray.is_initialized():
            ray.init(local_mode=args.local_mode)

        log_path = os.path.join(LOG_DIR, "imgs")
        result_dir = os.path.join(LOG_DIR, "results")
        if not os.path.exists(log_path):
            os.makedirs(log_path)
        if not os.path.exists(result_dir):
            os.makedirs(result_dir)

        plt.style.use("ggplot")
        x = np.arange(args.max_support_size)
        plt.yscale("log")
        legends = list(res.keys())
        for k in legends:
            _res = res[k]
            plt.plot(x, _res["NashConv"]["mean"])
            plt.fill_between(
                x, _res["NashConv"]["min"], _res["NashConv"]["max"], alpha=0.2
            )
        plt.legend(legends)
        plt.title("NashConv, game={}".format(args.env))
        plt.savefig(
            "{}/nash_conv_{}_{}.png".format(log_path, args.env, int(time.time()))
        )
        # also save results to pickle
        result_path = os.path.join(
            result_dir, "{}_{}.pkl".format(args.env, time.time())
        )
        with open(result_path, "wb") as f:
            pickle.dump(res, f)
    else:
        raise RuntimeError(
            "You must speficy algorithms or given an available file path for experiment configuration."
        )
