import importlib
import pytest
import ray

from gym import spaces
from ray.util.queue import Queue
from expground.algorithms.random_policy import loss

from expground.types import LearningMode, PolicyConfig, TrainingConfig

from expground.logger import Log

from expground.utils import rollout
from expground.utils.preprocessor import get_preprocessor
from expground.common.utils import MessageType, Message
from expground.common.policy_pool import PolicyPool

from expground.algorithms.random_policy import RandomPolicy
from expground.algorithms.random_policy.trainer import RandomTrainer
from expground.algorithms.random_policy.loss import RandomLoss

from expground.learner.utils import work_flow
from expground.learner.independent.concurrent import CILearner, PersistentRunning

from expground.envs.agent_interface import AgentInterface
from expground.utils.sampler import get_sampler


def _gen_policy_arguments(
    start_idx: int, num: int, obs_space: spaces.Space, act_space: spaces.Space
):
    res = []
    for i in range(start_idx, start_idx + num):
        tmp = dict(
            key=f"policy_{i}",
            policy=RandomPolicy(
                observation_space=obs_space,
                action_space=act_space,
                model_config={},
                custom_config={},
            ),
            fixed=True,
        )
        res.append(tmp)
    return res


def _gen_behavior_dist(num):
    return dict(zip([f"policy_{i}" for i in range(num)], [1 / num] * num))


@pytest.mark.parametrize(
    "module_path,env_id,algorithm,episodic",
    [("expground.envs.gym", "CartPole-v0", RandomPolicy, False)],
)
def test_persistent_running(module_path, env_id, algorithm, episodic):
    if not ray.is_initialized():
        ray.init(local_mode=False)

    module = importlib.import_module(module_path)
    env_desc_gen = getattr(module, "env_desc_gen")
    sampler_config_gen = getattr(module, "basic_sampler_config")
    init_fixed_policy_num = 2
    init_active_policy_num = 1

    env_desc = env_desc_gen(env_id, {})
    possible_agents = env_desc["config"]["possible_agents"]
    obs_spaces = env_desc["config"]["observation_spaces"]
    act_spaces = env_desc["config"]["action_spaces"]
    policy_config = PolicyConfig(
        algorithm,
        observation_space=lambda x: obs_spaces[x],
        action_space=lambda x: act_spaces[x],
    )

    send, recv = Queue(), Queue()
    agent_interfaces = {
        agent: AgentInterface(
            policy_name=None,
            policy=PolicyPool(
                agent,
                policy_config.copy(key=agent),
                start_fixed_support_num=init_fixed_policy_num,
                start_active_support_num=1 if i == 0 else 0,
                mixed_at_every_step=False,
            ),
            observation_space=obs_spaces[agent],
            action_space=act_spaces[agent],
            is_active=(i == 0),
        )
        for i, agent in enumerate(possible_agents)
    }

    # we create trainer for the 0th agent
    trainer = RandomTrainer(RandomLoss())
    sampler_config = sampler_config_gen(
        obs_spaces[possible_agents[0]],
        act_spaces[possible_agents[0]],
        get_preprocessor(obs_spaces[possible_agents[0]])(
            obs_spaces[possible_agents[0]]
        ),
        learning_starts=1,
    )

    runner_kwargs = {
        "runtime": {
            "agent_interfaces": agent_interfaces,
            "env_desc": env_desc,
            "max_step": 10,
            "fragment_length": 100,
            "caller": rollout.simultaneous_rollout,
        },
        "rollout": {
            "sampler_config": ([possible_agents[0]], sampler_config),
            "behavior_policies": None,
        },
        "evaluation": {"policy_mappings": None, "seed": None},
        "optimization": {
            "trainer": {possible_agents[0]: trainer},
            "train_every": 10,
            "episodic": episodic,
        },
    }

    # prepare fixed policy arguments for opponent agents
    info = {
        agent: {
            "policies": _gen_policy_arguments(
                start_idx=init_fixed_policy_num,
                num=2,
                obs_space=obs_spaces[agent],
                act_space=act_spaces[agent],
            ),
            "behavior_dist": _gen_behavior_dist(init_fixed_policy_num + 2),
        }
        for agent in possible_agents[1:]
    }
    task = PersistentRunning.remote(send, recv, work_flow, runner_kwargs)

    # send begin
    message = Message(MessageType.RUN, None)
    send.put(message)

    message = recv.get()
    Log.info("tester got message from sub proc: {}".format(message.info))
    send.put(Message(MessageType.SYNC_POLICIES, info))

    message = recv.get()
    send.put(Message(MessageType.TERMINATE, None))
    message = recv.get()
    print(message)

    ray.shutdown()


class TestConcurrent:
    def init(self):
        self.learner = CILearner(
            policy_config=None,
            env_description=None,
            rollout_config=None,
            training_config=None,
            loss_func=None,
            learning_mode=None,
            episodic_training=None,
            train_every=None,
            ego_agents=None,
            enable_policy_pool=None,
        )

    def test_build_sub_tasks(self):
        pass

    def test_check_stop_condition(self):
        pass

    def test_recycle(self):
        pass

    def test_learn(self):
        pass
