import gym
import numpy as np

from expground.types import Dict
from expground.utils.data import EpisodeKeys

from .env import env as Poker


def env_desc_gen(env_id: str, scenario_config: Dict):
    env = Poker(env_id=env_id, scenario_config=scenario_config)
    return {
        "creator": Poker,
        "config": {
            "env_id": env_id,
            "possible_agents": env.possible_agents,
            "action_spaces": env.action_spaces,
            "observation_spaces": env.observation_spaces,
            "scenario_config": scenario_config,
            "group": {"player_0": ["player_0"], "player_1": ["player_1"]},
        },
    }


def basic_sampler_config(
    observation_space: gym.Space,
    action_space: gym.Space,
    preprocessor: object,
    capacity: int = 1000,
    learning_starts: int = 64,
):
    sampler_config = {
        "dtypes": {
            EpisodeKeys.ACTION_MASK.value: float,
            EpisodeKeys.REWARD.value: float,
            EpisodeKeys.NEXT_OBSERVATION.value: float,
            EpisodeKeys.DONE.value: bool,
            EpisodeKeys.OBSERVATION.value: float,
            EpisodeKeys.ACTION.value: np.int32,
            EpisodeKeys.ACTION_DIST.value: float,
            "next_action_mask": float,
            EpisodeKeys.ACTION_LOGITS.value: float,
        },
        "data_shapes": {
            EpisodeKeys.ACTION_MASK.value: (action_space.n,),
            EpisodeKeys.REWARD.value: (),
            EpisodeKeys.NEXT_OBSERVATION.value: preprocessor.shape,
            EpisodeKeys.DONE.value: (),
            EpisodeKeys.OBSERVATION.value: preprocessor.shape,
            EpisodeKeys.ACTION.value: (),
            EpisodeKeys.ACTION_DIST.value: (action_space.n,),
            EpisodeKeys.ACTION_LOGITS.value: (action_space.n,),
            "next_action_mask": (action_space.n,),
        },
        "capacity": capacity,
        "learning_starts": learning_starts,
    }
    return sampler_config
