import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch import nn

from rsl_rl.storage.buffer import ReplayBuffer, NStepReplayBuffer
from rsl_rl.modules import SACActorCritic
from rsl_rl.modules.normalizer import EmpiricalNormalization, IdentityNormalization
from rsl_rl.utils import optimizer_update, soft_update_target_network


class SAC:
    def __init__(
        self,
        actor_critic: SACActorCritic,
        batch_size: int = 256,
        num_learning_epochs: int = 1,
        actor_lr: float = 1e-3,
        alpha: float = 0.2,
        alpha_lr: float = 0.005,
        critic_lr: float = 1e-3,
        gamma: float = 0.99,
        target_entropy: float | torch.Tensor = -10,
        max_grad_norm: float = 0.5,
        tau: float = 0.005,
        nstep: int = 1,
        empirical_normalization: bool = False,
        device="cpu",
    ):
        self.device = device
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.alpha_lr = alpha_lr
        self.tau = tau
        self.max_grad_norm = max_grad_norm
        self.num_learning_epochs = num_learning_epochs
        self.batch_size = batch_size

        self.actor_critic = actor_critic.to(device)
        self.critic_optim = optim.Adam(self.critic.parameters(), lr=self.critic_lr)
        self.actor_optim = optim.Adam(self.actor.parameters(), lr=self.actor_lr)
        self.log_alpha = nn.Parameter(torch.log(torch.tensor(alpha, device=device)))
        self.alpha_optim = optim.Adam([self.log_alpha], lr=alpha_lr)
        if target_entropy is None:
            self.target_entropy = torch.tensor(-np.prod(self.actor_critic.action_shape).item(), device=device)
        else:
            self.target_entropy = torch.tensor([target_entropy], device=device)
        self.gamma = gamma

        self.storage = None
        self.transition = ReplayBuffer.Transition()
        self.nstep = nstep
        self.empirical_normalization = empirical_normalization
        if self.empirical_normalization:
            self.obs_normalizer = EmpiricalNormalization(actor_critic.num_actor_obs).to(device)
            self.critic_obs_normalizer = EmpiricalNormalization(actor_critic.num_critic_obs).to(device)
        else:
            self.obs_normalizer = IdentityNormalization()
            self.critic_obs_normalizer = IdentityNormalization()

    def init_storage(self, num_envs, capacity_per_env, obs_shape, critic_obs_shape, action_shape):
        self.storage = ReplayBuffer(
            num_envs=num_envs,
            capacity_per_env=capacity_per_env,
            obs_shape=obs_shape,
            privileged_obs_shape=critic_obs_shape,
            action_shape=action_shape,
            device=self.device,
        )

        if self.nstep > 1:
            self.nstep_buffer = NStepReplayBuffer(
                num_envs=num_envs,
                obs_shape=obs_shape,
                privileged_obs_shape=critic_obs_shape,
                action_shape=action_shape,
                gamma=self.gamma,
                nstep=self.nstep,
                device=self.device,
            )

    def test_mode(self):
        self.actor_critic.test()

    def train_mode(self):
        self.actor_critic.train()

    @property
    def actor(self):
        return self.actor_critic.actor

    @property
    def critic(self):
        return self.actor_critic.critic

    @property
    def critic_target(self):
        return self.actor_critic.critic_target

    @property
    def alpha(self):
        return self.log_alpha.exp()

    @torch.no_grad()
    def update_normalizer(self, obs, critic_obs):
        self.obs_normalizer.update(obs)
        self.critic_obs_normalizer.update(critic_obs)

    # @torch.compile(dynamic=True)
    def act(self, obs, critic_obs, is_deterministic=False):
        if self.actor_critic.is_recurrent:
            self.transition.hidden_states = self.actor_critic.get_hidden_states()
        # Compute the actions
        self.update_normalizer(obs, critic_obs)
        return self.actor_critic.act(self.obs_normalizer(obs), is_deterministic)

    def act_inference(self, obs):
        return self.actor_critic.act(self.obs_normalizer(obs), is_deterministic=True).detach()

    def process_env_step(self, obs, critic_obs, next_obs, next_critic_obs, actions, rewards, dones, infos):
        self.transition.observations = obs
        self.transition.critic_observations = critic_obs
        self.transition.next_observations = next_obs
        self.transition.next_critic_observations = next_critic_obs
        self.transition.actions = actions
        self.transition.rewards = rewards
        self.transition.dones = dones
        # Bootstrapping on time outs
        if "time_outs" in infos:
            self.transition.dones = torch.logical_and(dones, ~infos["time_outs"])
        else:
            self.transition.dones = dones

        # Record the transition
        if self.nstep > 1:
            rewards, next_critic_obs, dones = self.nstep_buffer.add_transitions(self.transition)
            if self.nstep_buffer.is_ready():
                self.transition.rewards = rewards.squeeze(1)
                self.transition.next_critic_observations = next_critic_obs
                self.transition.dones = dones.squeeze(1)
        self.storage.add_transitions(self.transition)
        self.transition.clear()
        self.actor_critic.reset(dones)

    # @torch.compile(dynamic=True)
    def update(self):
        critic_loss_list = list()
        actor_loss_list = list()
        alpha_loss_list = list()

        if self.actor_critic.is_recurrent:
            generator = self.storage.reccurent_mini_batch_generator(self.batch_size, self.num_learning_epochs)
        else:
            generator = self.storage.mini_batch_generator(self.batch_size, self.num_learning_epochs)
        for (
            obs_batch,
            critic_obs_batch,
            next_critic_obs_batch,
            actions_batch,
            rewards_batch,
            dones_batch,
            hid_states_batch,
        ) in generator:
            with torch.no_grad():
                obs_batch_n = self.obs_normalizer(obs_batch)
                critic_obs_batch_n = self.critic_obs_normalizer(critic_obs_batch)
                next_critic_obs_batch_n = self.critic_obs_normalizer(next_critic_obs_batch)

            critic_loss = self.update_critic(
                critic_obs_batch_n, actions_batch, rewards_batch, next_critic_obs_batch_n, dones_batch
            )
            critic_loss_list.append(critic_loss)

            actor_loss, alpha_loss = self.update_actor_and_alpha(obs_batch_n, critic_obs_batch_n)
            actor_loss_list.append(actor_loss)
            alpha_loss_list.append(alpha_loss)

            soft_update_target_network(self.critic_target, self.critic, self.tau)

        mean_critic_loss = np.mean(critic_loss_list)
        mean_actor_loss = np.mean(actor_loss_list)
        mean_alpha_loss = np.mean(alpha_loss_list)
        # return critic_loss_list, actor_loss_list, alpha_loss_list
        return mean_critic_loss, mean_actor_loss, mean_alpha_loss

    def update_critic(self, obs, action, reward, next_obs, done):
        with torch.no_grad():
            next_actions, dist = self.actor_critic.get_actions_dist(next_obs)
            log_prob = dist.log_prob(next_actions).sum(-1, keepdim=True)
            target_Q = self.actor_critic.evaluate(next_obs, next_actions) - self.alpha.detach() * log_prob
            target_Q = reward + (1 - done) * self.gamma * target_Q
        current_Q1, current_Q2 = self.actor_critic.get_q1_q2(obs, action)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
        optimizer_update(self.critic_optim, critic_loss, self.max_grad_norm)

        return critic_loss.item()

    def update_actor_and_alpha(self, obs, critic_obs):
        self.critic.requires_grad_(False)
        actions, dist = self.actor_critic.get_actions_dist(obs)
        log_prob = dist.log_prob(actions).sum(-1, keepdim=True)
        Q_value = self.critic(critic_obs, actions)
        actor_loss = (self.alpha.detach() * log_prob - Q_value).mean()
        optimizer_update(self.actor_optim, actor_loss, self.max_grad_norm)
        self.critic.requires_grad_(True)

        alpha_loss = (self.alpha * (-log_prob - self.target_entropy).detach()).mean()
        optimizer_update(self.alpha_optim, alpha_loss, self.max_grad_norm)

        return actor_loss.item(), alpha_loss.item()
