from abc import ABCMeta, abstractmethod

from expground.types import Dict, Any, Sequence
from expground.logger import Log
from expground.algorithms.base_policy import Policy
from expground.algorithms.loss_func import LossFunc
from expground.utils.sampler import SamplerInterface


DEFAULT_TRAINING_CONFIG = {"batch_size": 32}


class Trainer(metaclass=ABCMeta):
    def __init__(
        self,
        loss_func: LossFunc,
        training_config: Dict[str, Any] = None,
        policy_instance: Policy = None,
    ):
        self._policy = policy_instance
        self._loss_func = loss_func

        if training_config is None:
            Log.warning(
                "No training config specified, will use default={}".format(
                    DEFAULT_TRAINING_CONFIG
                )
            )
            self._training_config = DEFAULT_TRAINING_CONFIG
        else:
            self._training_config = training_config
        self._step_counter = 0

        assert loss_func is not None, "Loss func cannot be None"
        # self._loss_func.reset(policy_instance, training_config)

    @property
    def counter(self):
        return self._step_counter

    @property
    def loss_func(self):
        """Return the loss function registered in this trainer.

        Returns:
            LossFunc: A loss function.
        """

        return self._loss_func

    def __call__(
        self,
        sampler: SamplerInterface,
        time_step: int,
        agent_filter: Sequence = None,
        n_inner_loop: int = 1,
    ) -> Dict[str, Any]:
        """Implement the training Logic here, and return the computed loss results.

        Args:
            sampler: (Sampler): A sampler instance for sampling.
            agent_filter (Sequence[AgentID]): Determine which agents are governed by this trainer.
                In single agent mode, there will be only one agents be transferred.

        Returns:
            Dict: A dict of training feedback. Could be agent to dict or string to any scalar/vector datas.
        """

        # assert loss func has been initialized
        self._step_counter = time_step
        assert self.loss_func is not None
        # batch sampling
        if sampler.size < self._training_config["batch_size"]:
            batch_size = self._training_config["batch_size"]
            Log.warning(
                f"No enough training data. size={sampler.size} batch_size={batch_size}"
            )
            return {}
        # return a dict of mapping from agent to dict of data.
        for _ in range(n_inner_loop):
            batch = sampler.sample(
                batch_size=self._training_config["batch_size"],
                agent_filter=agent_filter,
            )
            # 1. 2before loss computing, users can do data preprocessing or others.
            self.loss_func.zero_grad()
            batch, feedback = self._before_loss(self._policy, batch)
            # 2. compute loss with sampled batch data
            feedback.update(self.loss_func(batch))
            # 3. step optimizer
            self.loss_func.step()
            # execute some operations when network has been updated
            other_feedback = self._after_loss(self._policy, self._step_counter)

        # collect only the last one
        if other_feedback:
            feedback.update(other_feedback)

        return feedback

    def reset(self, policy_instance=None, configs=None):
        self._step_counter = 0
        self._policy = policy_instance or self._policy
        # print("\t- trainer reset policy:", policy_instance)
        self.loss_func.reset(policy_instance, configs or self._training_config)

    @abstractmethod
    def _before_loss(self, policy, batch):
        """Policy setup or batch preprocessing here and return preprocessed batch"""

    @abstractmethod
    def _after_loss(self, policy, step_counter: int):
        """Something

        Args:
            policy (Policy): The policy instance.
            step_counter (int): Step counter.
            training_config (Dict[str, Any]): Training configuration.

        Raises:
            NotImplementedError: [description]

        Returns:
            [type]: [description]
        """
