from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union

import numpy as np
import torch as th
from torch import nn
from gymnasium import spaces
from torch.nn import functional as F

from src.utils.buffers import ReplayBufferComplexity
from stable_baselines3.common.noise import ActionNoise
from src.agent.BaseAgentsSB3 import RPCBaseAlgorithm
from src.utils.torch_networks import RPCPolicy, RPCActor
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from stable_baselines3.sac.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule


SelfRPC = TypeVar("SelfRPC", bound="RPC")

def squash_to_range(t, low=-float('inf'), high=float('inf')):
    """Squashes an input to the range [low, high]."""
    if low != -float('inf'):
        t_low = -low * th.tanh(t / (-low))
    else:
        t_low = t

    if high != float('inf'):
        t_high = high * th.tanh(t / high)
    else:
        t_high = t

    return th.where(t < 0, t_low, t_high)

def kl_divergence(mu_q, logvar_q, mu_p, logvar_p):
    """Compute KL divergence KL(q||p) between two Gaussians q and p with diagonal covariance matrices."""
    kl = 0.5 * (logvar_p - logvar_q + (th.exp(logvar_q) + (mu_q - mu_p) ** 2) / th.exp(logvar_p) - 1)
    return kl.sum(dim=1)


class EncoderNet(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, latent_dim=64, identity_encoder=False):
        super().__init__()
        self.identity_encoder = identity_encoder
        self.input_dims = input_dim
        if identity_encoder:
            assert latent_dim == input_dim
            self.fc_mu = nn.Identity()
            self.fc_logvar = nn.Linear(input_dim, latent_dim)
        else:
            self.fc1 = nn.Linear(input_dim, hidden_dim)
            self.fc2 = nn.Linear(hidden_dim, hidden_dim)
            self.fc_mu = nn.Linear(hidden_dim, latent_dim)  # Output layer for mean
            self.fc_logvar = nn.Linear(hidden_dim, latent_dim)  # Output layer for log variance
            self.optimizer = th.optim.Adam(self.parameters(), lr=3e-4)

    def _encode(self, x):
        if not self.identity_encoder:
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            mu = squash_to_range(self.fc_mu(x), low=-30.0, high=30.0)
            logvar = squash_to_range(self.fc_logvar(x),low=np.log(0.1), high=np.log(10))
        else:
            mu = self.fc_mu(x)
            logvar = squash_to_range(self.fc_logvar(x),low=np.log(0.1), high=np.log(10))
        return mu, logvar
    def forward(self, x):
        # Log variance for stability
        return self._encode(x)

    def sample(self, x):
        mu, logvar = self.forward(x)
        z = reparameterize(mu, logvar)
        z = squash_to_range(z, low=-30.0, high=30.0)
        return z

