import torch
import torch.nn.functional as F
from .sac_base import SAC_Base


class SAC_mini(SAC_Base):
    def __init__(self, CONFIG):
        super().__init__(CONFIG)


    @property
    def has_latent(self):
        return False


    @property
    def latent_dist(self):
        return None


    def build_network(self,
                      CONFIG_ARCH,
                      build_optimizer=True,
                      verbose=True):
        super().build_network(CONFIG_ARCH, verbose)

        # Set up optimizer
        if build_optimizer:
            super().build_optimizer()


    def update_critic(self, batch, detach_lang):
        (non_final_mask, non_final_state_nxt, state, action, reward, append,
         non_final_append_nxt, _, _) = batch
        self.critic.train()
        self.critic_target.eval()
        self.actor.eval()

        # == get Q(s,a) ==
        q1, q2 = self.critic(
            state, action,
            append=append,
            detach_lang=detach_lang)  # Used to compute loss (non-target part).

        # == placeholder for target ==
        y = torch.zeros(state.shape[0]).float().to(self.device)

        # == compute actor next_actions and feed to critic_target ==
        with torch.no_grad():
            next_actions, next_log_prob = self.actor.sample(
                                                non_final_state_nxt, 
                                                append=non_final_append_nxt, 
                                                get_prob=True)
            next_q1, next_q2 = self.critic_target(non_final_state_nxt,
                                                  next_actions,
                                                  append=non_final_append_nxt)
            q_min = torch.min(next_q1, next_q2).view(-1)
            target_q = q_min - self.alpha * next_log_prob.view(
                -1)  # already masked - can be lower dim than y
            y = reward
            y[non_final_mask] += self.GAMMA * target_q

        # == MSE update for both Q1 and Q2 ==
        loss_q1 = F.mse_loss(input=q1.view(-1), target=y)
        loss_q2 = F.mse_loss(input=q2.view(-1), target=y)
        loss_q = loss_q1 + loss_q2

        # == backpropagation ==
        self.critic_optimizer.zero_grad()
        loss_q.backward()
        self.critic_optimizer.step()

        return loss_q.item()


    def update_actor(self, batch):
        """
        Use detach_encoder=True to not update conv layers
        """
        _, _, state, _, _, append, _, _, _ = batch

        self.critic.eval()
        self.actor.train()

        action_sample, log_prob = self.actor.sample(state,
                                                    append=append,
                                                    detach_encoder=True,
                                                    get_prob=True)
        q_pi_1, q_pi_2 = self.critic(state,
                                     action_sample,
                                     append=append,
                                     detach_encoder=True)
        q_pi = torch.min(q_pi_1, q_pi_2)

        # reward: max_theta E[ Q - alpha * (log pi + H)]
        # loss_pi = -Q + alpha * log pi
        loss_entropy = log_prob.view(-1).mean()
        loss_q_eval = q_pi.mean()
        loss_pi = -loss_q_eval + self.alpha * loss_entropy
        self.actor_optimizer.zero_grad()
        loss_pi.backward()
        self.actor_optimizer.step()

        # Automatic temperature tuning
        loss_alpha = (self.alpha *
                      (-log_prob - self.target_entropy).detach()).mean()
        if self.LEARN_ALPHA:
            self.log_alpha_optimizer.zero_grad()
            loss_alpha.backward()
            self.log_alpha_optimizer.step()
        return loss_pi.item(), loss_entropy.item(), loss_alpha.item()


    def update(self, batch, timer, update_period=2, detach_lang=False):
        self.critic.train()
        self.actor.train()

        loss_q = self.update_critic(batch, detach_lang)
        loss_pi, loss_entropy, loss_alpha = 0, 0, 0
        if timer % update_period == 0:
            loss_pi, loss_entropy, loss_alpha = self.update_actor(batch)
            self.update_target_networks()

        self.critic.eval()
        self.actor.eval()

        return loss_q, loss_pi, loss_entropy, loss_alpha


    def value(self, obs, append):
        u = self.actor(obs, append=append)
        u = torch.from_numpy(u).to(self.device)
        v = self.critic(obs, u, append=append)[0]
        if len(obs.shape) == 3:
            v = v[0]
        return v
