import gym

from expground.types import Dict, DataArray, Sequence, Tuple, Any
from expground.algorithms.base_trainer import Trainer
from expground.algorithms.loss_func import LossFunc
from . import TabularPolicy


class PolicyIteration(TabularPolicy):
    def __init__(
        self, observation_space: gym.Space, action_space: gym.Space, is_fixed: bool
    ):
        super(PolicyIteration, self).__init__(observation_space, action_space, is_fixed)

    def _compute_action(
        self, info_str: str, action_mask: DataArray, evaluate: bool
    ) -> Tuple[int, Sequence[float]]:
        if info_str not in self.state_dict:
            self.state_dict[info_str] = self.action_dist_handler.proba_distribution_ph()
        if evaluate:
            pass
        else:
            pass


class PILoss(LossFunc):
    def __init__(self, mute_critic_loss: bool):
        super(PILoss).__init__(mute_critic_loss=mute_critic_loss)

    def zero_grad(self):
        return super().zero_grad()

    def setup_extras(self):
        return super().setup_extras()

    def setup_optimizers(self, *args, **kwargs):
        return super().setup_optimizers(*args, **kwargs)

    def step(self) -> Any:
        pass

    def __call__(self, batch: Dict[str, DataArray]) -> Dict[str, Any]:
        pass


class PITrainer(Trainer):
    def __init__(
        self,
        loss_func: PILoss,
        training_config: Dict[str, Any],
        policy_instance: PolicyIteration,
    ):
        super(PITrainer, self).__init__(
            loss_func, training_config=training_config, policy_instance=policy_instance
        )

    def _before_loss(self, policy, batch):
        return super()._before_loss(policy, batch)

    def _after_loss(self, policy, step_counter: int):
        return super()._after_loss(policy, step_counter)
