import gtimer as gt
from collections import OrderedDict

import numpy as np
import torch
import torch.optim as optim
from torch import nn as nn

import rlkit.torch.pytorch_util as ptu
from rlkit.core.eval_util import create_stats_ordered_dict
from rlkit.torch.torch_rl_algorithm import TorchTrainer


class SACTrainer(TorchTrainer):

    def __init__(
            self,
            env,
            policy,
            qf1,
            qf2,
            target_qf1,
            target_qf2,
            discount=0.99,
            reward_scale=1.0,
            alpha=1.,
            policy_lr=1e-3,
            qf_lr=1e-3,
            optimizer_class=optim.Adam,
            soft_target_tau=1e-2,
            target_update_period=1,
            clip_norm=0.,
            use_automatic_entropy_tuning=True,
            target_entropy=None,
            Flag_entropy=True, # flag: entropy or not
            gradient=0, # flag: gradient
    ):
        super().__init__()
        self.env = env
        self.policy = policy
        self.qf1 = qf1
        self.qf2 = qf2
        self.target_qf1 = target_qf1
        self.target_qf2 = target_qf2
        self.soft_target_tau = soft_target_tau
        self.target_update_period = target_update_period

        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        if self.use_automatic_entropy_tuning: # true: maintain a good log_alpha
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                self.target_entropy = -np.prod(self.env.action_space.shape).item()  # heuristic value from Tuomas
            self.log_alpha = ptu.zeros(1, requires_grad=True)
            self.alpha_optimizer = optimizer_class(
                [self.log_alpha],
                lr=policy_lr,
            )
        else:
            self.alpha = alpha

        # define loss and optimizer: only policy NN and qf1/qf2 NN, only copy for the target NN
        self.qf_criterion = nn.MSELoss()

        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            lr=policy_lr,
        )
        self.qf1_optimizer = optimizer_class(
            self.qf1.parameters(),
            lr=qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self.qf2.parameters(),
            lr=qf_lr,
        )

        self.discount = discount
        self.reward_scale = reward_scale
        self.clip_norm = clip_norm
        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True

        # new
        self.Flag_entropy = Flag_entropy
        self.gradient = gradient

    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        gt.stamp('preback_start', unique=False)
        """
        Update Alpha
        """
        new_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy( # reparameterization
            obs,
            reparameterize=True,
            return_log_prob=True,
        )
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = self.alpha
        """
        Update QF: two Q networks
        """
        with torch.no_grad():
            new_next_actions, _, _, new_log_pi, *_ = self.policy(
                next_obs,
                reparameterize=True,
                return_log_prob=True,
            )
            if self.Flag_entropy:
                target_q_values = torch.min(
                    self.target_qf1(next_obs, new_next_actions), # eq.3 Q(s', a') - alpha * log pi
                    self.target_qf2(next_obs, new_next_actions),
                ) - alpha * new_log_pi # key !!!!!!!!!!!!!!!!!
            else:
                target_q_values = torch.min(
                    self.target_qf1(next_obs, new_next_actions),  # eq.3 Q(s', a') - alpha * log pi
                    self.target_qf2(next_obs, new_next_actions),
                )
                alpha = 1.0
            q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values # TD learning

        ########### require gradient
        if self.gradient == 1:
            obs.requires_grad=True

        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        qf1_loss = self.qf_criterion(q1_pred, q_target)
        qf2_loss = self.qf_criterion(q2_pred, q_target)
        gt.stamp('preback_qf', unique=False)

        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        ############### compute gradient
        if self.gradient == 1:
            critic_norm_loss = torch.norm(obs.grad.data)

        self.qf1_optimizer.step()
        gt.stamp('backward_qf1', unique=False)

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()
        gt.stamp('backward_qf2', unique=False)
        """
        Update Policy
        """
        # require gradient
        if self.gradient == 1:
            obs.requires_grad = True
        q_new_actions = torch.min(
            self.qf1(obs, new_actions), # new action comes from the policy
            self.qf2(obs, new_actions),
        )
        if self.Flag_entropy:
            policy_loss = (alpha * log_pi - q_new_actions).mean() # KL divergence: equation 12 in SAC
        else:
            policy_loss = - q_new_actions.mean() # Q function / avereage rewards
        gt.stamp('preback_policy', unique=False)


        self.policy_optimizer.zero_grad()
        policy_loss.backward()

        # require gradient
        if self.gradient == 1:
            actor_norm_loss = torch.norm(obs.grad.data)

        policy_grad = ptu.fast_clip_grad_norm(self.policy.parameters(), self.clip_norm)
        self.policy_optimizer.step()
        gt.stamp('backward_policy', unique=False)
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf1, self.target_qf1, self.soft_target_tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2, self.soft_target_tau)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False

            # record gradient
            if self.gradient == 1:
                self.eval_statistics['critic gradient loss norm'] = critic_norm_loss.cpu().numpy()
                self.eval_statistics['actor gradient loss norm'] = actor_norm_loss.cpu().numpy()

            self.eval_statistics['QF1 Loss'] = qf1_loss.item()
            self.eval_statistics['QF2 Loss'] = qf2_loss.item()
            self.eval_statistics['Policy Loss'] = policy_loss.item()
            self.eval_statistics['Policy Grad'] = policy_grad
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Log Pis',
                ptu.get_numpy(log_pi),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy mu',
                ptu.get_numpy(policy_mean),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy log std',
                ptu.get_numpy(policy_log_std),
            ))
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()
        self._n_train_steps_total += 1

    def get_diagnostics(self):
        return self.eval_statistics

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True

    @property
    def networks(self):
        return [
            self.policy,
            self.qf1,
            self.qf2,
            self.target_qf1,
            self.target_qf2,
        ]

    def get_snapshot(self):
        return dict(
            policy=self.policy.state_dict(),
            qf1=self.qf1.state_dict(),
            qf2=self.qf2.state_dict(),
            target_qf1=self.qf1.state_dict(),
            target_qf2=self.qf2.state_dict(),
        )
