import warnings
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union

import numpy as np
import torch as th
from gymnasium import spaces
from torch.nn import functional as F
import pickle

from stable_baselines3.common.buffers import ReplayBuffer
from src.agent.BaseAgentsSB3 import OffPolicyModelBasedAlgorithm
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy, QNetwork

SelfCADQN = TypeVar("SelfCADQN", bound="CADQN")


class CADQN(OffPolicyModelBasedAlgorithm):
    """
    :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
    :param env: The environment to learn from (if registered in Gym, can be str)
    :param learning_rate: The learning rate, it can be a function
        of the current progress remaining (from 1 to 0)
    :param buffer_size: size of the replay buffer
    :param learning_starts: how many steps of the model to collect transitions for before learning starts
    :param batch_size: Minibatch size for each gradient update
    :param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update
    :param gamma: the discount factor
    :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
        like ``(5, "step")`` or ``(2, "episode")``.
    :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
        Set to ``-1`` means to do as many gradient steps as steps done in the environment
        during the rollout.
    :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
        If ``None``, it will be automatically selected.
    :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
    :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
        at a cost of more complexity.
        See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
    :param target_update_interval: update the target network every ``target_update_interval``
        environment steps.
    :param exploration_fraction: fraction of entire training period over which the exploration rate is reduced
    :param exploration_initial_eps: initial value of random action probability
    :param exploration_final_eps: final value of random action probability
    :param max_grad_norm: The maximum value for the gradient clipping
    :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
        the reported success rate, mean episode length, and mean reward over
    :param tensorboard_log: the log location for tensorboard (if None, no logging)
    :param policy_kwargs: additional arguments to be passed to the policy on creation
    :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
        debug messages
    :param seed: Seed for the pseudo random generators
    :param device: Device (cpu, cuda, ...) on which the code should be run.
        Setting it to auto, the code will be run on the GPU if possible.
    :param _init_setup_model: Whether or not to build the network at the creation of the instance
    """

    policy_aliases: Dict[str, Type[BasePolicy]] = {
        "MlpPolicy": MlpPolicy,
        "CnnPolicy": CnnPolicy,
        "MultiInputPolicy": MultiInputPolicy,
    }
    # Linear schedule will be defined in `_setup_model()`
    exploration_schedule: Schedule
    q_net: QNetwork
    q_net_target: QNetwork
    policy: DQNPolicy

    def __init__(
        self,
        policy: Union[str, Type[DQNPolicy]],
        env: Union[GymEnv, str],
        learning_rate: Union[float, Schedule] = 1e-4,
        buffer_size: int = 1_000_000,  # 1e6
        learning_starts: int = 50000,
        batch_size: int = 32,
        model_batch: int = 32,
        tau: float = 1.0,
        gamma: float = 0.99,
        eta: float = 2, # eta = 1
        train_freq: Union[int, Tuple[int, str]] = 4,
        model_rollout: int = 1,
        model_kwargs: Optional[Dict[str, Any]] = None,
        gradient_steps: int = 1,
        model_grad_steps: int = 1,
        replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
        replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
        optimize_memory_usage: bool = False,
        target_update_interval: int = 1000,
        exploration_fraction: float = 0.1,
        exploration_initial_eps: float = 1.0,
        exploration_final_eps: float = 0.05,
        monitor_wrapper: bool = True,
        max_grad_norm: float = 10,
        tensorboard_log: Optional[str] = None,
        policy_kwargs: Optional[Dict[str, Any]] = None,
        verbose: int = 0,
        seed: Optional[int] = None,
        device: Union[th.device, str] = "auto",
        _init_setup_model: bool = True,
    ) -> None:
        super().__init__(
            policy,
            env,
            learning_rate,
            buffer_size,
            learning_starts,
            batch_size,
            model_batch,
            tau,
            gamma,
            train_freq,
            model_rollout,
            gradient_steps,
            model_grad_steps=model_grad_steps,
            monitor_wrapper=monitor_wrapper,
            action_noise=None,  # No action noise
            replay_buffer_class=replay_buffer_class,
            replay_buffer_kwargs=replay_buffer_kwargs,
            policy_kwargs=policy_kwargs,
            model_kwargs=model_kwargs,
            tensorboard_log=tensorboard_log,
            verbose=verbose,
            device=device,
            seed=seed,
            sde_support=False,
            optimize_memory_usage=optimize_memory_usage,
            supported_action_spaces=(spaces.Discrete,),
            support_multi_env=True,
        )

        self.exploration_initial_eps = exploration_initial_eps
        self.exploration_final_eps = exploration_final_eps
        self.exploration_fraction = exploration_fraction
        self.target_update_interval = target_update_interval
        # For updating the target network with multiple envs:
        self._n_calls = 0
        self.max_grad_norm = max_grad_norm
        # "epsilon" for the epsilon-greedy exploration
        self.exploration_rate = 0.0
        if _init_setup_model:
            self._setup_model()
        self.ht = 0
        self.eta = eta
        self.learning_rate_complex = 1
        self.ht_target = 0

    def _setup_model(self) -> None:
        super()._setup_model()
        self._create_aliases()
        # Copy running stats, see GH issue #996
        self.batch_norm_stats = get_parameters_by_name(self.q_net, ["running_"])
        self.batch_norm_stats_target = get_parameters_by_name(self.q_net_target, ["running_"])
        self.exploration_schedule = get_linear_fn(
            self.exploration_initial_eps,
            self.exploration_final_eps,
            self.exploration_fraction,
        )
        # Account for multiple environments
        # each call to step() corresponds to n_envs transitions
        if self.n_envs > 1:
            if self.n_envs > self.target_update_interval:
                warnings.warn(
                    "The number of environments used is greater than the target network "
                    f"update interval ({self.n_envs} > {self.target_update_interval}), "
                    "therefore the target network will be updated after each call to env.step() "
                    f"which corresponds to {self.n_envs} steps."
                )

            self.target_update_interval = max(self.target_update_interval // self.n_envs, 1)
        self._setup_complexity_q()
        self.complex_optimizer = th.optim.Adam(self.q_complexity.parameters(), lr=self.learning_rate*5)

    def _setup_complexity_q(self) -> None:
        self.q_complexity = pickle.loads(pickle.dumps(self.q_net))
        self.q_complexity_target = pickle.loads(pickle.dumps(self.q_net))

    def _create_aliases(self) -> None:
        self.q_net = self.policy.q_net
        self.q_net_target = self.policy.q_net_target

    def _on_step(self) -> None:
        """
        Update the exploration rate and target network if needed.
        This method is called in ``collect_rollouts()`` after each step in the environment.
        """
        self._n_calls += 1
        if self._n_calls % self.target_update_interval == 0:
            polyak_update(self.q_net.parameters(), self.q_net_target.parameters(), self.tau)
            # Copy running stats, see GH issue #996
            polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
            polyak_update(self.q_complexity.parameters(), self.q_complexity_target.parameters(), self.tau)
            # self.ht_target = self.ht*1.0

        self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
        self.logger.record("rollout/exploration_rate", self.exploration_rate)

    def train(self, gradient_steps: int, batch_size: int = 100) -> None:
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        self.q_complexity.set_training_mode(True)
        # Update learning rate according to schedule
        self._update_learning_rate(self.policy.optimizer)

        losses = []
        losses_c = []
        for _ in range(gradient_steps):
            # Sample replay buffer
            # replay_data = self.model_replay_buffer.sample(batch_size, env=self._vec_normalize_env)  # Model-Based
            replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)  # Model-Free
            with th.no_grad():
                # Compute the next Q-values using the target network
                next_q_values = self.q_net_target(replay_data.next_observations)
                next_qc_values = self.q_complexity_target(replay_data.next_observations)
                # Follow greedy policy: use the one with the highest value
                next_q_values, _ = next_q_values.max(dim=1)
                next_qc_values, _ = next_qc_values.min(dim=1)
                # Avoid potential broadcast issue
                next_q_values = next_q_values.reshape(-1, 1)
                next_qc_values = next_qc_values.reshape(-1, 1)
                # 1-step TD target
                target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
                complxs = self.dynamic_model.predict_complexity(th.as_tensor(replay_data.observations).float(),
                                                                th.as_tensor(replay_data.actions).float()).unsqueeze(1)
                h_delta = complxs-self.ht +(1 - replay_data.dones) * next_qc_values
                # h_delta = complxs - self.ht + next_qc_values
                # h_delta = complxs + (1 - replay_data.dones)*self.gamma * next_qc_values
                self.ht += self.eta*self.learning_rate*th.mean(h_delta)
                # self.learning_rate_complex = (self.learning_rate_complex**-1+1)**-1

            # Get current Q-values estimates
            current_q_values = self.q_net(replay_data.observations)

            # Retrieve the q-values for the actions from the replay buffer
            current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long())


            # Compute Huber loss (less sensitive to outliers)
            loss = F.smooth_l1_loss(current_q_values, target_q_values)
            losses.append(loss.item())

            # Optimize the policy
            self.policy.optimizer.zero_grad()
            loss.backward()
            # Clip gradient norm
            th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
            self.policy.optimizer.step()

            current_qc_values = self.q_complexity(replay_data.observations)
            current_qc_values = th.gather(current_qc_values, dim=1, index=replay_data.actions.long())
            loss = F.smooth_l1_loss(current_qc_values, h_delta)
            loss = th.clip(loss,0,5)
            losses_c.append(loss.item())

            # Optimize the policy
            self.complex_optimizer.zero_grad()
            loss.backward()
            # Clip gradient norm
            # th.nn.utils.clip_grad_norm_(self.q_complexity.parameters(), 1)
            self.complex_optimizer.step()


        # Increase update counter
        self._n_updates += gradient_steps

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/loss", np.mean(losses))
        self.logger.record("train/entropy_rate",self.ht.item())
        self.logger.record("train/complexities",complxs.mean().item())
        self.logger.record("train/complexity_loss", np.mean(losses_c))

    def predict(
        self,
        observation: Union[np.ndarray, Dict[str, np.ndarray]],
        state: Optional[Tuple[np.ndarray, ...]] = None,
        episode_start: Optional[np.ndarray] = None,
        deterministic: bool = False,
        complexity: bool = True,
        tolerance: float = 0.2,
    ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
        """
        Overrides the base_class predict function to include epsilon-greedy exploration.

        :param observation: the input observation
        :param state: The last states (can be None, used in recurrent policies)
        :param episode_start: The last masks (can be None, used in recurrent policies)
        :param deterministic: Whether or not to return deterministic actions.
        :return: the model's action and the next state
            (used in recurrent policies)
        """
        if not deterministic and np.random.rand() < self.exploration_rate:
            if self.policy.is_vectorized_observation(observation):
                if isinstance(observation, dict):
                    n_batch = observation[list(observation.keys())[0]].shape[0]
                else:
                    n_batch = observation.shape[0]
                action = np.array([self.action_space.sample() for _ in range(n_batch)])
            else:
                action = np.array(self.action_space.sample())
        else:
            if complexity:
                # self.policy.set_training_mode(False)
                try:
                    observation, vectorized_env = self.policy.obs_to_tensor(observation)
                except:
                    observation = th.as_tensor(observation).float()
                    vectorized_env = False
                qs = self.q_net(observation).squeeze()
                # Get the incides of the actions that are epsilon close to the maximum
                gap = th.max(qs) - th.min(qs)
                max_actions = th.where(qs >= th.max(qs) - gap*tolerance)[0]
                # max_actions = th.where(qs >= th.max(qs) -tolerance)[0]
                # Get the complexity of the actions
                complxs = self.q_complexity(observation).squeeze()
                complxs = complxs[max_actions]
                # Get the action with the lowest complexity among the ones that are epsilon close to the maximum
                action = max_actions[th.argmin(complxs)]
                action = action.cpu().numpy().reshape((-1, *self.action_space.shape))
                # Remove batch dimension if needed
                if not vectorized_env:
                    action = action.squeeze(axis=0)
            else:
                action, state = self.policy.predict(observation, state, episode_start, deterministic)
        return action, state

    def learn(
        self: SelfCADQN,
        total_timesteps: int,
        callback: MaybeCallback = None,
        log_interval: int = 4,
        tb_log_name: str = "CADQN",
        reset_num_timesteps: bool = True,
        progress_bar: bool = False,
    ) -> SelfCADQN:
        return super().learn(
            total_timesteps=total_timesteps,
            callback=callback,
            log_interval=log_interval,
            tb_log_name=tb_log_name,
            reset_num_timesteps=reset_num_timesteps,
            progress_bar=progress_bar,
        )

    def _excluded_save_params(self) -> List[str]:
        ex_list = [*super()._excluded_save_params(),"simulator", "q_net", "q_net_target",  "q_complexity", "complex_optimizer", "q_complexity_target","dynamic_model"]
        # Remove 'env' from list
        ex_list.remove("env")
        return ex_list

    def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
        state_dicts = ["policy", "policy.optimizer", "q_complexity", "complex_optimizer", "q_complexity_target", "dynamic_model","simulator.model"]

        return state_dicts, []

    def estimate_complexity_rate(self,T=10):
        """
        Estimate the entropy rate of the environment
        """
        self.policy.set_training_mode(False)
        self.q_complexity.set_training_mode(False)
        self.q_complexity_target.set_training_mode(False)
        self.q_net.set_training_mode(False)
        self.q_net_target.set_training_mode(False)
        self.dynamic_model.train(False)
        obs = self.simulator.reset(is_vec=True)[0]
        h = []
        for i in range(T):
            action, _states = self.predict(obs, deterministic=True, complexity=True, tolerance=0.1)
            # make sure action and observation have the correct shape
            h.append(self.dynamic_model.predict_complexity(obs,th.as_tensor(action).unsqueeze(0).float().to(self.device)).cpu().numpy())
            obs, reward, done, info = self.simulator.step(action)
            obs = th.as_tensor(obs).float().to(self.device)
            if done[0]:
                obs = self.simulator.reset(is_vec=True)[0]
        h = np.mean(h)
        self.ht = h
        return h

