from stable_baselines3 import DQN
import wandb
import numpy as np
from typing import Union, Optional
from stable_baselines3.common.noise import ActionNoise
from gym import spaces
import gymnasium as gym
from rl.policies import DQNPolicy
from utils.config import Config
import random

class ImagePoserDQN(DQN):
    """
    - override predict(): change API/behavior when you call model.predict(obs)
    - override _sample_action(): change how actions are sampled during data collection
      (epsilon-greedy, custom exploration, etc.). This is defined in OffPolicyAlgorithm,
      which DQN inherits from; overriding here replaces that behavior.
    """
    def __init__(self, env: gym.Env, config: Optional[Config] = None, *args, **kwargs):
        if config is not None:
            self.config = config
            self.image_poser_logger = self.config.logger
            self.learning_rate = self.config.learning_rate
            self.gamma = self.config.gamma
            self.exploration_fraction = self.config.exploration_fraction
            self.exploration_initial_eps = self.config.exploration_initial_eps
            self.exploration_final_eps = self.config.exploration_final_eps
            self.learning_starts = self.config.learning_starts
            self.buffer_size = self.config.buffer_size
            self.verbose = self.config.verbose
            self.tensorboard_log = self.config.tensorboard_log

            super().__init__(
                policy=DQNPolicy,
                env=env,
                learning_rate=self.learning_rate,
                gamma=self.gamma,
                exploration_fraction=self.exploration_fraction,
                exploration_initial_eps=self.exploration_initial_eps,
                exploration_final_eps=self.exploration_final_eps,
                learning_starts=self.learning_starts,
                buffer_size=self.buffer_size,
                verbose=self.verbose,
                tensorboard_log=self.tensorboard_log,
                *args,
                **kwargs
            )
        else:
            super().__init__(
                env=env,
                *args,
                **kwargs
            )


    def predict(
        self,
        observation: Union[np.ndarray, dict[str, np.ndarray]],
        state: Optional[tuple[np.ndarray, ...]] = None,
        episode_start: Optional[np.ndarray] = None,
        deterministic: bool = False,
    ) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]:
        """
        Overrides the base_class predict function to include epsilon-greedy exploration.

        :param observation: the input observation
        :param state: The last states (can be None, used in recurrent policies)
        :param episode_start: The last masks (can be None, used in recurrent policies)
        :param deterministic: Whether or not to return deterministic actions.
        :return: the model's action and the next state
            (used in recurrent policies)
        """
        assert self.env is not None, "self.env was not set"
        
        # Access the underlying environment through the vectorized wrapper
        underlying_env = self.env.envs[0] if hasattr(self.env, 'envs') else self.env
        
        if not underlying_env.image_to_use:
            expert_idxs = underlying_env.t2i_model_idx
        else:
            expert_idxs = underlying_env.i2i_model_idx

        if not deterministic and np.random.rand() < self.exploration_rate:
            if self.policy.is_vectorized_observation(observation):
                if isinstance(observation, dict):
                    n_batch = observation[next(iter(observation.keys()))].shape[0]
                else:
                    n_batch = observation.shape[0]

                action = np.array([random.choice(expert_idxs) for _ in range(n_batch)])
            else:
                action = np.array(random.choice(expert_idxs))
        else:
            self.policy.expert_idxs = expert_idxs
            action, state = self.policy.predict(observation, state, episode_start, deterministic)
        return action, state

    def _sample_action(
        self,
        learning_starts: int,
        action_noise: Optional[ActionNoise] = None,
        n_envs: int = 1,
    ) -> tuple[np.ndarray, np.ndarray]:
        """
        Sample an action according to the exploration policy.
        This is either done by sampling the probability distribution of the policy,
        or sampling a random action (from a uniform distribution over the action space)
        or by adding noise to the deterministic output.

        :param action_noise: Action noise that will be used for exploration
            Required for deterministic policy (e.g. TD3). This can also be used
            in addition to the stochastic policy for SAC.
        :param learning_starts: Number of steps before learning for the warm-up phase.
        :param n_envs:
        :return: action to take in the environment
            and scaled action that will be stored in the replay buffer.
            The two differs when the action space is not normalized (bounds are not [-1, 1]).
        """
        assert self.env is not None, "self.env was not set"
        
        # Access the underlying environment through the vectorized wrapper
        underlying_env = self.env.envs[0] if hasattr(self.env, 'envs') else self.env
        
        if not underlying_env.image_to_use:
            expert_idxs = underlying_env.t2i_model_idx
        else:
            expert_idxs = underlying_env.i2i_model_idx

        # Select action randomly or according to policy
        if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup):
            # Warmup phase
            unscaled_action = np.array([random.choice(expert_idxs) for _ in range(n_envs)])
        else:
            # Note: when using continuous actions,
            # we assume that the policy uses tanh to scale the action
            # We use non-deterministic action in the case of SAC, for TD3, it does not matter
            assert self._last_obs is not None, "self._last_obs was not set"
            unscaled_action, _ = self.predict(self._last_obs, deterministic=False)

        # Rescale the action from [low, high] to [-1, 1]
        if isinstance(self.action_space, spaces.Box):
            scaled_action = self.policy.scale_action(unscaled_action)

            # Add noise to the action (improve exploration)
            if action_noise is not None:
                scaled_action = np.clip(scaled_action + action_noise(), -1, 1)

            # We store the scaled action in the buffer
            buffer_action = scaled_action
            action = self.policy.unscale_action(scaled_action)
        else:
            # Discrete case, no need to normalize or clip
            buffer_action = unscaled_action
            action = buffer_action
        return action, buffer_action

