from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union

import gym
import numpy as np
import torch as th
from gym import spaces
from torch.nn import functional as F

from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3 import SAC
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from rl_zoo3.sacm.policies import (
    MlpPolicy,
    SACMPolicy,
)
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer

SACMSelf = TypeVar("SACMSelf", bound="SACM")


class SACM(SAC):
    """
    Soft Actor Critic Mixture (SACM)
    Extends SAC to support mixture policies.
    """

    policy_aliases: Dict[str, Type[BasePolicy]] = {
        "MlpPolicy": MlpPolicy,
    }

    def __init__(
        self,
        policy: Union[str, Type[SACMPolicy]],
        env: Union[GymEnv, str],
        learning_rate: Union[float, Schedule] = 3e-4,
        buffer_size: int = 1_000_000,  # 1e6
        learning_starts: int = 100,
        batch_size: int = 256,
        tau: float = 0.005,
        gamma: float = 0.99,
        train_freq: Union[int, Tuple[int, str]] = 1,
        gradient_steps: int = 1,
        action_noise: Optional[ActionNoise] = None,
        replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
        replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
        optimize_memory_usage: bool = False,
        ent_coef: Union[str, float] = "auto",
        target_update_interval: int = 1,
        target_entropy: Union[str, float] = "auto",
        use_sde: bool = False,
        sde_sample_freq: int = -1,
        use_sde_at_warmup: bool = False,
        tensorboard_log: Optional[str] = None,
        create_eval_env: bool = False,
        policy_kwargs: Optional[Dict[str, Any]] = None,
        verbose: int = 0,
        seed: Optional[int] = None,
        device: Union[th.device, str] = "auto",
        _init_setup_model: bool = False,
        n_components: int = 3,
        component_weights: Optional[List[float]] = None,
        save_path: Optional[str] = "",
    ):
        self.n_components = n_components
        self.component_weights = (
            component_weights
            if component_weights is not None
            else [1.0 / n_components] * n_components
        )
        super().__init__(
            policy,
            env,
            learning_rate,
            buffer_size,
            learning_starts,
            batch_size,
            tau,
            gamma,
            train_freq,
            gradient_steps,
            action_noise,
            replay_buffer_class=replay_buffer_class,
            replay_buffer_kwargs=replay_buffer_kwargs,
            policy_kwargs=policy_kwargs,
            tensorboard_log=tensorboard_log,
            verbose=verbose,
            device=device,
            create_eval_env=create_eval_env,
            seed=seed,
            use_sde=use_sde,
            sde_sample_freq=sde_sample_freq,
            use_sde_at_warmup=use_sde_at_warmup,
            optimize_memory_usage=optimize_memory_usage,
        )
        self.policy_kwargs["n_components"] = n_components

        self._setup_model()

    def _setup_model(self) -> None:
        self._setup_lr_schedule()
        self.set_random_seed(self.seed)

        # Use DictReplayBuffer if needed
        if self.replay_buffer_class is None:
            if isinstance(self.observation_space, spaces.Dict):
                self.replay_buffer_class = DictReplayBuffer
            else:
                self.replay_buffer_class = ReplayBuffer

        elif self.replay_buffer_class == HerReplayBuffer:
            assert (
                self.env is not None
            ), "You must pass an environment when using `HerReplayBuffer`"

            # If using offline sampling, we need a classic replay buffer too
            if self.replay_buffer_kwargs.get("online_sampling", True):
                replay_buffer = None
            else:
                replay_buffer = DictReplayBuffer(
                    self.buffer_size,
                    self.observation_space,
                    self.action_space,
                    device=self.device,
                    optimize_memory_usage=self.optimize_memory_usage,
                )

            self.replay_buffer = HerReplayBuffer(
                self.env,
                self.buffer_size,
                device=self.device,
                replay_buffer=replay_buffer,
                **self.replay_buffer_kwargs,
            )

        if self.replay_buffer is None:
            self.replay_buffer = self.replay_buffer_class(
                self.buffer_size,
                self.observation_space,
                self.action_space,
                device=self.device,
                n_envs=self.n_envs,
                optimize_memory_usage=self.optimize_memory_usage,
                **self.replay_buffer_kwargs,
            )

        self.policy = self.policy_class(  # pytype:disable=not-instantiable
            self.observation_space,
            self.action_space,
            self.lr_schedule,
            **self.policy_kwargs,  # pytype:disable=not-instantiable
        )
        self.policy = self.policy.to(self.device)

        # Convert train freq parameter to TrainFreq object
        self._convert_train_freq()

        self._create_aliases()
        # Running mean and running var
        self.batch_norm_stats = get_parameters_by_name(self.critic, ["running_"])
        self.batch_norm_stats_target = get_parameters_by_name(
            self.critic_target, ["running_"]
        )
        # Target entropy is used when learning the entropy coefficient
        if self.target_entropy == "auto":
            # automatically set target entropy if needed
            self.target_entropy = -np.prod(self.env.action_space.shape).astype(
                np.float32
            )
        else:
            # Force conversion
            # this will also throw an error for unexpected string
            self.target_entropy = float(self.target_entropy)

        # The entropy coefficient or entropy can be learned automatically
        # see Automating Entropy Adjustment for Maximum Entropy RL section
        # of https://arxiv.org/abs/1812.05905
        if isinstance(self.ent_coef, str) and self.ent_coef.startswith("auto"):
            # Default initial value of ent_coef when learned
            init_value = 1.0
            if "_" in self.ent_coef:
                init_value = float(self.ent_coef.split("_")[1])
                assert (
                    init_value > 0.0
                ), "The initial value of ent_coef must be greater than 0"

            # Note: we optimize the log of the entropy coeff which is slightly different from the paper
            # as discussed in https://github.com/rail-berkeley/softlearning/issues/37
            self.log_ent_coef_components = [
                th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True)
                for i in range(self.n_components)
            ]

            self.ent_coef_optimizer_components = [
                th.optim.Adam([self.log_ent_coef_components[i]], lr=self.lr_schedule(1))
                for i in range(self.n_components)
            ]
        else:
            # Force conversion to float
            # this will throw an error if a malformed string (different from 'auto')
            # is passed
            self.ent_coef_tensor = th.tensor(float(self.ent_coef)).to(self.device)

    def _create_aliases(self) -> None:
        self.actor = self.policy.actor
        self.critic = self.policy.critic
        self.critic_target = self.policy.critic_target

    @th.no_grad()
    def mixture_entropy(self, ent_coef_components, next_actions):
        ent_coef_components = (
            th.tensor(ent_coef_components).reshape(-1, 1).to(self.device)
        )

        ent_coef_mixture = th.sum(ent_coef_components)
        action_prob_mixture = th.zeros(next_actions.shape[0]).to(self.device)
        for i in range(self.n_components):
            temp = self.actor.log_prob_component(next_actions, i)
            action_prob_mixture += self.component_weights[i] * th.exp(temp)
        return ent_coef_mixture * th.log(action_prob_mixture.reshape(-1, 1) + 1e-5)

    def train(self, gradient_steps: int, batch_size: int = 64) -> None:
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update optimizers learning rate
        optimizers = [self.actor.optimizer, self.critic.optimizer]
        if self.ent_coef_optimizer is not None:
            optimizers += [self.ent_coef_optimizer]

        # Update learning rate according to lr schedule
        self._update_learning_rate(optimizers)

        ent_coef_losses, ent_coefs = [], []
        actor_losses, critic_losses = [], []

        for gradient_step in range(gradient_steps):
            # Sample replay buffer
            replay_data = self.replay_buffer.sample(
                batch_size, env=self._vec_normalize_env
            )

            # Equation (13): Update entropy temperature
            ent_coef_losses_components, ent_coef_components = [], []
            for i in range(self.n_components):
                _, log_prob_component = self.actor.action_log_prob_component(
                    replay_data.observations, i
                )
                ent_coef_component = th.exp((self.log_ent_coef_components[i]).detach())
                ent_coef_loss_component = -(
                    self.log_ent_coef_components[i]
                    * (log_prob_component + self.target_entropy).detach()
                ).mean()
                ent_coef_losses_components.append(ent_coef_loss_component.item())
                ent_coef_components.append(ent_coef_component.item())

                # Optimize entropy coefficient
                self.ent_coef_optimizer_components[i].zero_grad()
                ent_coef_loss_component.backward()
                self.ent_coef_optimizer_components[i].step()

            ent_coef_losses.append(np.mean(ent_coef_losses_components))
            ent_coefs.append(np.mean(ent_coef_components))

            # Equation (9) and (10): Update Q-function
            with th.no_grad():
                # Select action according to mixture policy
                next_actions, next_log_prob = self.actor.action_log_prob(
                    replay_data.next_observations
                )
                next_actions, next_log_prob = next_actions.to(
                    self.device
                ), next_log_prob.to(self.device)
                next_q_values = th.cat(
                    self.critic_target(replay_data.next_observations, next_actions),
                    dim=1,
                )
                next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
                # Add entropy term
                next_q_values = next_q_values - self.mixture_entropy(
                    ent_coef_components, next_actions
                )
                if th.any(th.isnan(next_q_values)) or th.any(th.isinf(next_q_values)):
                    continue

                target_q_values = (
                    replay_data.rewards
                    + (1 - replay_data.dones) * self.gamma * next_q_values
                )

            current_q_values = self.critic(
                replay_data.observations, replay_data.actions
            )
            critic_loss = 0.5 * sum(
                F.mse_loss(current_q, target_q_values) for current_q in current_q_values
            )
            critic_losses.append(critic_loss.item())

            self.critic.optimizer.zero_grad()
            critic_loss.backward()
            self.critic.optimizer.step()

            # Equation (12): Update policy weights
            ent_coef_components = (
                th.tensor(ent_coef_components).reshape(-1, 1).to(self.device)
            )
            ent_coef_mixture = th.sum(ent_coef_components).detach()

            actor_loss_components = []
            for i in range(self.n_components):
                actions_pi_component, log_prob_component = (
                    self.actor.action_log_prob_component(replay_data.observations, i)
                )
                q_values_pi_component = th.cat(
                    self.critic(replay_data.observations, actions_pi_component), dim=1
                )
                min_qf_pi_component, _ = th.min(
                    q_values_pi_component, dim=1, keepdim=True
                )

                actor_loss_component = (
                    ent_coef_mixture * log_prob_component - min_qf_pi_component
                ).mean()
                actor_loss_components.append(actor_loss_component.item())

                self.actor.optimizer.zero_grad()
                actor_loss_component.backward()
                self.actor.optimizer.step()

            actor_losses.append(np.mean(actor_loss_components))

            if gradient_step % self.target_update_interval == 0:
                polyak_update(
                    self.critic.parameters(), self.critic_target.parameters(), self.tau
                )
                polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)

        self._n_updates += gradient_steps

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/ent_coef", np.mean(ent_coefs))
        self.logger.record("train/actor_loss", np.mean(actor_losses))
        self.logger.record("train/critic_loss", np.mean(critic_losses))
        self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))

    def learn(
        self: SACMSelf,
        total_timesteps: int,
        callback: MaybeCallback = None,
        log_interval: int = 4,
        eval_env: Optional[GymEnv] = None,
        eval_freq: int = -1,
        n_eval_episodes: int = 5,
        tb_log_name: str = "SACM",
        eval_log_path: Optional[str] = None,
        reset_num_timesteps: bool = True,
        progress_bar: bool = False,
    ) -> SACMSelf:

        return super().learn(
            total_timesteps=total_timesteps,
            callback=callback,
            log_interval=log_interval,
            eval_env=eval_env,
            eval_freq=eval_freq,
            n_eval_episodes=n_eval_episodes,
            tb_log_name=tb_log_name,
            eval_log_path=eval_log_path,
            reset_num_timesteps=reset_num_timesteps,
            progress_bar=progress_bar,
        )

    def _excluded_save_params(self) -> List[str]:
        return super()._excluded_save_params() + ["actor", "critic", "critic_target"]

    def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
        state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
        # if self.ent_coef_optimizer is not None:
        #     saved_pytorch_variables = ["log_ent_coef_components"]
        #     state_dicts.append("ent_coef_optimizer")
        # else:
        #     saved_pytorch_variables = ["ent_coef_tensor"]
        # return state_dicts, saved_pytorch_variables
        return state_dicts, []
