import numpy
from abc import abstractmethod
from modules.agents.Policy import Policy
from modules.utils.Log import Logger
from modules.train.TrainHelper import discount, unpack_trajectories

class Oracle:
    def __init__(self, env, discount_factor: float):
        self._env = env
        self._gamma = discount_factor
        self._name = None

    @abstractmethod
    def _predict(self, episode: int, policy: Policy, trajectories: list, logger: Logger):
        pass

    def gamma(self) -> float:
        return self._gamma

    def name(self) -> str:
        return self._name

    def add_discounted_returns(self, trajectories, logger: Logger) -> None:
        discounted_rewards = []

        for trajectory in trajectories:
            _, _, rewards = unpack_trajectories([trajectory])
            trajectory["discounted_returns"] = discount(rewards, self.gamma())
            discounted_rewards.append(trajectory["discounted_returns"][0])

        batch_size = int(len(discounted_rewards) ** 0.5)
        avg_discounted_rewards = [numpy.mean(discounted_rewards[i:i+batch_size])
                                  for i in numpy.arange(0, len(discounted_rewards), batch_size)]
        if logger is not None:
            logger.log({"_AvgDiscountedRewardSum": numpy.mean(avg_discounted_rewards),
                    "_StdDiscountedRewardSum": numpy.std(avg_discounted_rewards),
                    "_MinDiscountedRewardSum": numpy.min(discounted_rewards),
                    "_MaxDiscountedRewardSum": numpy.max(discounted_rewards)})

    def predict(self, episode: int, policy: Policy, trajectories: list, logger: Logger):
        # calculated discounted sum of Rs
        self.add_discounted_returns(trajectories, logger)
        return self._predict(episode, policy, trajectories, logger)