from typing import Dict
import gym
from numpy.lib.utils import info

from expground.types import DataArray, Any
from expground.algorithms.base_trainer import Trainer
from expground.algorithms.loss_func import LossFunc
from . import TabularPolicy


class QLearning(TabularPolicy):
    def __init__(
        self, observation_space: gym.Space, action_space: gym.Space, is_fixed: bool
    ):
        assert isinstance(action_space, gym.space.Discrete), action_space
        super(TabularPolicy, self).__init__(observation_space, action_space, is_fixed)

        self._q_table: Dict[str, Dict[str, float]] = dict()

    def _compute_action(self, info_str: str, action_mask: DataArray, evaluate: bool):
        if info_str not in self._q_table:
            self._q_table[info_str] = {a: 0.0 for a in range(self._action_space.n)}
        if evaluate:
            pass
        else:
            pass


class QLoss(LossFunc):
    pass


class QTrainer(Trainer):
    def __init__(
        self,
        loss_func: QLoss,
        training_config: Dict[str, Any],
        policy_instance: QLearning,
    ):
        super(QTrainer, self).__init__(
            loss_func, training_config=training_config, policy_instance=policy_instance
        )
