from typing import Any, Dict, List, Optional, Tuple, Type, Union
from clearml.backend_config import entry
import gym
import numpy as np
import torch as th
import torch.nn.functional as F
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.preprocessing import maybe_transpose
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update

from sb3_contrib.common.utils import quantile_huber_loss
from sb3_contrib.qrdqn.policies import QRDQNPolicy
from sb3_contrib.qrdqn import QRDQN


class MunchausenQRDQN(QRDQN):

    def __init__(
        self, policy: Union[str, Type[QRDQNPolicy]], 
        env: Union[GymEnv, str], 
        learning_rate: Union[float, Schedule] = 0.00005, 
        buffer_size: int = 1000000, 
        learning_starts: int = 5000, 
        batch_size: Optional[int] = 32, 
        tau: float = 1, 
        gamma: float = 0.99, 
        train_freq: int = 4, 
        gradient_steps: int = 1, 
        replay_buffer_class: Optional[ReplayBuffer] = None, 
        replay_buffer_kwargs: Optional[Dict[str, Any]] = None, 
        optimize_memory_usage: bool = False, 
        target_update_interval: int = 8000, 
        exploration_fraction: float = 0.005, 
        exploration_initial_eps: float = 1, 
        exploration_final_eps: float = 0.01, 
        max_grad_norm: Optional[float] = None, 
        tensorboard_log: Optional[str] = None, 
        create_eval_env: bool = False, 
        policy_kwargs: Optional[Dict[str, Any]] = None, 
        verbose: int = 0, 
        seed: Optional[int] = None, 
        device: Union[th.device, str] = "auto", 
        _init_setup_model: bool = True, 
        entropy_tau: float = 0.03, 
        alpha: float = 0.9
        ):
        
        super().__init__(policy, env, learning_rate=learning_rate, buffer_size=buffer_size, learning_starts=learning_starts, batch_size=batch_size, tau=tau, gamma=gamma, train_freq=train_freq, gradient_steps=gradient_steps, replay_buffer_class=replay_buffer_class, replay_buffer_kwargs=replay_buffer_kwargs, optimize_memory_usage=optimize_memory_usage, target_update_interval=target_update_interval, exploration_fraction=exploration_fraction, exploration_initial_eps=exploration_initial_eps, exploration_final_eps=exploration_final_eps, max_grad_norm=max_grad_norm, tensorboard_log=tensorboard_log, create_eval_env=create_eval_env, policy_kwargs=policy_kwargs, verbose=verbose, seed=seed, device=device, _init_setup_model=_init_setup_model)

        self.entropy_tau = entropy_tau
        self.alpha = alpha

    def train(self, gradient_steps: int, batch_size: int = 100) -> None:
        # Switch to train mode (this affects batch norm / dropout)
        #self.policy.set_training_mode(True)
        # Update learning rate according to schedule
        self._update_learning_rate(self.policy.optimizer)

        losses = []
        for _ in range(gradient_steps):
            # Sample replay buffer
            replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)

            with th.no_grad():
                # Compute the quantiles of next observation (batch_size, N, action_size)
                next_quantiles = self.quantile_net_target(replay_data.next_observations)
                # Compute the greedy actions which maximize the next Q values
                next_values = next_quantiles.mean(dim=1)

                next_policy = F.softmax(next_values / self.entropy_tau, dim=1)
                next_advantage = next_values - next_values.max(dim=1)[0].unsqueeze(-1)
                next_logsum = th.logsumexp(next_advantage / self.entropy_tau, dim=1).unsqueeze(-1)
                next_tau_logpi = next_advantage - self.entropy_tau * next_logsum

                # (batch_size, 1)
                next_quantiles = (next_policy * (next_values - next_tau_logpi)).sum(dim=1).unsqueeze(-1)
                assert next_quantiles.shape == replay_data.rewards.shape
               
                current_values = self.quantile_net_target(replay_data.observations).mean(dim=1)
                current_advantage = current_values - current_values.max(dim=1)[0].unsqueeze(-1)
                current_logsum = th.logsumexp(current_advantage / self.entropy_tau, dim=1).unsqueeze(-1)
                current_tau_logpi = current_advantage - self.entropy_tau * current_logsum
                munchausen_addon = current_tau_logpi.gather(dim=1, index=replay_data.actions.long())

                munchausen_rewards = replay_data.rewards + self.alpha * th.clamp(munchausen_addon, min=-1, max=0)
                target_quantiles = munchausen_rewards + (1 - replay_data.dones) * self.gamma * next_quantiles



                #next_greedy_actions = next_quantiles.mean(dim=1, keepdim=True).argmax(dim=2, keepdim=True)
                # Make "n_quantiles" copies of actions, and reshape to (batch_size, n_quantiles, 1)
                #next_greedy_actions = next_greedy_actions.expand(batch_size, self.n_quantiles, 1)
                # Follow greedy policy: use the one with the highest Q values
                #next_quantiles = next_quantiles.gather(dim=2, index=next_greedy_actions).squeeze(dim=2)
                # 1-step TD target
                #target_quantiles = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_quantiles

            # Get current quantile estimates
            current_quantiles = self.quantile_net(replay_data.observations)

            # Make "n_quantiles" copies of actions, and reshape to (batch_size, n_quantiles, 1).
            actions = replay_data.actions[..., None].long().expand(batch_size, self.n_quantiles, 1)
            # Retrieve the quantiles for the actions from the replay buffer
            current_quantiles = th.gather(current_quantiles, dim=2, index=actions).squeeze(dim=2)

            # Compute Quantile Huber loss, summing over a quantile dimension as in the paper.
            loss = quantile_huber_loss(current_quantiles, target_quantiles, sum_over_quantiles=True)
            losses.append(loss.item())

            # Optimize the policy
            self.policy.optimizer.zero_grad()
            loss.backward()
            # Clip gradient norm
            if self.max_grad_norm is not None:
                th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
            self.policy.optimizer.step()

        # Increase update counter
        self._n_updates += gradient_steps

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/loss", np.mean(losses))




