import random

import numpy as np


class RandomAgent:

    def __init__(self, act_space, actor_dist_disc):
        self.act_space = act_space
        self.actor_dist_disc = actor_dist_disc

    def policy(self, obs, state=None, unimix=1.0, horizon=128, mode="train"):
        batch_size = len(next(iter(obs.values())))

        if self.actor_dist_disc != "twohot":
            act = {k: np.stack([v.sample() for _ in range(batch_size)]) for k, v in self.act_space.items() if k != "reset"}
            return act, state

        # TODO: this is a hack for hierarchy
        action = np.zeros((batch_size, self.act_space["action"].shape[-1]))
        low_action = random.randint(0, 14)
        if state is None:
            new_state = {"step": 0, "command": random.randint(0, 2)}
        else:
            new_state = {"step": state["step"] + 1, "command": state["command"]}
        if new_state["step"] % horizon == 0:
            new_state["command"] = random.randint(0, 2)
        action[:, low_action] = 1.0
        action[:, 15 + new_state["command"]] = 1.0
        act = {"action": action}
        return act, new_state
