import copy
import gym
import numpy as np
import ray

from dataclasses import dataclass

from expground.logger import Log
from expground.types import DataArray, Sequence, PolicyID, Union, Any, Callable
from expground.common.policy_pool import PolicyPool
from expground.algorithms.base_policy import Policy


DEFAULT_OBSERVATION_ADAPTER = lambda x, obs_spec: x
DEFAULT_ACTION_ADAPTER = lambda x: x


@dataclass
class AgentInterface:
    policy_name: str
    """human-readable policy name"""

    policy: Union[ray._raylet.ObjectRef, Policy]
    """policy, could be remote or local instance."""

    observation_space: gym.spaces.Space
    """observation space"""

    action_space: gym.spaces.Space
    """action space"""

    observation_adapter: Callable = lambda x, obs_spec: x
    """for some environments which require observation reconstruction."""

    action_adapter: Callable = lambda x: x
    """for some environments which require action mapping from abstraction represention from policy to
    human readable actions
    """

    is_active: bool = False
    """ is current agent interface use active policy as behavior policy or not."""

    def __post_init__(self):
        self._warn_time = 0
        self._discrete = type(self.action_space) in [
            gym.spaces.Discrete,
            gym.spaces.MultiDiscrete,
        ]
        if isinstance(self.policy, ray._raylet.ObjectRef):
            self.policy = ray.get(self.policy)
        elif isinstance(self.policy, PolicyPool):
            # enable reward record: (support, reward)
            self._previous_rewards = []

    def compute_action(
        self,
        observation: Sequence[Any],
        action_mask: Sequence[DataArray] = None,
        evaluate: bool = True,
    ) -> Sequence[Any]:
        """Compute action with given a sequence of observation and action_mask. If `policy_id` is not None.
        Agent will retrieve policy tagged with `policy_id` from its policy pool.

        Args:
            observation (Sequence[Any]): A sequence of transformed observations.
            action_mask (Sequence[DataArray]): A sequence of action masks.
            evalute: (bool): Use evaluation mode or not.

        Raises:
            TypeError: Unsupported policy type.

        Returns:
            Sequence[Any]: A sequence of actions.
        """

        obs_shape = self.policy.preprocessor.shape
        original_batch = observation.shape[: -len(obs_shape)]
        observation = observation.reshape((-1,) + obs_shape)
        if action_mask is not None:
            action_mask = action_mask.reshape((-1,) + action_mask.shape)
        if isinstance(self.policy, Policy):
            # (batch, innershape)
            action, action_dist, logits = self.policy.compute_action(
                observation, action_mask, evaluate
            )
            # print("action, action_dist, logits shape: ", action.shape if isinstance(action, np.ndarray) else action, action_dist.shape, logits.shape)
        elif isinstance(self.policy, dict):
            raise DeprecationWarning(
                "Policy dict for agent interface has been deprecated, do not use it!"
            )
            # policy = self.policy[policy_id]
            # action, action_dist = policy.compute_action(
            #     observation, action_mask, evaluate
            # )
        else:
            raise TypeError(f"Unexpected policy type: {type(self.policy)}")
        if isinstance(action, np.ndarray):
            if len(action.shape) > 1:
                a_shape = action.shape[1:]
            else:
                a_shape = ()
        else:
            a_shape = ()
        if len(action_dist.shape) > 1:
            start = 1
        else:
            start = 0
        # a_shape, dist_shape, logits_shape = (
        #     action.shape[start:],
        #     action_dist.shape[start:],
        #     logits.shape[start:],
        # )
        dist_shape, logits_shape = (
            action_dist.shape[start:],
            logits.shape[start:],
        )
        # print("* original shape and inner shape: ", original_batch, a_shape, dist_shape, logits_shape)
        return (
            action.reshape(original_batch + a_shape)
            if isinstance(action, np.ndarray)
            else action,
            action_dist.reshape(original_batch + dist_shape),
            logits.reshape(original_batch + logits_shape),
        )

    def action_mask(self, raw_observation) -> DataArray:
        """Generate an action mask from raw observation.

        Args:
            raw_observation ([type]): Raw environment observation.

        Returns:
            DataArray: A returned action mask.
        """

        shape = (self.action_space.n,) if self._discrete else self.action_space.shape
        if isinstance(raw_observation, dict):
            legal_actions = raw_observation.get("legal_actions")
            if legal_actions is not None:
                legal_actions = (
                    legal_actions[raw_observation["current_player"]]
                    if raw_observation.get("current_player") is not None
                    else legal_actions
                )
                action_mask = np.zeros(shape)
                # FIXME: multi dim
                action_mask[legal_actions] = 1
            else:
                action_mask = raw_observation.get("action_mask")
                if action_mask is None:
                    action_mask = np.ones(shape)
                # assert action_mask is not None, f"Cannot find action mask in raw observation: {raw_observation}!"
            return action_mask
        else:
            return np.ones(shape)

    def transform_observation(self, observation) -> DataArray:
        observation = self.observation_adapter(observation, self.observation_space)

        # policy = self.policy[policy_id] if policy_id is not None else self.policy
        if self.policy.preprocessor is not None:
            return self.policy.preprocessor.transform(observation)
        else:
            if self._warn_time == 0:
                Log.warning(
                    "AgentInterface:: No callable preprocessor, will return the original observation"
                )
                self._warn_time += 1
            return observation

    def reset(self, **kwargs):
        self.policy.reset(is_active=self.is_active, **kwargs)

    def copy(self):
        return AgentInterface(
            policy_name=self.policy_name,
            policy=copy.copy(self.policy),
            observation_adapter=self.observation_adapter,
            observation_space=self.observation_space,
            action_space=self.action_space,
            action_adapter=self.action_adapter,
            is_active=self.is_active,
        )