class PredictorNet(nn.Module):
    def __init__(self, input_dim, hidden_dim=128,latent_dim=64):
        super().__init__()

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)  # Output layer for mean
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)  # Output layer for log variance
        self.optimizer = th.optim.Adam(self.parameters(), lr=3e-4)

    def forward(self, x, u):
        x = th.cat([x, u], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        mu = squash_to_range(self.fc_mu(x), low=-30.0, high=30.0)
        # squash the logvar to obtain an std in the range 0.1 - 10
        logvar = squash_to_range(self.fc_logvar(x),low=np.log(0.1), high=np.log(10))
        return mu, logvar

    def sample(self, x, u):
        mu, logvar = self.forward(x, u)
        z = reparameterize(mu, logvar)
        return z

def reparameterize(mu, logvar):
    """Reparameterization trick to sample from N(mu, var) from N(0, 1)."""
    std = th.exp(0.5 * logvar)  # Standard deviation
    eps = th.randn_like(std)  # `randn_like` as we need the same size
    return mu + eps * std  # Sampling

class RPC(RPCBaseAlgorithm):
    """
    RPC implementation

    :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
    :param env: The environment to learn from (if registered in Gym, can be str)
    :param learning_rate: learning rate for adam optimizer,
        the same learning rate will be used for all networks (Q-Values, Actor and Value function)
        it can be a function of the current progress remaining (from 1 to 0)
    :param buffer_size: size of the replay buffer
    :param learning_starts: how many steps of the model to collect transitions for before learning starts
    :param batch_size: Minibatch size for each gradient update
    :param tau: the soft update coefficient ("Polyak update", between 0 and 1)
    :param gamma: the discount factor
    :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
        like ``(5, "step")`` or ``(2, "episode")``.
    :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
        Set to ``-1`` means to do as many gradient steps as steps done in the environment
        during the rollout.
    :param action_noise: the action noise type (None by default), this can help
        for hard exploration problem. Cf common.noise for the different action noise type.
    :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
        If ``None``, it will be automatically selected.
    :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
    :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
        at a cost of more complexity.
        See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
    :param ent_coef: Entropy regularization coefficient. (Equivalent to
        inverse of reward scale in the original SAC paper.)  Controlling exploration/exploitation trade-off.
        Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value)
    :param target_update_interval: update the target network every ``target_network_update_freq``
        gradient steps.
    :param target_entropy: target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``)
    :param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
        instead of action noise exploration (default: False)
    :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
        Default: -1 (only sample at the beginning of the rollout)
    :param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling
        during the warm up phase (before learning starts)
    :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
        the reported success rate, mean episode length, and mean reward over
    :param tensorboard_log: the log location for tensorboard (if None, no logging)
    :param policy_kwargs: additional arguments to be passed to the policy on creation
    :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
        debug messages
    :param seed: Seed for the pseudo random generators
    :param device: Device (cpu, cuda, ...) on which the code should be run.
        Setting it to auto, the code will be run on the GPU if possible.
    :param _init_setup_model: Whether or not to build the network at the creation of the instance
    """

    policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
        "MlpPolicy": RPCPolicy,
        "CnnPolicy": CnnPolicy,
        "MultiInputPolicy": MultiInputPolicy,
    }
    policy: RPCPolicy
    actor: RPCActor
    critic: ContinuousCritic
    critic_target: ContinuousCritic

    def __init__(
        self,
        policy: Union[str, Type[RPCPolicy]],
        env: Union[GymEnv, str],
        learning_rate: Union[float, Schedule] = 3e-4,
        buffer_size: int = 1_000_000,  # 1e6
        learning_starts: int = 100,
        batch_size: int = 256,
        tau: float = 0.005,
        gamma: float = 0.99,
        train_freq: Union[int, Tuple[int, str]] = 1,
        gradient_steps: int = 1,
        action_noise: Optional[ActionNoise] = None,
        replay_buffer_class: Optional[Type[ReplayBufferComplexity]] = None,
        replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
        optimize_memory_usage: bool = False,
        ent_coef: Union[str, float] = "auto",
        target_update_interval: int = 1,
        target_entropy: Union[str, float] = "auto",
        use_sde: bool = False,
        sde_sample_freq: int = -1,
        use_sde_at_warmup: bool = False,
        stats_window_size: int = 100,
        tensorboard_log: Optional[str] = None,
        policy_kwargs: Optional[Dict[str, Any]] = None,
        verbose: int = 0,
        seed: Optional[int] = None,
        device: Union[th.device, str] = "auto",
        _init_setup_model: bool = True,
        clip_mean: Optional[float] = None,
        clip_min_stddev: Optional[float] = None,
        clip_max_stddev: Optional[float] = None,
        latent_dim=64,
        hidden_dim=64,
        log_prob_reward_scale=0.0,
        predictor_updates_encoder=False,
        predict_prior=True,
        use_recurrent_actor=False,
        rnn_sequence_length=20,
        predictor_num_layers=2,
        use_identity_encoder=False,
        identity_encoder_single_stddev=False,
        kl_constraint=1.0,
        use_residual_predictor=True,
        predict_prior_std=True,

    ):
        super().__init__(
            policy,
            env,
            learning_rate,
            buffer_size,
            learning_starts,
            batch_size,
            tau,
            gamma,
            train_freq,
            gradient_steps,
            action_noise,
            replay_buffer_class=replay_buffer_class,
            replay_buffer_kwargs=replay_buffer_kwargs,
            policy_kwargs=policy_kwargs,
            tensorboard_log=tensorboard_log,
            verbose=verbose,
            device=device,
            seed=seed,
            use_sde=use_sde,
            sde_sample_freq=sde_sample_freq,
            use_sde_at_warmup=use_sde_at_warmup,
            optimize_memory_usage=optimize_memory_usage,
            supported_action_spaces=(spaces.Box,),
            support_multi_env=True,

        )

        self.target_entropy = target_entropy
        self.log_ent_coef = None  # type: Optional[th.Tensor]
        # Entropy coefficient / Entropy temperature
        # Inverse of the reward scale
        self.ent_coef = ent_coef
        self.target_update_interval = target_update_interval
        self.ent_coef_optimizer: Optional[th.optim.Adam] = None
        self.clip_mean = clip_mean
        self.clip_min_stddev = clip_min_stddev
        self.clip_max_stddev = clip_max_stddev
        self.std_low = th.log(th.exp(th.as_tensor(self.clip_min_stddev)) - 1.0) if self.clip_min_stddev is not None else -float('inf')
        self.std_high = th.log(th.exp(th.as_tensor(self.clip_max_stddev)) - 1.0) if self.clip_max_stddev is not None else float('inf')
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.log_prob_reward_scale = log_prob_reward_scale
        self.predictor_updates_encoder = predictor_updates_encoder
        self.predict_prior = predict_prior
        self.use_recurrent_actor = use_recurrent_actor
        self.rnn_sequence_length = rnn_sequence_length
        self.predictor_num_layers = predictor_num_layers
        self.use_identity_encoder = use_identity_encoder
        self.identity_encoder_single_stddev = identity_encoder_single_stddev
        self.kl_constraint = kl_constraint
        self.use_residual_predictor = use_residual_predictor
        self.predict_prior_std = predict_prior_std
        self.dual_coef = th.tensor(0.0001, device=self.device, requires_grad=True)
        self.dual_optimizer = th.optim.Adam([self.dual_coef], lr=3e-4)
        if _init_setup_model:
            self._setup_model()

    def _setup_model(self) -> None:
        encoder = EncoderNet(self.observation_space.shape[0], latent_dim=self.latent_dim, identity_encoder=True)
        predictor = PredictorNet(self.latent_dim + self.action_space.shape[0], hidden_dim=self.hidden_dim,
                                 latent_dim=self.latent_dim)
        encoder_decoder_kwargs = {
            "encoder": encoder,
            "decoder": predictor,
        }
        self.policy_kwargs.update(encoder_decoder_kwargs)
        super()._setup_model()
        self._create_aliases()
        # Running mean and running var
        self.batch_norm_stats = get_parameters_by_name(self.critic, ["running_"])
        self.batch_norm_stats_target = get_parameters_by_name(self.critic_target, ["running_"])
        # Target entropy is used when learning the entropy coefficient
        if self.target_entropy == "auto":
            # automatically set target entropy if needed
            self.target_entropy = float(-np.prod(self.env.action_space.shape).astype(np.float32))  # type: ignore
        else:
            # Force conversion
            # this will also throw an error for unexpected string
            self.target_entropy = float(self.target_entropy)

        # The entropy coefficient or entropy can be learned automatically
        # see Automating Entropy Adjustment for Maximum Entropy RL section
        # of https://arxiv.org/abs/1812.05905
        if isinstance(self.ent_coef, str) and self.ent_coef.startswith("auto"):
            # Default initial value of ent_coef when learned
            init_value = 1.0
            if "_" in self.ent_coef:
                init_value = float(self.ent_coef.split("_")[1])
                assert init_value > 0.0, "The initial value of ent_coef must be greater than 0"

            # Note: we optimize the log of the entropy coeff which is slightly different from the paper
            # as discussed in https://github.com/rail-berkeley/softlearning/issues/37
            self.log_ent_coef = th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True)
            self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.lr_schedule(1))
        else:
            # Force conversion to float
            # this will throw an error if a malformed string (different from 'auto')
            # is passed
            self.ent_coef_tensor = th.tensor(float(self.ent_coef), device=self.device)



    def _create_aliases(self) -> None:
        self.actor = self.policy.actor
        self.critic = self.policy.critic
        self.critic_target = self.policy.critic_target

    def train(self, gradient_steps: int, batch_size: int = 64) -> None:
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update optimizers learning rate
        optimizers = [self.actor.optimizer, self.critic.optimizer, self.dual_optimizer]
        if self.ent_coef_optimizer is not None:
            optimizers += [self.ent_coef_optimizer]

        # Update learning rate according to lr schedule
        self._update_learning_rate(optimizers)

        ent_coef_losses, ent_coefs = [], []
        actor_losses, critic_losses = [], []

        for gradient_step in range(gradient_steps):
            # Sample replay buffer

            replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)  # type: ignore[union-attr]


            # We need to sample because `log_std` may have changed between two gradient steps
            if self.use_sde:
                self.actor.reset_noise()

            actions_pi, log_prob, kl_div, sampled_z_next = self.actor.full_forward_pass(th.as_tensor(replay_data.observations,dtype=th.float32), th.as_tensor(replay_data.actions,dtype=th.float32), th.as_tensor(replay_data.next_observations,dtype=th.float32))
            # Clip the kl divergence to prevent divergence
            # kl_div = th.clamp(kl_div, min=0.0, max=2*self.kl_constraint)
            dual_loss = -1*self.dual_coef*(kl_div.mean().detach()-self.kl_constraint)
            self.dual_optimizer.zero_grad()
            dual_loss.backward()
            self.dual_optimizer.step()
            # Action by the current actor for the sampled state
            log_prob = log_prob.reshape(-1, 1)

            ent_coef_loss = None
            if self.ent_coef_optimizer is not None and self.log_ent_coef is not None:
                # Important: detach the variable from the graph
                # so we don't change it with other losses
                # see https://github.com/rail-berkeley/softlearning/issues/60
                ent_coef = th.exp(self.log_ent_coef.detach())
                ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
                ent_coef_losses.append(ent_coef_loss.item())
            else:
                ent_coef = self.ent_coef_tensor

            ent_coefs.append(ent_coef.item())

            # Optimize entropy coefficient, also called
            # entropy temperature or alpha in the paper
            if ent_coef_loss is not None and self.ent_coef_optimizer is not None:
                self.ent_coef_optimizer.zero_grad()
                ent_coef_loss.backward()
                self.ent_coef_optimizer.step()
            next_actions, next_log_prob = self.actor.action_log_prob(sampled_z_next)
            with (((th.no_grad()))):

                # Select action according to policy
                # Compute the next Q values: min over all critics targets
                next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions.detach()), dim=1)
                next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
                # add entropy term
                next_q_values = next_q_values - ent_coef * next_log_prob.detach().reshape(-1, 1)
                # td error + entropy term
                target_q_values = replay_data.rewards-self.dual_coef*kl_div.detach().reshape(-1, 1) + (1 - replay_data.dones) * self.gamma * next_q_values


            # Get current Q-values estimates for each critic network
            # using action from the replay buffer
            current_q_values = self.critic(replay_data.observations, replay_data.actions)

            # Compute critic loss
            critic_loss = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values)
            assert isinstance(critic_loss, th.Tensor)  # for type checker
            critic_losses.append(critic_loss.item())  # type: ignore[union-attr]


            # Optimize the critic
            self.critic.optimizer.zero_grad()
            critic_loss.backward()
            self.critic.optimizer.step()

            # Compute actor loss
            # Alternative: actor_loss = th.mean(log_prob - qf1_pi)
            # Min over all critic networks
            q_values_pi = th.cat(self.critic(replay_data.next_observations, next_actions), dim=1)
            min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
            actor_loss = (ent_coef * log_prob - min_qf_pi + self.dual_coef.detach()*kl_div.reshape(-1,1)).mean()
            actor_losses.append(actor_loss.item())

            # Optimize the actor
            self.actor.optimizer.zero_grad()
            actor_loss.backward()
            self.actor.optimizer.step()

            # Update target networks
            if gradient_step % self.target_update_interval == 0:
                polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
                # Copy running stats, see GH issue #996
                polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)

        self._n_updates += gradient_steps

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/ent_coef", np.mean(ent_coefs))
        self.logger.record("train/actor_loss", np.mean(actor_losses))
        self.logger.record("train/critic_loss", np.mean(critic_losses))
        self.logger.record("train/dual_coeff", self.dual_coef.item())
        self.logger.record("train/kl_div", kl_div.detach().mean().item())
        if len(ent_coef_losses) > 0:
            self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))

    def learn(
        self: SelfRPC,
        total_timesteps: int,
        callback: MaybeCallback = None,
        log_interval: int = 4,
        tb_log_name: str = "RPC",
        reset_num_timesteps: bool = True,
        progress_bar: bool = False,
    ) -> SelfRPC:
        return super().learn(
            total_timesteps=total_timesteps,
            callback=callback,
            log_interval=log_interval,
            tb_log_name=tb_log_name,
            reset_num_timesteps=reset_num_timesteps,
            progress_bar=progress_bar,
        )

    def _excluded_save_params(self) -> List[str]:
        return super()._excluded_save_params() + ["actor", "critic", "critic_target"]  # noqa: RUF005

    def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
        state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
        if self.ent_coef_optimizer is not None:
            saved_pytorch_variables = ["log_ent_coef"]
            state_dicts.append("ent_coef_optimizer")
        else:
            saved_pytorch_variables = ["ent_coef_tensor"]
        saved_pytorch_variables.append("dual_coef")
        state_dicts.append("dual_optimizer")
        return state_dicts, saved_pytorch_variables
