import gym
import numpy as np

from expground.types import Dict, Any
from expground.utils.data import EpisodeKeys

from .env import Matrix, PayoffType
from .env import env as creator


def env_desc_gen(env_id: str, scenario_config: Dict):
    num_players = scenario_config.get("num_players", 2)
    payoff_type = scenario_config.get("payoff_type", PayoffType.RANDOM_SYMMETRIC)
    dim = scenario_config.get("dim", 10)
    max_cycles = scenario_config.get("max_cycles", 1)
    env = Matrix(
        num_players=num_players, payoff_type=payoff_type, dim=dim, max_cycles=max_cycles
    )
    return {
        "creator": creator,
        "config": {
            "env_id": env_id,
            "possible_agents": env.possible_agents,
            "action_spaces": env.action_spaces,
            "observation_spaces": env.observation_spaces,
            "scenario_config": scenario_config,
            "full_policy_set": env.full_policy_set,
        },
    }


def data_preprocessor(batch: Dict[str, Any]):
    # preprocessor for batched data
    rewards = batch[EpisodeKeys.REWARD.value].copy()

    acc_rew = 0.0
    L = len(rewards) - 1
    for i, r in enumerate(np.flip(rewards)):
        acc_rew += r
        rewards[L - i] = acc_rew

    batch[EpisodeKeys.ACC_REWARD.value] = rewards
    return batch


def basic_sampler_config(
    observation_space: gym.Space,
    action_space: gym.Space,
    preprocessor: object,
    capacity: int = 1000,
    learning_starts: int = 64,
):
    sampler_config = {
        "dtypes": {
            EpisodeKeys.REWARD.value: np.float,
            EpisodeKeys.NEXT_OBSERVATION.value: np.float,
            EpisodeKeys.DONE.value: np.bool,
            EpisodeKeys.OBSERVATION.value: np.float,
            EpisodeKeys.ACTION.value: np.int,
            EpisodeKeys.ACTION_DIST.value: np.float,
            EpisodeKeys.ACC_REWARD.value: np.float,
        },
        "data_shapes": {
            EpisodeKeys.REWARD.value: (),
            EpisodeKeys.NEXT_OBSERVATION.value: preprocessor.shape,
            EpisodeKeys.DONE.value: (),
            EpisodeKeys.OBSERVATION.value: preprocessor.shape,
            EpisodeKeys.ACTION.value: (),
            EpisodeKeys.ACTION_DIST.value: (action_space.n,),
            EpisodeKeys.ACC_REWARD.value: (),
        },
        "capacity": capacity,
        "learning_starts": learning_starts,
        "data_preprocessor": data_preprocessor,
    }
    return sampler_config
