"""
Concurrent learning API serves for Pipeline PSRO
"""

import asyncio
import traceback
import ray

from ray.util.queue import Queue

from expground.types import (
    Union,
    Dict,
    LambdaType,
    Sequence,
    AgentID,
    TrainingConfig,
    PolicyConfig,
    RolloutConfig,
    Any,
    List,
    Tuple,
)
from expground.logger import Log
from expground.utils.sampler import get_sampler
from expground.common.utils import MessageType, Message
from expground.learner.utils import work_flow

from .independent import IndependentLearner


@ray.remote
def PersistentRunning(recv: Queue, send: Queue, runner: type, runner_kwargs: Dict):
    """Create an instance of persistent training.
    Runner kwargs include:
        - optimize_runtime_config
            - trainer

        - rollout_runtime_config
            - caller
            - sampler
            - behavior_policies
            - agent_interfaces
            - env_desc
            - fragment_length
            - max_step
            - episodic
            - train_every

        - evaluation_runtime_config
            - policy_mappings
            - max_step
            - fragment_length
            - agent_interfaces
            - caller
            - env_desc
            - seed
    """

    iteration = 1
    last_time_step_th = 0
    last_episode_th = 0

    # build rollout_config and evaluation_config
    agent_interfaces = runner_kwargs["runtime"]["agent_interfaces"]
    rollout_config = runner_kwargs["rollout"]
    rollout_config.update(runner_kwargs["runtime"])

    sampler_config = rollout_config.pop("sampler_config")
    rollout_config["sampler"] = get_sampler(*sampler_config)

    evaluation_config = runner_kwargs["evaluation"]
    evaluation_config.update(runner_kwargs["runtime"])

    optimize_config = runner_kwargs["optimization"]

    try:
        while True:
            if recv.size() < 1:
                if iteration == 1:
                    message == Message(MessageType.WAIT, None)
                else:
                    message = Message(MessageType.RUN, None)
            else:
                message = recv.get_nowait()

            Log.info("got message: {}".format(message))
            if message.opt_type in [MessageType.KILL, MessageType.TERMINATE]:
                break
            elif message.opt_type == MessageType.WAIT:
                continue
            elif message.opt_type == MessageType.RUN:
                pass
            elif message.opt_type == MessageType.SYNC_POLICIES:
                info = message.info
                # sync policies and dist to policy pool
                for aid, _info in info.items():
                    interface = agent_interfaces[aid]
                    if "policies" in _info:
                        _ = [
                            interface.policy.add_policy(**e) for e in _info["policies"]
                        ]
                    if "behavior_dist" in _info:
                        interface.policy.set_distribution(_info["behavior_dist"])

            (
                training_statistic,
                evaluation_statistic,
                last_time_step_th,
                last_episode_th,
            ) = runner(
                optimize_config,
                rollout_config,
                evaluation_config,
                last_time_step_th,
                last_episode_th,
            )
            info = {"iteration": iteration}
            message = Message(MessageType.FEED_BACK, info)
            send.put(message)

            iteration += 1

        send.put(
            MessageType(MessageType.FEED_BACK, {"status": 200, "iteration": iteration})
        )
    except Exception:
        Log.error(traceback.format_exc())


class CILearner(IndependentLearner):
    def __init__(
        self,
        policy_config: PolicyConfig,
        env_description: Dict[str, Any],
        rollout_config: RolloutConfig,
        training_config: TrainingConfig,
        loss_func: type,
        learning_mode: str,
        episodic_training: bool = True,
        train_every: int = 1,
        ego_agents: Sequence[AgentID] = None,
        enable_policy_pool: bool = False,
        experiment: str = None,
        seed: int = None,
        agent_mapping: LambdaType = ...,
        mini_epoch: int = 5,
        **kwargs
    ):
        super().__init__(
            policy_config,
            env_description,
            rollout_config,
            training_config,
            loss_func,
            learning_mode,
            episodic_training=episodic_training,
            train_every=train_every,
            ego_agents=ego_agents,
            enable_policy_pool=enable_policy_pool,
            experiment=experiment,
            seed=seed,
            agent_mapping=agent_mapping,
            mini_epoch=mini_epoch,
            **kwargs
        )

        # self.actor_pool = ActorPool()
        # build multiple training sub process, which
        self.conn: List[Tuple[Queue, Queue]] = []
        self.persistent_actors: List[ray.ObjectRef] = []

    def build_sub_task(self, num: int, sampler_config: Dict[str, Any]):
        sampler = get_sampler(self._ego_agents, sampler_config)
        tasks = self.build_sub_task(sampler_config)
        for task in tasks:
            _conn = (Queue(), Queue())
            actor = PersistentRunning.remote(*_conn, runner=work_flow, rkwargs=task)
            _conn[0].put(MessageType(MessageType.RUN, None))
            self.conn.append(_conn)
            self.persistent_actors.append(actor)

    def recycle(self, i):
        conn = self.conn.pop(i)
        task = self.persistent_actors.pop(i)
        # send terminate signale to this task
        conn[0].put(Message(MessageType.TERMINATE, None))
        # wait for feedback
        message: Message = conn[1].get()
        if message.info["status"] == 200:
            done_ref = task.__ray_terminate__.remote()
            done, not_done = ray.wait([done_ref], timeout=10)
            if not_done:
                ray.kill(task, no_restart=True)

    def check_stop_condition(self, feed_back: Dict[str, Any]):
        raise NotImplementedError

    def learn(
        self, sampler_config: Union[Dict, LambdaType], stop_conditions: Dict = None
    ):
        # start N persistent running actors
        self.build_sub_task(num=1, sampler_config=sampler_config)

        # suppose we run 100 big episodes
        big_episodes = 100
        while big_episodes:
            for i, (send_conn, recv_conn) in enumerate(self.conn):
                message: Message = recv_conn.get()
                if i == 0 and self.check_stop_condition(message.info):
                    # 0 is always the lowest active training task
                    # judge whether this training task achieved stop condition
                    # pop this task into pending_fixed policy queue and generate new active
                    self.recycle(i)
                    self.build_sub_task(num=1, sampler_config=sampler_config)
                else:
                    # sync policies
                    send_conn.put(Message(MessageType.SYNC_POLICIES, self._policies))
            big_episodes -= 1
