import math
from pathlib import Path

import torch


class PlayEnv:
    def __init__(self, agent, envs, action_names) -> None:
        self.agent = agent
        self.envs = envs
        self.all_ckpt_paths = sorted(list(Path('checkpoints/all').iterdir()))
        self.action_names = action_names
        self.is_human_player = False
        self._env, self.obs, self._t, self._return, self.hx_cx, self.env_id, self.env_name, self.ckpt_id, self.epoch = (None,) * 9
        self.switch_env(0)
        self.load_ckpt(-1)

    def next_mode(self):
        self.switch_controller()
        return True

    def next_axis_1(self):
        self.switch_env(self.env_id + 1)
        return True

    def prev_axis_1(self):
        self.switch_env(self.env_id - 1)
        return True
    
    def next_axis_2(self):
        self.load_ckpt(self.ckpt_id + 1)
        return True

    def prev_axis_2(self):
        self.load_ckpt(self.ckpt_id - 1)
        return True
    
    def load_ckpt(self, ckpt_id):
        self.ckpt_id = ckpt_id % len(self.all_ckpt_paths)
        p = self.all_ckpt_paths[self.ckpt_id]
        self.agent.load_state_dict(torch.load(p, map_location=self.agent.device))
        self.epoch = int(p.stem.split('_')[-1])

    def switch_env(self, env_id):
        self.env_id = env_id % len(self.envs)
        self.env_name, self._env = self.envs[self.env_id]
    
    def switch_controller(self):
        self.is_human_player = not self.is_human_player

    def reset(self):
        self.obs, _ = self._env.reset()
        self._t, self._return, self.hx_cx = 0, 0, None
        return self.obs, None

    @torch.no_grad()
    def step(self, act):
        if self.is_human_player:
            act = torch.tensor([act], device=self.agent.device)
        else:
            logits_act, value, self.hx_cx = self.agent.ac(self.obs, self.hx_cx)
            dst = torch.distributions.categorical.Categorical(logits=logits_act)
            act = dst.sample()
            entropy = dst.entropy() / math.log(2)
        entropy = None if self.is_human_player else f'{entropy.item():.2f}'
        value = None if self.is_human_player else f'{value.item():.2f}'
        self.obs, rew, end, trunc, _ = self._env.step(act)
        self._return += rew.item()
        header = [
            [
                f'Env     : {self.env_name}',
                f'Control : {"human" if self.is_human_player else "policy"}',
                f'Epoch   : {self.epoch}',
                f'Timestep: {self._t}',
            ],
            [
                f'Action: {self.action_names[act[0]]}',
                f'Reward: {rew.item():.2f}',
                f'Return: {self._return:.2f}',
            ],
            [
                f'Trunc : {bool(trunc)}',
                f'Done  : {bool(end)}',
            ],
            [
                f'Entropy: {entropy}',
                f'Value  : {value}',
            ]
        ]
        info = {'header': header}
        self._t += 1
        return self.obs, rew, end, trunc, info
