import collections
import ray
import os
import time

import numpy as np

from expground import settings

from expground.types import (
    Any,
    List,
    Dict,
    Union,
    Tuple,
    AgentID,
    PolicyID,
    LambdaType,
    PolicyConfig,
    EnvDescription,
    TrainingConfig,
    RolloutConfig,
)
from expground.learner.base_learner import Learner
from expground.gt import MetaSolver
from expground.learner.utils import (
    generate_random_from_shapes,
    dispatch_resources_for_agent,
)

from .psro import PSROLearner
from expground.logger import Log
from expground.utils.logging import write_to_tensorboard, append_to_table
from expground.learner.independent import IndependentLearner, SubOracle
from expground.learner.centralized import CentralizedLearner
from expground.utils.stoppers import get_stopper, DEFAULT_STOP_CONDITIONS
from expground.gt.payoff_matrix import PayoffMatrix
from expground.gt.payoff_server import Identifier, PayoffServer
from expground.common.exploitability import measure_exploitability

from expground.algorithms.base_policy import Policy


class P2SROLearner(PSROLearner):
    def __init__(
        self,
        meta_solver: str,
        policy_config: PolicyConfig,
        env_description: EnvDescription,
        rollout_config: RolloutConfig,
        training_config: TrainingConfig,
        loss_func: type,
        learning_mode: str,
        episodic_training: bool = False,
        train_every: int = 1,
        use_learnable_dist: bool = False,
        experiment: str = None,
        ray_mode: bool = False,
        seed: int = None,
        mixed_at_every_step: bool = False,
        independent_learning: bool = True,
        distribution_training_kwargs: Dict = None,
        centralized_critic_config: Dict = None,
        rectifier_type: int = 0,
        active_init_num: int = 3,
        fixed_init_num: int = 1,
        parallel_num: int = 4,
        plateaus_threshold: float = 0.03,
        save_indep_curve_every: int = 20,
        train_at_least: int = 4,
        mini_epoch: int = 10,
        resource_limit: Dict = None,
        use_cuda: bool = False,
        exp_config=None,
    ):

        self.active_init_num = active_init_num
        self.fixed_init_num = fixed_init_num
        self.parallel_num = parallel_num
        self.plateaus_threshold = plateaus_threshold
        self.train_at_least = train_at_least

        if experiment is None:
            experiment = "psro"
        super(PSROLearner, self).__init__(
            # f"{inner_learner_type}_{experiment}"
            # or f"{self.NAME}_{policy_config.human_readable or policy_config.policy}_{time.time()}",
            experiment,
            seed=seed,
            ray_mode=ray_mode,
            exp_config=exp_config,
            rollout_config=rollout_config,
            env_desc=env_description,
            resource_config={"evaluation_worker_num": 0},
        )

        use_remote_payoff = False

        np.set_printoptions(linewidth=999)

        possible_agents = env_description["config"]["possible_agents"]
        resource_limit = resource_limit or {}
        learner_resources = dispatch_resources_for_agent(
            possible_agents * self.parallel_num,
            use_cuda=use_cuda,
            resource_limit=resource_limit,
        )
        # learner_resources["num_cpus"] = 2.0
        Log.info(
            "Resource per learner (%s in total): %s",
            len(possible_agents),
            learner_resources,
        )

        # build agent learners with given policies from policy pool
        if independent_learning:
            if use_remote_payoff:
                learner_cls = SubOracle
            else:
                learner_cls = IndependentLearner
            if ray_mode:
                learner_cls = learner_cls.as_remote(**learner_resources).remote
        else:
            if ray_mode:
                learner_cls = CentralizedLearner.as_remote(**learner_resources).remote
            else:
                learner_cls = CentralizedLearner

        policy_pool_config = {
            "mixed_at_every_step": mixed_at_every_step,
            "use_learnable_dist": use_learnable_dist,
            "distribution_training_kwargs": distribution_training_kwargs,
            "start_fixed_support_num": 1,
            "start_active_support_num": 0,
            "seed": seed,
        }

        sub_learner_custom_coonfig = (
            {
                # "enable_evaluation_pool": True,
                "use_vector_env": rollout_config.vector_mode,
                "enable_evaluation_pool": True,
            }
            if independent_learning
            else {
                "centralized_critic_config": centralized_critic_config,
                "groups": None,  # none-specific groups for sub learners
                "share_critic": False,
            }
        )

        self._agent_learners = self.init_agent_learners(
            learner_cls=learner_cls,
            policy_config=policy_config,
            env_description=env_description,
            rollout_config=rollout_config,
            training_config=training_config,
            loss_func=loss_func,
            learning_mode=learning_mode,
            episodic_training=episodic_training,
            train_every=train_every,
            seed=seed,
            possible_agents=possible_agents,
            enable_policy_pool=True,
            policy_pool_config=policy_pool_config,
            custom_config=sub_learner_custom_coonfig,
            resource_config=resource_limit,
            mini_epoch=mini_epoch,
            exp_config=exp_config,
        )

        # create meta solver
        self._meta_solver = MetaSolver.from_type(meta_solver)

        self._payoff_server = None
        self._payoff_matrix = []
        for k in range(self.parallel_num):
            self._payoff_matrix.append(
                PayoffMatrix(tuple(possible_agents), rectifier_type=rectifier_type)
            )

        self._env_description = env_description
        self._rollout_config = rollout_config
        self._use_learnable_dist = use_learnable_dist
        self._use_remote_payoff = use_remote_payoff
        self._save_indep_curve_every = save_indep_curve_every

        self.register_env_agents(possible_agents)

        # sync policies to learners
        # self._remove_dominated_policy = remove_dominated_policy
        # fixed_policy_ids = self._sync_up(0)
        # self._payoff_matrix.expand(fixed_policy_ids)
        # if use_remote_payoff:
        #     ray.get(self._payoff_server.set_payoff_matrix.remote(self._payoff_matrix))

    def init_agent_learners(
        self,
        learner_cls: type,
        policy_config,
        env_description,
        rollout_config,
        training_config,
        loss_func,
        learning_mode,
        episodic_training,
        train_every,
        seed: int,
        possible_agents: List[AgentID],
        enable_policy_pool: bool,
        policy_pool_config: Dict[str, Any],
        custom_config: Dict[str, Any],
        resource_config: Dict[str, Any],
        mini_epoch: int,
        exp_config: Dict[str, Any],
    ) -> List[Dict[AgentID, Learner]]:
        agent_learners = []
        for k in range(self.parallel_num):
            agent_learner = {
                # initialize policy pool for each ego agent
                aid: learner_cls(
                    policy_config=policy_config,
                    env_description=env_description,
                    rollout_config=rollout_config,
                    training_config=training_config,
                    loss_func=loss_func,
                    learning_mode=learning_mode,
                    episodic_training=episodic_training,
                    train_every=train_every,
                    ego_agents=[aid],
                    experiment=f"{self.experiment_tag}/{aid}",
                    seed=seed,
                    ray_mode=self.ray_mode,
                    exp_config=exp_config,
                    enable_policy_pool=enable_policy_pool,
                    summary_writer=self.summary_writer if not self.ray_mode else None,
                    turn_off_logging=True,
                    policy_pool_config=policy_pool_config,
                    resource_config=resource_config,
                    custom_config=custom_config,
                    reset_every_learn=False,
                    inner_eval=False,
                    pretrain_mode=k > 0,
                    mini_epoch=mini_epoch,
                )
                for aid in possible_agents
            }
            agent_learners.append(agent_learner)

        return agent_learners

    def _sync_up_fixed(self):
        fixed_policies = {}
        if self.ray_mode:
            res = ray.get(
                [
                    learner.get_ego_fixed_policies.remote("cpu")
                    for learner in self._agent_learners[0].values()
                ]
            )
            for e in res:
                fixed_policies.update(e)
            # print("sync fixed, ", fixed_policies)

            for k in range(self.parallel_num):
                ray.get(
                    [
                        learner.sync_policies.remote(
                            fixed_policies, force_sync=True, sync_type="union"
                        )
                        for learner in self._agent_learners[k].values()
                    ]
                )
        else:
            for agent, learner in self._agent_learners[0].items():
                fixed_policies.update(learner.get_ego_fixed_policies("cpu"))
            for k in range(self.parallel_num):
                for aid, learner in self._agent_learners[k].items():
                    learner.sync_policies(
                        fixed_policies, force_sync=True, sync_type="union"
                    )

    def _sync_up(self, step: int):
        """Sync fixed policies to learners, and expand payoff/simulation table."""
        active_policies = [{} for _ in range(self.parallel_num)]
        num_agent = len(self._agent_learners[0])
        # get all active policies
        if self.ray_mode:
            ray_list = []
            for k in range(self.parallel_num - 1):
                ray_list.extend(
                    [
                        learner.get_ego_active_policies.remote()
                        for learner in self._agent_learners[k].values()
                    ]
                )
            res = ray.get(ray_list)
            for k in range(self.parallel_num):
                res_k = res[k * num_agent : (k + 1) * num_agent]
                for e in res_k:
                    active_policies[k].update(e)
            #     if k > 0:
            #         active_policies[k].update(active_policies[k-1])

            ray_list = []
            sync_policies = {}
            for k in range(1, self.parallel_num):
                res_k = res[(k - 1) * num_agent : (k) * num_agent]
                for e in res_k:
                    for agent_id in e:
                        for policy_id in e[agent_id]:
                            e[agent_id][policy_id].is_fixed = True
                        if sync_policies.get(agent_id) is None:
                            sync_policies[agent_id] = {}
                        sync_policies[agent_id].update(e[agent_id])
                # print("at level {}, the sync_policies is {}".format(k, sync_policies))
                ray_list.extend(
                    [
                        learner.sync_policies.remote(
                            sync_policies, force_sync=True, sync_type="union"
                        )
                        for learner in self._agent_learners[k].values()
                    ]
                )

            ray.get(ray_list)
        else:
            sync_policies = {}
            for k in range(self.parallel_num - 1):
                for agent, learner in self._agent_learners[k].items():
                    policies = learner.get_ego_active_policies()
                    for agent_id in policies:
                        for policy_id in policies[agent_id]:
                            policies[agent_id][policy_id].is_fixed = True
                        if sync_policies.get(agent_id) is None:
                            sync_policies[agent_id] = {}
                        sync_policies[agent_id].update(policies[agent_id])

                for agent, learner in self._agent_learners[k].items():
                    learner.sync_policies(
                        sync_policies, force_sync=True, sync_type="union"
                    )

        # TODO: payoff matrix

    def _run_centralized_meta_solver(
        self,
    ) -> Tuple[
        List[Dict[AgentID, Dict[PolicyID, float]]],
        List[Dict[AgentID, Dict[PolicyID, Policy]]],
    ]:
        equilibrium_parallel = []
        fixed_policies_parallel = []
        num_agent = len(self._agent_learners[0])

        # TODO: real parallelize it
        active_fix_policies = None
        ray_list = []
        total_simus = []
        for k in range(self.parallel_num):
            res = ray.get(
                [
                    learner.get_ego_fixed_policy_ids.remote()
                    for learner in self._agent_learners[k].values()
                ]
            )
            current_fix_policies = {}
            for r in res:
                current_fix_policies.update(r)
            self._payoff_matrix[k].expand(current_fix_policies)

            # for kk in range(0, k):
            #     self._payoff_matrix[k].update_payoff_and_simulation_status(simu_results[kk])
            simus = self._payoff_matrix[k].gen_simulations(
                active_fix_policies, split=True
            )
            total_simus.append(len(simus[0]) + len(simus[1]))

            res = ray.get(
                [
                    learner.get_ego_active_policy_ids.remote()
                    for learner in self._agent_learners[k].values()
                ]
            )
            active_fix_policies = {}
            for r in res:
                active_fix_policies.update(r)

            ray_list.extend(
                [
                    learner.evaluation.remote(
                        simulation,
                        max_step=self._rollout_config.max_step,
                        fragment_length=self._rollout_config.num_simulation
                        * self._rollout_config.max_step,
                        # ========= for test =========
                        # fragment_length=1
                        #                 * self._rollout_config.max_step,
                    )
                    for learner, simulation in zip(
                        self._agent_learners[k].values(), simus
                    )
                ]
            )
            # print("ray_list length", len(ray_list))
        # print("ray_list", ray_list)
        eval_results = ray.get(ray_list)
        Log.info(
            "\t* evaluation finished for %s workers, %s[%s] simus",
            len(eval_results),
            np.sum(total_simus),
            total_simus,
        )
        # print(eval_results)

        update_dict = []
        set_dones = []
        for k in range(self.parallel_num):
            for kk in range((k) * num_agent, (k + 1) * num_agent):
                update_dict.extend(eval_results[kk])
                set_done = [k == 0] * len(eval_results[kk])
                set_dones.extend(set_done)
            # print("eval results:", update_dict, set_dones)

            self._payoff_matrix[k].update_payoff_and_simulation_status(
                update_dict, set_dones
            )

        # calculate equlibrium for all k workers
        for k in range(self.parallel_num):
            fixed_policies = {}
            if self.ray_mode:
                res = ray.get(
                    [
                        learner.get_ego_fixed_policies.remote()
                        for learner in self._agent_learners[k].values()
                    ]
                )
                for e in res:
                    fixed_policies.update(e)
            else:
                for agent, learner in self._agent_learners[k].items():
                    fixed_policies.update(learner.get_ego_fixed_policies())
            # a dict of payoff tables
            fixed_policy_mapping = {
                aid: list(_policies.keys()) for aid, _policies in fixed_policies.items()
            }
            matrix = self._payoff_matrix[k].get_sub_matrix(fixed_policy_mapping)
            equilibrium = self._meta_solver.solve(matrix)
            # merge with poicy
            equilibrium = {
                k: dict(
                    zip(
                        fixed_policy_mapping[k],
                        equilibrium[k],
                    )
                )
                for k, v in matrix.items()
            }
            equilibrium_parallel.append(equilibrium)
            fixed_policies_parallel.append(fixed_policies)
            # print("equilibrium", k, equilibrium, fixed_policies)

        return equilibrium_parallel, fixed_policies_parallel, update_dict

    def _print_cur_workers(self):
        # print the status of all policies.
        for k in range(self.parallel_num):
            Log.info("level: {}".format(k))
            res = ray.get(
                [
                    learner.get_all_policies.remote()
                    for learner in self._agent_learners[k].values()
                ]
            )
            for ii, policies in enumerate(res):
                print_dict = {}
                for agent_id in policies:
                    print_dict[agent_id] = []
                    for policy_id in policies[agent_id]:
                        print_dict[agent_id].append(
                            (policy_id, policies[agent_id][policy_id].is_fixed)
                        )
                Log.info("{}, {}".format(ii, print_dict))

    def _init_parallel_workers(self):
        # init policy support
        ray_list = []
        for i in range(self.parallel_num):
            if self.ray_mode:
                ray_list.extend(
                    [
                        learner.add_policy.remote(n_support=1 + i)
                        for learner in self._agent_learners[i].values()
                    ]
                )
            else:
                for agent, learner in self._agent_learners[i].items():
                    learner.add_policy(n_support=1 + i)

        if self.ray_mode:
            ray.get(ray_list)

        self._sync_up_fixed()

        for k in range(1, self.parallel_num):
            if self.ray_mode:
                res = ray.get(
                    [
                        learner.get_ego_active_policy_ids.remote()
                        for learner in self._agent_learners[k].values()
                    ]
                )
                ray_list = []
                for learner, policies in zip(self._agent_learners[k].values(), res):
                    policies_to_fix = policies.copy()
                    for agent_id in policies:
                        policies_to_fix[agent_id] = policies[agent_id][:-1]

                    #     policies[agent_id] = policies[agent_id][:-1]
                    ray_list.append(
                        learner.set_ego_policy_fixed.remote(policies_to_fix)
                    )
                ray.get(ray_list)
            else:
                res = [
                    learner.get_ego_active_policy_ids()
                    for learner in self._agent_learners[k].values()
                ]

                for learner, policies in zip(self._agent_learners[k].values(), res):
                    policies_to_fix = policies.copy()
                    for agent_id in policies:
                        policies_to_fix[agent_id] = policies[agent_id][:-1]
                    learner.set_ego_policy_fixed(policies_to_fix)

        # self._print_cur_workers()

        # sync policy
        self._sync_up(0)

        active_fix_policies = None

        simu_results = []
        for k in range(self.parallel_num):
            if self.ray_mode:
                res = ray.get(
                    [
                        learner.get_ego_fixed_policy_ids.remote()
                        for learner in self._agent_learners[k].values()
                    ]
                )
                current_fix_policies = {}
                for r in res:
                    current_fix_policies.update(r)
                self._payoff_matrix[k].expand(current_fix_policies)

        self._run_centralized_meta_solver()

        # self._print_cur_workers()

    def _add_new_level(self, equilibrium, full_update_dict):
        active_policies = []
        lowest_learner = self._agent_learners[0]
        self.count_time("new_level/sync")
        for k in range(self.parallel_num):
            res = ray.get(
                [
                    learner.get_ego_active_policies.remote()
                    for learner in self._agent_learners[k].values()
                ]
            )
            policies = {}
            for r in res:
                policies.update(r)
            active_policies.append(policies)

        policy_ids = {}
        res = ray.get(
            [
                learner.get_ego_active_policy_ids.remote()
                for learner in self._agent_learners[0].values()
            ]
        )
        for r in res:
            policy_ids.update(r)
        ray.get(
            [
                learner.set_ego_policy_fixed.remote(policy_ids)
                for learner in self._agent_learners[0].values()
            ]
        )
        # set fix
        active_policy_ids = {}
        res = ray.get(
            [
                learner.get_ego_active_policy_ids.remote()
                for learner in self._agent_learners[0].values()
            ]
        )
        for e in res:
            active_policy_ids.update(e)
        ray.get(
            [
                learner.set_ego_policy_fixed.remote(active_policy_ids)
                for learner in self._agent_learners[0].values()
            ]
        )

        # sync
        tmp_active_policy = {}
        for aid, v in active_policies[0].items():
            tmp_active_policy[aid] = v.copy()
        for k in range(self.parallel_num):
            for aid, v in active_policies[k].items():
                tmp_active_policy[aid].update(active_policies[k][aid])
        for aid, ps in tmp_active_policy.items():
            for pid in ps:
                ps[pid].is_fixed = True

        ray.get(
            [
                learner.sync_policies.remote(
                    tmp_active_policy, force_sync=True, sync_type="union"
                )
                for learner in self._agent_learners[0].values()
            ]
        )
        ray.get(
            [
                learner.add_policy.remote(n_support=1)
                for learner in self._agent_learners[0].values()
            ]
        )

        self._agent_learners.pop(0)
        self._agent_learners.append(lowest_learner)

        # self._print_cur_workers()

        self.count_time("new_level/sync", False)
        self.count_time("new_level/update_latest", True)
        last_equilibrium, last_fixed_policy = self._update_latest_level(
            active_policies, full_update_dict
        )
        self.count_time("new_level/update_latest", False)

        # reset training timestep for DQN
        ray.get(
            [
                learner.finish_pretrain.remote()
                for learner in self._agent_learners[0].values()
            ]
        )
        ray.get(
            [
                learner.reset_all.remote()
                for learner in self._agent_learners[self.parallel_num - 1].values()
            ]
        )

        return last_equilibrium, last_fixed_policy

    def _update_latest_level(self, active_policies, full_update_dict):
        # update payoff
        lowest_mat = self._payoff_matrix[0]
        self._payoff_matrix.pop(0)
        self._payoff_matrix.append(lowest_mat)

        for k in range(self.parallel_num):
            # print("lowest mat append", active_policies[k])
            lowest_mat.expand(active_policies[k])
        # print("full update dict", full_update_dict)
        lowest_mat.update_payoff_and_simulation_status(full_update_dict)

        # eval the rest payoff table
        simus = lowest_mat.gen_simulations(split=True)
        # print("latest simus", simus)
        self.count_time(
            "add_new/simu",
        )
        eval_result = ray.get(
            [
                learner.evaluation.remote(
                    simulation,
                    max_step=self._rollout_config.max_step,
                    fragment_length=self._rollout_config.num_simulation
                    * self._rollout_config.max_step,
                    # ========= for test =========
                    # fragment_length=1
                    #                 * self._rollout_config.max_step,
                )
                for learner, simulation in zip(
                    self._agent_learners[self.parallel_num - 1].values(), simus
                )
            ]
        )
        self.count_time("add_new/simu", False)
        update_dict = []
        Log.info("\t+ re-evaluate with %s simus", len(simus[0]) + len(simus[1]))
        for r in eval_result:
            update_dict.extend(r)
        lowest_mat.update_payoff_and_simulation_status(
            update_dict, [False] * len(update_dict)
        )

        # solve the new equilibrium
        last_fixed_policy = {}
        res = ray.get(
            [
                learner.get_ego_fixed_policies.remote()
                for learner in self._agent_learners[self.parallel_num - 1].values()
            ]
        )
        for e in res:
            last_fixed_policy.update(e)
        fixed_policy_mapping = {
            aid: list(_policies.keys()) for aid, _policies in last_fixed_policy.items()
        }
        matrix = self._payoff_matrix[self.parallel_num - 1].get_sub_matrix(
            fixed_policy_mapping
        )
        self.count_time(
            "add_new/meta_solver",
        )
        last_equilibrium = self._meta_solver.solve(matrix)
        self.count_time(
            "add_new/meta_solver",
        )
        last_equilibrium = {
            k: dict(
                zip(
                    fixed_policy_mapping[k],
                    last_equilibrium[k],
                )
            )
            for k, v in matrix.items()
        }

        # update behavior policy of latest policy
        ray.get(
            [
                learner.set_behavior_dist.remote(last_equilibrium)
                for learner in self._agent_learners[self.parallel_num - 1].values()
            ]
        )

        return last_equilibrium, last_fixed_policy

    def _eval_lowest_reward(self, equilibrium, fixed_agent):
        reward = []
        fixed_policy_mapping = {
            aid: list(_policies.keys()) for aid, _policies in fixed_agent.items()
        }
        matrix = self._payoff_matrix[1].get_sub_matrix(fixed_policy_mapping)
        # print("eval elements", fixed_policy_mapping, matrix, equilibrium)
        for i, agent_id in enumerate(fixed_agent):
            v_idx = [slice(None)] * len(fixed_agent)
            v_idx[i] = -1
            v_idx = tuple(v_idx)
            payoff = matrix[agent_id][v_idx].reshape([-1])[:-1]
            p_prob = [equilibrium[agent_id][p] for p in sorted(equilibrium[agent_id])]
            p_prob = np.array(p_prob)
            r = np.dot(payoff, p_prob)
            # print("performance agent {}, {}, {}, {}".format(agent_id, payoff, p_prob, r))
            reward.append(r)
        return np.mean(reward)

    def _reset_indep_summary(self, round):
        set_none = True
        if (round + 1) % self._save_indep_curve_every == 0 or round == 0:
            set_none = False

        if self.ray_mode:
            ray_list = []
            for agent, learner in self._agent_learners[0].items():
                indep_log_path = ""
                if not set_none:
                    indep_log_path = self._exp_config.get_path(f"{agent}_{round}")
                ray_list.append(
                    learner.reset_summary_writer.remote(indep_log_path, set_none)
                )
            res = ray.get(ray_list)
        else:
            for agent, learner in self._agent_learners[0].items():
                indep_log_path = ""
                if not set_none:
                    indep_log_path = self._exp_config.get_path(f"{agent}_{round}")
                learner.reset_summary_writer(indep_log_path, set_none)

    def _update_agent_matrix(self, equilibrium):
        # equilibrium table
        for agent, learner in self._agent_learners[0].items():
            active_policy_ids = {}
            if self.ray_mode:
                active_policy_ids = ray.get(learner.get_ego_active_policy_ids.remote())

            else:
                active_policy_ids = learner.get_ego_active_policy_ids()
            # update equilibrium
            assert len(active_policy_ids) == 1
            train_id = list(active_policy_ids.values())[0]
            f_path = os.path.join(self._exp_config.log_path, f"equilibrium_{agent}.csv")
            title = []
            pids = list(equilibrium[agent].keys())
            pids.sort()
            row = [equilibrium[agent][i] for i in pids]
            append_to_table(f_path, title, row)

            # payoff table
            f_path = os.path.join(self._exp_config.log_path, f"payoff_{agent}.csv")
            # row = [self._payoff_matrix.payoff_matrix[agent].get_agent_support_idx(i) for i in pids]
            title = []
            # agent = list(self._agent_learners.keys())[0]
            assert (
                len(self._payoff_matrix[0].payoff_matrix[agent].agents) == 2
            ), "only support agents=2 for now"
            if agent == "player_0":
                row = self._payoff_matrix[0].payoff_matrix[agent].table[-1].tolist()
            else:
                row = (
                    self._payoff_matrix[0]
                    .payoff_matrix[agent]
                    .table[:, -1]
                    .reshape([-1])
                    .tolist()
                )
            append_to_table(f_path, title, row)

    def count_time(self, name, start=True):
        if start:
            self.time_cost_table[name] = time.time()
        else:
            self.time_cost_table[name] = time.time() - self.time_cost_table[name]

    def print_time(self):
        Log.info("\t* Time Cost: %s", self.time_cost_table)

    def learn(
        self,
        sampler_config: Union[Dict, LambdaType],
        stop_conditions: Dict,
        inner_stop_conditions: Dict,
    ):
        # 0. init all parallel workers
        self._init_parallel_workers()

        stopper = get_stopper(stop_conditions or DEFAULT_STOP_CONDITIONS)
        stopper.reset()
        round = 0
        last_add_round = 0

        last_measure_reward = -np.inf

        #  Total algorithm
        #       1. recalculate payoff table
        #       2. calculate equilibrium and set behavior policy
        #       3. Evaluate the lowest level [check]
        #       4. if necessary, a. fixed lowest level, b. add policies for all the levels []
        #       5. train all workers for k epochs
        #       6. sync active models [check sync]
        # KeyPoint:
        #   1. How to manage the payoff server
        #   2. Evaluate the performance of lowest level.

        policy_number = 0

        while not stopper.is_terminal():
            Log.info("Global iteration on: %s", stopper.counter)
            self._reset_indep_summary(round)
            self.time_cost_table = {}

            # 1. recalculate payoff table
            # 2. Calculate equilibrium and set behavior policy
            if not self._use_learnable_dist:
                self.count_time("meta_solver", True)
                (
                    equilibrium,
                    fixed_policies,
                    full_update_dict,
                ) = self._run_centralized_meta_solver()
                self.count_time("meta_solver", False)

                Log.info(
                    "\t* oldest behavior dist: %s",
                    equilibrium[0],
                )
                Log.info(
                    "\t* latest behavior dist: %s", equilibrium[self.parallel_num - 1]
                )

                if self.ray_mode:
                    for k in range(self.parallel_num):
                        ray.get(
                            [
                                learner.set_behavior_dist.remote(equilibrium[k])
                                for learner in self._agent_learners[k].values()
                            ]
                        )
                else:
                    for k in range(self.parallel_num):
                        for aid, learner in self._agent_learners[k].items():
                            learner.set_behavior_dist(equilibrium[k])
            else:
                raise NotImplementedError()
                equilibrium, fixed_policies = self._run_decentralized_meta_solver()

            # TODO: dump equlibirium to log files
            # self._update_agent_matrix(equilibrium[0])

            # 3. Evaluate the lowest level
            current_reward = self._eval_lowest_reward(equilibrium[0], fixed_policies[1])
            need_to_add_policy = False
            Log.info(
                "\t* check threshold: reward {}, last reward {}, diff: {} < {} = {}".format(
                    current_reward,
                    last_measure_reward,
                    current_reward - last_measure_reward,
                    self.plateaus_threshold,
                    current_reward < last_measure_reward + self.plateaus_threshold,
                )
            )
            if current_reward < last_measure_reward + self.plateaus_threshold:
                need_to_add_policy = True

            last_measure_reward = current_reward

            if round - last_add_round < self.train_at_least:
                Log.info(
                    "\t* Train a new level at least {} rounds, {} rounds now.".format(
                        self.train_at_least, round - last_add_round
                    )
                )
                need_to_add_policy = False

            # 4. if necessary,
            #   a. fixed lowest level,
            #   b. add policies for all the levels
            #   c. update equ and behavior policy
            if need_to_add_policy:
                policy_number += 1
                self.count_time("add_new_level", True)
                last_add_round = round
                last_equilibrium, last_fixed_policy = self._add_new_level(
                    equilibrium, full_update_dict
                )
                self.count_time("add_new_level", False)
                equilibrium.pop(0)
                equilibrium.append(last_equilibrium)
                fixed_policies.pop(0)
                fixed_policies.append(last_fixed_policy)

                self._update_agent_matrix(equilibrium[0])

                last_measure_reward = self._eval_lowest_reward(
                    equilibrium[0], fixed_policies[1]
                )
                Log.info(
                    "\t+ update new level, new eval reward {}, equilibrium {}".format(
                        last_measure_reward, str(equilibrium[0])
                    )
                )

            # and sync to remote server
            # if self._use_remote_payoff:
            #     ray.get(
            #         self._payoff_server.set_payoff_matrix.remote(self._payoff_matrix)
            #     )

            # 5. train all workers for k epochs
            learning_res = {}
            # ========= for test =========
            # inner_stop_conditions['max_episode'] = 1
            self.count_time("train", True)
            if self.ray_mode:
                train_list = []
                for k in range(self.parallel_num):
                    # print("train cond", sampler_config, inner_stop_conditions)
                    train_list.extend(
                        [
                            learner.learn.remote(
                                sampler_config, stop_conditions=inner_stop_conditions
                            )
                            for learner in self._agent_learners[k].values()
                        ]
                    )
                ray.get(train_list)
                # for k in range(self.parallel_num):
                #     learning_res = dict(
                #         zip(
                #             list(self._agent_learners[k].keys()), tmp[k * 2 : k * 2 + 2]
                #         )
                #     )
            else:
                for aid, learner in self._agent_learners[k].items():
                    learning_res[aid] = learner.learn(
                        sampler_config, stop_conditions=inner_stop_conditions
                    )
                    # then set fixed
            self.count_time("train", False)
            # dump
            if self._use_learnable_dist:
                (
                    dpsro_equilibrium,
                    dpsro_fixed_policies,
                ) = self._run_decentralized_meta_solver()

                Log.info("\t* computed equilibrium as: %s", dpsro_equilibrium)
                dpsro_nash_conv = measure_exploitability(
                    self._env_description["config"]["env_id"],
                    dpsro_fixed_policies,
                    dpsro_equilibrium,
                )
                nash_report = {"NashConv/dpsro": dpsro_nash_conv.nash_conv}
                Log.info("\t* dpsro nash conv: %s", dpsro_nash_conv)
            else:
                nash_conv = measure_exploitability(
                    self._env_description["config"]["env_id"],
                    fixed_policies[0],
                    equilibrium[0],
                )
                nash_report = {"NashConv/base": nash_conv.nash_conv}
                Log.info("\t* nash conv: %s", nash_conv)
            write_to_tensorboard(
                self.summary_writer,
                info=nash_report,
                global_step=stopper.counter,
                prefix="",
            )

            # nash_report["LearnerRes"] = learning_res
            nash_report["NashConv/policy_number"] = policy_number

            write_to_tensorboard(
                self.summary_writer,
                info=nash_report,
                global_step=stopper.counter,
                prefix="",
            )

            # 6. sync active models
            self.count_time("sync_up", True)
            self._sync_up(stopper.counter)
            self.count_time("sync_up", False)
            Log.info("\t* synced up")
            self.print_time()
            # stopper is used to compute exploitability
            stopper.step(None, None, None, None)
            Log.info("")
            round += 1
