"""
Test sequential rollout for rl policies.

Checking for:
    - action computing
    - data collection
"""

import pytest
import numpy as np

from open_spiel.python import rl_environment

from expground.types import PolicyConfig, RolloutConfig
from expground.utils.data import EpisodeKeys
from expground.utils.preprocessor import get_preprocessor
from expground.utils.rollout import sequential_rollout
from expground.utils.sampler import Sampler
from expground.algorithms.base_policy import Policy
from expground.algorithms.random_policy import RandomPolicy
from expground.algorithms.ddpg import DDPG
from expground.envs import open_spiel_adapters
from expground.envs.agent_interface import AgentInterface


@pytest.fixture(scope="session")
def env_desc():
    creator = rl_environment.Environment
    env_id = "kuhn_poker"
    scenario_config = {"players": 2}
    env = rl_environment.Environment(env_id, **scenario_config)
    observation_spec = env.observation_spec()
    action_spec = env.action_spec()
    player_ids = [i for i in range(env.num_players)]
    observation_spaces = {
        pid: open_spiel_adapters.ObservationSpace(observation_spec)
        for pid in player_ids
    }
    action_spaces = {
        pid: open_spiel_adapters.ActionSpace(action_spec) for pid in player_ids
    }
    return {
        "creator": creator,
        "config": {
            "env_id": env_id,
            "possible_agents": player_ids,
            "action_spaces": action_spaces,
            "observation_spaces": observation_spaces,
            "observation_adapter": open_spiel_adapters.observation_adapter,
            "action_adapter": open_spiel_adapters.action_adapter,
            "scenario_config": scenario_config,
        },
    }


@pytest.mark.parametrize(
    "rollout_config,policy_cls",
    [
        [RolloutConfig(caller=sequential_rollout, fragment_length=20), RandomPolicy],
        [RolloutConfig(caller=sequential_rollout, fragment_length=20), DDPG],
    ],
)
def test_sequential_rollout(
    env_desc, rollout_config: RolloutConfig, policy_cls: Policy
):
    env_config = env_desc["config"]
    action_space = env_config["action_spaces"][0]
    observation_space = env_config["observation_spaces"][0]
    agent_ids = env_config["possible_agents"]
    policy_config = PolicyConfig(
        policy=policy_cls,
        mapping=lambda agent_id: agent_id,
        observation_space=lambda k: env_config["observation_spaces"][k],
        action_space=lambda k: env_config["action_spaces"][k],
    )

    # build policies
    agent_interfaces = {}
    for pid in env_config["possible_agents"]:
        agent_interfaces[pid] = AgentInterface(
            policy_name="",
            policy=policy_config.new_policy_instance(pid),
            observation_space=env_config["observation_spaces"][pid],
            action_space=env_config["action_spaces"][pid],
            observation_adapter=env_config["observation_adapter"],
            action_adapter=env_config["action_adapter"],
        )

    sampler = Sampler(
        agent_ids,
        dtypes={
            EpisodeKeys.ACTION_MASK.value: np.float,
            EpisodeKeys.REWARD.value: np.float,
            EpisodeKeys.NEXT_OBSERVATION.value: np.float,
            EpisodeKeys.DONE.value: np.bool,
            EpisodeKeys.OBSERVATION.value: np.float,
            EpisodeKeys.ACTION.value: np.int,
            "next_action_mask": np.float,
        },
        data_shapes={
            EpisodeKeys.ACTION_MASK.value: (action_space.n,),
            EpisodeKeys.REWARD.value: (),
            EpisodeKeys.NEXT_OBSERVATION.value: get_preprocessor(observation_space)(
                observation_space
            ).shape,
            EpisodeKeys.DONE.value: (),
            EpisodeKeys.OBSERVATION.value: get_preprocessor(observation_space)(
                observation_space
            ).shape,
            EpisodeKeys.ACTION.value: (action_space.n,),
            "next_action_mask": (action_space.n,),
        },
        capacity=1000,
    )
    rollout_config.caller(
        sampler=sampler,
        agent_interfaces=agent_interfaces,
        env_description=env_desc,
        num_episode=rollout_config.num_episodes,
        fragment_length=rollout_config.fragment_length,
    )
