import math

import torch


class PlayEnv:
    def __init__(self, agent, envs, action_names) -> None:
        self.agent = agent
        self.envs = envs
        self.action_names = action_names
        self.is_human_player = False
        self._env, self.obs, self._t, self._return, self.hx_cx, self.env_id, self.env_name = (None,) * 7
        self.switch_env(0)

    def next_mode(self):
        return True

    def next_axis_1(self):
        self.switch_env(self.env_id + 1)
        return True

    def prev_axis_1(self):
        self.switch_env(self.env_id - 1)
        return True
    
    def next_axis_2(self):
        self.switch_controller()
        return True

    def prev_axis_2(self):
        self.switch_controller()
        return True
    
    def switch_env(self, env_id):
        self.env_id = env_id % len(self.envs)
        self.env_name, self._env = self.envs[self.env_id]
    
    def switch_controller(self):
        self.is_human_player = not self.is_human_player

    def reset(self):
        self.obs, _ = self._env.reset()
        self._t, self._return, self.hx_cx = 0, 0, None
        return self.obs, None

    @torch.no_grad()
    def step(self, act):
        if self.is_human_player:
            act = torch.tensor([act], device=self.agent.device)
        else:
            logits_act, value, self.hx_cx = self.agent.ac(self.obs, self.hx_cx)
            dst = torch.distributions.categorical.Categorical(logits=logits_act)
            act = dst.sample()
            entropy = dst.entropy() / math.log(2)
        entropy = None if self.is_human_player else f'{entropy.item():.2f}'
        value = None if self.is_human_player else f'{value.item():.2f}'
        self.obs, rew, end, trunc, _ = self._env.step(act)
        self._return += rew.item()
        header = [
            [
                f'Env     : {self.env_name}',
                f'Control : {"human" if self.is_human_player else "policy"}',
                f'Timestep: {self._t}',
            ],
            [
                f'Action: {self.action_names[act[0]]}',
                f'Reward: {rew.item():.2f}',
                f'Return: {self._return:.2f}',
            ],
            [
                f'Trunc : {bool(trunc)}',
                f'Done  : {bool(end)}',
            ],
            [
                f'Entropy: {entropy}',
                f'Value  : {value}',
            ]
        ]
        info = {'header': header}
        self._t += 1
        return self.obs, rew, end, trunc, info