class MunchausenSparseQRDQN(MunchausenQRDQN):
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)


    def _sparsemax_operator(self, q_values):
        
        # sparsemax can handle only two-dim tensors
        # reshape to doable shape then reshape it back
        '''
        input = q_values.transpose(0, self.dim)
        original_size = input.size()
        input = input.reshape(input.size(0), -1)
        input = input.transpose(0, 1)
        '''
        input = q_values
        dim = 1

        # number of dimmension (logits) for sparsemax distribution
        number_of_logits = input.size(dim)

        # for numerical stability...
        input = input - input.max(dim=dim, keepdim=True)[0].expand_as(input)

        # sort the vector in descending order, zs is the naming in the paper (Martins 2016)
        zs = th.sort(input=input, dim=dim, descending=True)[0]
        # range is for computation 1 + k * z > sum zj
        range = th.arange(start=1, end=number_of_logits+1, step=1, device=self.device, dtype=input.dtype).view(1, -1)
        range = range.expand_as(zs)

        # sparsity of the projection, controlled by regularization coefficient
        bound = 1 + range * zs 
        cumulative_sum_zs = th.cumsum(zs , dim)
        is_greater = th.gt(bound, cumulative_sum_zs).type(input.type())
        # k is the maximum and cardinality set of allowable actions, see Eq. (4) of sparsemax (Martins 2016)
        k = th.max(th.max(is_greater * range, dim, keepdim=True)[0], th.Tensor([1]).to(self.device))

        # threshold function
        zs_sparse = is_greater * zs

        # tau appears in both value function and policy
        tau = (th.sum(zs_sparse, dim, keepdim=True) - 1) / k
        tau = tau.expand_as(input)

        policy = th.max(th.zeros_like(input), input-tau)

        spmax = self.entropy_tau * (0.5 * th.sum(th.pow(zs, 2) - th.pow(tau, 2), dim=dim, keepdim=True) + 0.5)

        # reshape back to original shape
        '''
        output = policy
        output = output.transpose(0, 1)
        output = output.reshape(original_size)
        policy = output.transpose(0, self.dim)
        '''
        return policy, spmax



    def train(self, gradient_steps: int, batch_size: int = 100) -> None:
        # Switch to train mode (this affects batch norm / dropout)
        #self.policy.set_training_mode(True)
        # Update learning rate according to schedule
        self._update_learning_rate(self.policy.optimizer)

        losses = []
        for _ in range(gradient_steps):
            # Sample replay buffer
            replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)

            with th.no_grad():
                # Compute the quantiles of next observation (batch_size, N, action_size)
                next_quantiles = self.quantile_net_target(replay_data.next_observations)
                # Compute the greedy actions which maximize the next Q values
                next_values = next_quantiles.mean(dim=1)

                next_policy, _ = self._sparsemax_operator(next_values / self.entropy_tau)

                # (batch_size, 1)
                next_quantiles = (next_policy * next_values).sum(dim=1).unsqueeze(-1)
                assert next_quantiles.shape == replay_data.rewards.shape
               
                current_values = self.quantile_net_target(replay_data.observations).mean(dim=1)
                current_policy, _ = self._sparsemax_operator(current_values / self.entropy_tau)
                munchausen_addon = th.log(current_policy + 1e-8).gather(dim=1, index=replay_data.actions.long())
                assert munchausen_addon.shape == replay_data.rewards.shape
                munchausen_rewards = replay_data.rewards + self.alpha * th.clamp(munchausen_addon, min=-1, max=0)
                target_quantiles = munchausen_rewards + (1 - replay_data.dones) * self.gamma * next_quantiles



                #next_greedy_actions = next_quantiles.mean(dim=1, keepdim=True).argmax(dim=2, keepdim=True)
                # Make "n_quantiles" copies of actions, and reshape to (batch_size, n_quantiles, 1)
                #next_greedy_actions = next_greedy_actions.expand(batch_size, self.n_quantiles, 1)
                # Follow greedy policy: use the one with the highest Q values
                #next_quantiles = next_quantiles.gather(dim=2, index=next_greedy_actions).squeeze(dim=2)
                # 1-step TD target
                #target_quantiles = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_quantiles

            # Get current quantile estimates
            current_quantiles = self.quantile_net(replay_data.observations)

            # Make "n_quantiles" copies of actions, and reshape to (batch_size, n_quantiles, 1).
            actions = replay_data.actions[..., None].long().expand(batch_size, self.n_quantiles, 1)
            # Retrieve the quantiles for the actions from the replay buffer
            current_quantiles = th.gather(current_quantiles, dim=2, index=actions).squeeze(dim=2)

            # Compute Quantile Huber loss, summing over a quantile dimension as in the paper.
            loss = quantile_huber_loss(current_quantiles, target_quantiles, sum_over_quantiles=True)
            losses.append(loss.item())

            # Optimize the policy
            self.policy.optimizer.zero_grad()
            loss.backward()
            # Clip gradient norm
            if self.max_grad_norm is not None:
                th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
            self.policy.optimizer.step()

        # Increase update counter
        self._n_updates += gradient_steps

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/loss", np.mean(losses))




