import numpy as np
from d3rlpy.metrics import EvaluatorProtocol
from d3rlpy.dataset import ReplayBufferBase
from d3rlpy.interface import QLearningAlgoProtocol
import os

class CustomEnvironmentEvaluator(EvaluatorProtocol):
    r"""Action matches between algorithms.

    Args:
        env: Gym environment.
        n_trials: Number of episodes to evaluate.
        epsilon: Probability of random action.
    """
    _n_trials: int
    _epsilon: float

    def __init__(
        self,
        env,
        n_trials: int = 10,
        epsilon: float = 0.0,
        steps: int =30
    ):
        self._env = env
        self._n_trials = n_trials
        self._epsilon = epsilon
        self._steps = steps

    def __call__(
        self, algo: QLearningAlgoProtocol, dataset: ReplayBufferBase
    ) -> float:
        episode_rewards = []
        for _ in range(self._n_trials):
            observation, _ = self._env.reset()
            episode_reward = []

            #while not done:
            for _ in range(self._steps):
                if isinstance(observation, np.ndarray):
                    observation = np.expand_dims(observation, axis=0)
                elif isinstance(observation, (tuple, list)):
                    observation = [
                        np.expand_dims(o, axis=0) for o in observation
                    ]
                else:
                    raise ValueError(
                        f"Unsupported observation type: {type(observation)}"
                    )
                action = algo.predict(observation)[0]

                observation, reward, terminated, truncated, _ = self._env.step(action)
                episode_reward.append(float(reward))

            # Position of nearest point to optimum
            episode_rewards.append(np.max(episode_reward))

        # Among all nearest points, return the median
        return float(np.median(episode_rewards))


class CallbackList():
    def __init__(self, callbacks):
        self.callbacks = callbacks

    def __call__(self, model, epoch, total_steps):
        [c(model, epoch, total_steps) for c in self.callbacks]


class SaveCallback():
    def __init__(self, path, save_interval=1):
        self.mode_save_dir = path
        self.interval = save_interval

    def __call__(self, algo, epoch, total_steps):
        if epoch % self.interval == 0:
            algo.save_model(os.path.join(self.mode_save_dir, f'epoch_{epoch}.pt'))
