from abc import abstractmethod
from typing import Tuple

import gymnasium as gym
import numpy as np
import torch
from gymnasium.spaces import Discrete

from core.msa import MSAFlat


class ConditionalActionEnv(gym.Env):

    @property
    @abstractmethod
    def available_mask(self) -> Tuple:
        """
        Return a binary array specifying which options can be run at the current state
        """
        pass

    def sample_action(self, valid_only=True) -> np.int64:
        """
        Randomly pick an action
        :param valid_only: whether only valid actions should be picked
        :return: an action
        """
        assert isinstance(self.action_space, Discrete)  # must be discrete!
        if valid_only:
            return self.action_space.sample(mask=np.array(self.available_mask, dtype=np.int8))
        return self.action_space.sample()


class MSAWrapperEnv(gym.Wrapper, ConditionalActionEnv):
    """
    Example usage:
    ```
    env = gym.make('FourRooms-v0')
    env = MSAWrapperEnv(env,
                        config_file='config.yaml',
                        load_path='save/model_folder',
                        lambda x: x.reshape(-1) / 255.0)
    ```
    """

    @property
    def available_mask(self) -> Tuple:
        return self._env.available_mask

    def __init__(self, env: ConditionalActionEnv, msa_folder, state_transform):
        super().__init__(env)
        self._env = env
        self.state_transform = state_transform
        self.msa = MSAFlat(msa_folder).cpu()

    def reset(self, *args, **kwargs):
        obs, info = self.env.reset(*args, **kwargs)
        return self._encode(obs), info

    def step(self, action):
        obs, reward, done, truncated, info = self.env.step(action)
        return self._encode(obs), reward, done, truncated, info

    def _encode(self, obs):
        with torch.inference_mode():
            obs = torch.tensor(obs, device=self.msa.device, dtype=torch.float32)
            obs = self.state_transform(obs)
            return self.msa.encode(obs).cpu().numpy()

    def _decode(self, obs):
        # for vis purposes only
        with torch.inference_mode():
            obs = torch.tensor(obs, device=self.msa.device, dtype=torch.float32)
            return self.msa.decode(obs).cpu().numpy()

    def __getattr__(self, item):
        return getattr(self.env, item)