class MVIq(MunchausenSparseQRDQN):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)


    def train(self, gradient_steps: int, batch_size: int = 100) -> None:
         # Switch to train mode (this affects batch norm / dropout)
        #self.policy.set_training_mode(True)
        # Update learning rate according to schedule
        self._update_learning_rate(self.policy.optimizer)

        losses = []
        for _ in range(gradient_steps):
            # Sample replay buffer
            replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)

            with th.no_grad():
                # Compute the quantiles of next observation (batch_size, N, action_size)
                next_quantiles = self.quantile_net_target(replay_data.next_observations)
                # Compute the greedy actions which maximize the next Q values
                next_values = next_quantiles.mean(dim=1)

                next_policy, _ = self._sparsemax_operator(next_values / self.entropy_tau)

                # (batch_size, 1)
                next_quantiles = (next_policy * next_values).sum(dim=1).unsqueeze(-1)
                assert next_quantiles.shape == replay_data.rewards.shape
               
                current_values = self.quantile_net_target(replay_data.observations).mean(dim=1)
                current_policy, _ = self._sparsemax_operator(current_values / self.entropy_tau)
                current_q = current_values.gather(dim=1, index=replay_data.actions.long())
     
                target_quantiles = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_quantiles + self.alpha * (current_q - (current_policy * current_values).sum(dim=1, keepdim=True))

                assert target_quantiles.shape == (batch_size, 1)



                #next_greedy_actions = next_quantiles.mean(dim=1, keepdim=True).argmax(dim=2, keepdim=True)
                # Make "n_quantiles" copies of actions, and reshape to (batch_size, n_quantiles, 1)
                #next_greedy_actions = next_greedy_actions.expand(batch_size, self.n_quantiles, 1)
                # Follow greedy policy: use the one with the highest Q values
                #next_quantiles = next_quantiles.gather(dim=2, index=next_greedy_actions).squeeze(dim=2)
                # 1-step TD target
                #target_quantiles = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_quantiles

            # Get current quantile estimates
            current_quantiles = self.quantile_net(replay_data.observations)

            # Make "n_quantiles" copies of actions, and reshape to (batch_size, n_quantiles, 1).
            actions = replay_data.actions[..., None].long().expand(batch_size, self.n_quantiles, 1)
            # Retrieve the quantiles for the actions from the replay buffer
            current_quantiles = th.gather(current_quantiles, dim=2, index=actions).squeeze(dim=2)

            # Compute Quantile Huber loss, summing over a quantile dimension as in the paper.
            loss = quantile_huber_loss(current_quantiles, target_quantiles, sum_over_quantiles=True)
            losses.append(loss.item())

            # Optimize the policy
            self.policy.optimizer.zero_grad()
            loss.backward()
            # Clip gradient norm
            if self.max_grad_norm is not None:
                th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
            self.policy.optimizer.step()

        # Increase update counter
        self._n_updates += gradient_steps

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/loss", np.mean(losses))




