from typing import Any, Dict, List, Optional, Tuple, Type, Union


import gym
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.type_aliases import GymEnv, Schedule
from stable_baselines3.dqn.policies import DQNPolicy

from stable_baselines3 import DQN
import torch as th
import torch.nn.functional as F
import numpy as np



class MDQN(DQN):

    def __init__(
        self,
        policy: Union[str, Type[DQNPolicy]],
        env: Union[GymEnv, str],
        learning_rate: Union[float, Schedule] = 1e-4,
        buffer_size: int = 1000000,  # 1e6
        learning_starts: int = 5000,
        batch_size: Optional[int] = 32,
        tau: float = 0.99,
        gamma: float = 0.99,
        train_freq: Union[int, Tuple[int, str]] = 4,
        gradient_steps: int = 1,
        replay_buffer_class: Optional[ReplayBuffer] = None,
        replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
        optimize_memory_usage: bool = False,
        target_update_interval: int = 8000,
        exploration_fraction: float = 0.1,
        exploration_initial_eps: float = 1.0,
        exploration_final_eps: float = 0.01,
        max_grad_norm: float = 10,
        tensorboard_log: Optional[str] = None,
        create_eval_env: bool = False,
        policy_kwargs: Optional[Dict[str, Any]] = None,
        verbose: int = 0,
        seed: Optional[int] = None,
        device: Union[th.device, str] = "auto",
        _init_setup_model: bool = True,
        entropy_tau: float = 0.03,
        MunCoef: float = 0.9):

        super().__init__(
            policy,
            env,
            learning_rate=learning_rate,
            buffer_size=buffer_size,
            learning_starts=learning_starts,
            batch_size=batch_size,
            tau=tau,
            gamma=gamma,
            train_freq=train_freq,
            gradient_steps=gradient_steps,
            optimize_memory_usage=optimize_memory_usage,
            target_update_interval=target_update_interval,
            exploration_fraction=exploration_fraction,
            exploration_initial_eps=exploration_initial_eps,
            exploration_final_eps=exploration_final_eps,
            max_grad_norm=max_grad_norm,
            tensorboard_log=tensorboard_log,
            create_eval_env=create_eval_env,
            policy_kwargs=policy_kwargs,
            verbose=verbose,
            seed=seed,
            device=device,
            _init_setup_model=_init_setup_model
        )


        self.entropy_tau = entropy_tau
        self.MunCoef = MunCoef


    def train(self, gradient_steps: int, batch_size: int = 100) -> None:

        self._update_learning_rate(self.policy.optimizer)

        losses = []
        for _ in range(gradient_steps):
            # Sample replay buffer
            replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)

            with th.no_grad():
                # Compute the next Q-values using the target network
                next_q_values = self.q_net_target(replay_data.next_observations)

                # compute log-sum-exp and logpi
                logsum = th.logsumexp((next_q_values - next_q_values.max(dim=1)[0].unsqueeze(-1)) / self.entropy_tau, 1).unsqueeze(-1)

                # calculate target values using next values
                tau_logpi_next = next_q_values - next_q_values.max(dim=1)[0].unsqueeze(-1) - self.entropy_tau * logsum

                pi_target = F.softmax(next_q_values / self.entropy_tau, dim=1)

                target_q = (self.gamma * (pi_target * (next_q_values - tau_logpi_next) * (1 - replay_data.dones)).sum(dim=1)).unsqueeze(-1)


                # calculate munchausen term using current values
                q_target = self.q_net_target(replay_data.observations)

                v_target = q_target.max(dim=1)[0].unsqueeze(-1)

                logsum = th.logsumexp((q_target - v_target) / self.entropy_tau, 1).unsqueeze(-1)

                logpi = q_target - v_target - self.entropy_tau * logsum

                munchausen = logpi.gather(1, replay_data.actions.long())

                munchausen_reward = (replay_data.rewards + self.MunCoef * th.clamp(munchausen, min=-1, max=0))
                
                # target values comprise target value and munchausen term
                target_q_values = target_q + munchausen_reward

                #print('Shape target_q: {}, munchausen: {}'.format(target_q.shape, munchausen_reward.shape))


            # Get current Q-values estimates
            current_q_values = self.q_net(replay_data.observations)

            # Retrieve the q-values for the actions from the replay buffer
            current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long())

            # Compute Huber loss (less sensitive to outliers)
            loss = F.smooth_l1_loss(current_q_values, target_q_values)
            losses.append(loss.item())

            # Optimize the policy
            self.policy.optimizer.zero_grad()
            loss.backward()
            # Clip gradient norm
            th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
            self.policy.optimizer.step()

        # Increase update counter
        self._n_updates += gradient_steps

        #self.logger.record("train/munchausen", munchausen_reward)
        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/loss", np.mean(losses))