class TsallisQRDQN(MunchausenSparseQRDQN):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)


    def train(self, gradient_steps: int, batch_size: int = 100) -> None:
         # Switch to train mode (this affects batch norm / dropout)
        #self.policy.set_training_mode(True)
        # Update learning rate according to schedule
        self._update_learning_rate(self.policy.optimizer)

        losses = []
        for _ in range(gradient_steps):
            # Sample replay buffer
            replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)

            with th.no_grad():
                # Compute the quantiles of next observation (batch_size, N, action_size)
                next_quantiles = self.quantile_net_target(replay_data.next_observations)
                # Compute the greedy actions which maximize the next Q values
                next_values = next_quantiles.mean(dim=1)

                next_policy, _ = self._sparsemax_operator(next_values / self.entropy_tau)

                # (batch_size, 1)
                next_quantiles = (next_policy * next_values).sum(dim=1).unsqueeze(-1)
                assert next_quantiles.shape == replay_data.rewards.shape
                target_quantiles = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_quantiles 

                assert target_quantiles.shape == (batch_size, 1)

                #next_greedy_actions = next_quantiles.mean(dim=1, keepdim=True).argmax(dim=2, keepdim=True)
                # Make "n_quantiles" copies of actions, and reshape to (batch_size, n_quantiles, 1)
                #next_greedy_actions = next_greedy_actions.expand(batch_size, self.n_quantiles, 1)
                # Follow greedy policy: use the one with the highest Q values
                #next_quantiles = next_quantiles.gather(dim=2, index=next_greedy_actions).squeeze(dim=2)
                # 1-step TD target
                #target_quantiles = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_quantiles

            # Get current quantile estimates
            current_quantiles = self.quantile_net(replay_data.observations)

            # Make "n_quantiles" copies of actions, and reshape to (batch_size, n_quantiles, 1).
            actions = replay_data.actions[..., None].long().expand(batch_size, self.n_quantiles, 1)
            # Retrieve the quantiles for the actions from the replay buffer
            current_quantiles = th.gather(current_quantiles, dim=2, index=actions).squeeze(dim=2)

            # Compute Quantile Huber loss, summing over a quantile dimension as in the paper.
            loss = quantile_huber_loss(current_quantiles, target_quantiles, sum_over_quantiles=True)
            losses.append(loss.item())

            # Optimize the policy
            self.policy.optimizer.zero_grad()
            loss.backward()
            # Clip gradient norm
            if self.max_grad_norm is not None:
                th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
            self.policy.optimizer.step()

        # Increase update counter
        self._n_updates += gradient_steps

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/loss", np.mean(losses))
