from copy import deepcopy
import torch
import numpy as np
from .iq_sac import IQ_SAC
import torch.nn.functional as F
from mushroom_rl.utils.parameters import to_parameter
from mushroom_rl.approximators import Regressor
from mushroom_rl.approximators.parametric import TorchApproximator

from mushroom_rl.utils.torch import to_float_tensor


class LSIQ_HC(IQ_SAC):

    """ This is the LSIQ version that uses an entropy critic. It is
        implemented the same way as IQ. """

    def __init__(self, H_tau, Q_max, Q_min, abs_mult=1.0, clip_expert_entropy_to_policy_max=True,
                 loss_mode_exp="fix", H_loss_mode="Huber",
                 clipp_reward_pen_to_zero=False, Q_exp_loss=None, H_params=None, treat_absorbing_states=False,
                 max_H_policy_tau_down=1e-4, max_H_policy_tau_up=1e-2, **kwargs):

        # call parent
        super().__init__(**kwargs)

        self._H_tau = to_parameter(H_tau)

        self._Q_max = Q_max
        self._Q_min = Q_min
        self._loss_mode_exp = loss_mode_exp     # or bootstrap
        self._Q_exp_loss = Q_exp_loss   # either MSE or Huber | Only used when loss_mod_exp == "fix"
        self._abs_mult = abs_mult
        self._clipp_reward_pen_to_zero = clipp_reward_pen_to_zero
        self._H_loss_mode = H_loss_mode     # either MSE or Huber
        self._treat_absorbing_states = treat_absorbing_states

        # define the H function with the target
        target_H_params = deepcopy(H_params)
        self._H_approximator = Regressor(TorchApproximator, **H_params)
        self._target_H_approximator = Regressor(TorchApproximator, **target_H_params)
        self._clip_expert_entropy_to_policy_max = clip_expert_entropy_to_policy_max
        self._max_H_policy = None
        self._max_H_policy_tau_down = max_H_policy_tau_down
        self._max_H_policy_tau_up = max_H_policy_tau_up

        # define the optimizer for the H function
        net_params = self._H_approximator.model.network.parameters()
        self._H_optimizer = H_params["optimizer"]["class"](net_params, **H_params["optimizer"]["params"])

    def _lossQ(self, obs, act, next_obs, absorbing, is_expert):

        # Calculate 1st term of loss
        gamma = to_float_tensor(self.mdp_info.gamma).cuda() if self._use_cuda else to_float_tensor(self.mdp_info.gamma)
        absorbing = torch.tensor(absorbing).cuda() if self._use_cuda else absorbing
        current_Q = self._critic_approximator(obs, act, output_tensor=True)
        if not self._use_target:
            next_v = self.getV(next_obs)
        else:
            with torch.no_grad():
                next_v = self.get_targetV(next_obs).detach()
        absorbing = torch.unsqueeze(absorbing, 1)
        y = (1 - absorbing) * gamma.detach() * self._Q_Q_multiplier * torch.clip(next_v, self._Q_min, self._Q_max)

        reward = (self._Q_Q_multiplier*current_Q - y)
        exp_reward = reward[is_expert]

        if self._loss_mode_exp == "bootstrap":
            loss_term1 = -exp_reward.mean()
        elif self._loss_mode_exp == "fix":
            if self._Q_exp_loss == "MSE":
                loss_term1 = F.mse_loss(current_Q[is_expert], torch.ones_like(current_Q[is_expert]) * self._Q_max)
            elif self._Q_exp_loss == "Huber":
                loss_term1 = F.huber_loss(current_Q[is_expert], torch.ones_like(current_Q[is_expert]) * self._Q_max)
            elif self._Q_exp_loss is None:
                raise ValueError("If you choose loss_mode_exp == fix, you have to specify Q_exp_loss. Setting it to"
                                 "None is not valid.")
            else:
                raise ValueError(
                    "Choosen Q_exp_loss %s is not supported. Choose either MSE or Huber." % self._Q_exp_loss)

        # do the logging
        self.logging_loss(current_Q, y, reward, is_expert, obs, act, absorbing)

        # 2nd term for our loss (use expert and policy states)
        V = self._Q_Q_multiplier * self.getV(obs)
        value = (V - y)
        self.sw_add_scalar('V for policy on all states', self._Q_Q_multiplier * V.mean(), self._iter)
        value_loss = value
        if self._plcy_loss_mode == "value":
            loss_term2 = value_loss.mean()
        elif self._plcy_loss_mode == "value_expert":
            value_loss_exp = value_loss[is_expert]
            loss_term2 = value_loss_exp.mean()
        elif self._plcy_loss_mode == "value_policy":
            value_loss_plcy = value_loss[~is_expert]
            loss_term2 = value_loss_plcy.mean()
        elif self._plcy_loss_mode == "q_old_policy":
            reward_plcy = reward[~is_expert]
            loss_term2 = reward_plcy.mean()
        elif self._plcy_loss_mode == "value_q_old_policy":
            reward_plcy = reward[~is_expert]
            loss_term2 = reward_plcy.mean() + value_loss.mean()
        elif self._plcy_loss_mode == "v0":
            value_loss_v0 = (1-gamma.detach()) * self.getV(obs[is_expert])
            loss_term2 = value_loss_v0.mean()
        else:
            raise ValueError("Undefined policy loss mode: %s" % self._plcy_loss_mode)

        # regularize
        chi2_loss = self.regularizer_loss(absorbing, reward, gamma, is_expert, treat_absorbing_states=self._treat_absorbing_states)

        loss_Q = loss_term1 + loss_term2 + chi2_loss
        self.update_Q_parameters(loss_Q)

        # update the H function
        loss_H, H, logpi = self.update_H_Delta(obs, act, next_obs, absorbing, gamma.detach(), is_expert)

        grads = []
        for param in self._critic_approximator.model.network.parameters():
            grads.append(param.grad.view(-1))
        grads = torch.cat(grads)
        norm = grads.norm(dim=0, p=2)
        if self._iter % self._logging_iter == 0:
            self.sw_add_scalar('Gradients/Norm2 Gradient LossQ wrt. Q-parameters', norm, self._iter)
            self.sw_add_scalar('H function/Loss', loss_H, self._iter)
            self.sw_add_scalar('H function/H', np.mean(H), self._iter)
            self.sw_add_scalar('H function/H plcy', np.mean(H[~is_expert]), self._iter)
            self.sw_add_scalar('H function/H expert', np.mean(H[is_expert]), self._iter)
            self.sw_add_scalar('H function/H_step', np.mean(-logpi), self._iter)
            self.sw_add_scalar('H function/H_step plcy', np.mean(-logpi[~is_expert]), self._iter)
            self.sw_add_scalar('H function/H_step expert', np.mean(-logpi[is_expert]), self._iter)

        return loss_term1, loss_term2, chi2_loss

    def update_H_Delta(self, obs, action, next_obs, absorbing, gamma, is_expert):
        H = self._H_approximator(obs, action, output_tensor=True)
        with torch.no_grad():
            next_action, log_pi = self.policy.compute_action_and_log_prob_t(next_obs)
            Q_plcy = self._target_critic_approximator(obs, action, output_tensor=True)
            V_plcy = self.get_targetV(obs)
            y = (1 - absorbing) * gamma.detach() * self._Q_Q_multiplier * torch.clip(V_plcy, self._Q_min,
                                                                                                 self._Q_max)

            if self._clipp_reward_pen_to_zero:
                squared_reg_reward_plcy = (1 - absorbing) * self._reg_mult * torch.square(torch.clip(Q_plcy - y, -1/self._reg_mult, 0.0)).detach() \
                                          + absorbing * (1.0 - gamma.detach()) * self._reg_mult * torch.square(torch.clip(Q_plcy - y, self._Q_min, 0.0)).detach()
            else:
                squared_reg_reward_plcy = (1 - absorbing) * self._reg_mult * torch.square(torch.clip(Q_plcy - y, -1/self._reg_mult, 1/self._reg_mult)).detach() \
                                          + absorbing * (1.0 - gamma.detach()) * self._reg_mult * torch.square(torch.clip(Q_plcy - y, self._Q_min, self._Q_max)).detach()

        # restrict the target H of the expert to the maximum one of the policy | This is the entropy clipping.
        neg_log_pi = -log_pi
        if self._clip_expert_entropy_to_policy_max:
            if self._max_H_policy is None:
                self._max_H_policy = torch.max(neg_log_pi[~is_expert])
            else:
                curr_max_H_policy = torch.max(neg_log_pi[~is_expert])
                if curr_max_H_policy > self._max_H_policy:
                    self._max_H_policy = (1 - self._max_H_policy_tau_up) * self._max_H_policy + \
                                         self._max_H_policy_tau_up * curr_max_H_policy
                else:
                    self._max_H_policy = (1 - self._max_H_policy_tau_down) * self._max_H_policy + \
                                          self._max_H_policy_tau_down * curr_max_H_policy
            neg_log_pi[is_expert] = torch.clip(neg_log_pi[is_expert], self._max_H_policy, 100000)
        target_H = squared_reg_reward_plcy + (1 - absorbing) * gamma * (self._target_H_approximator(next_obs, next_action, output_tensor=True).detach() +
                                              self._alpha.detach() * torch.unsqueeze(neg_log_pi, 1))

        Q2_max = (1.0/self._reg_mult)**2 / (1 - gamma.detach())
        target_H = torch.clip(target_H, -1000, Q2_max+100)

        if self._H_loss_mode == "Huber":
            loss_H = F.huber_loss(H, target_H)
        elif self._H_loss_mode == "MSE":
            loss_H = F.mse_loss(H, target_H)
        else:
            raise ValueError("Unsupported H_loss %s" % self._H_loss_mode)

        self._H_optimizer.zero_grad()
        loss_H.backward()
        self._H_optimizer.step()

        return loss_H, H.detach().cpu().numpy(), log_pi.detach().cpu().numpy()

    def iq_update(self, input_states, input_actions, input_n_states, input_absorbing, is_expert):

        if self._iter % self._delay_Q == 0:
            loss1, loss2, chi2_loss = self._lossQ(input_states, input_actions, input_n_states, input_absorbing,
                                                  is_expert)
            if self._iter % self._logging_iter == 0:
                self.sw_add_scalar('IQ-Loss/Loss1', loss1, self._iter)
                self.sw_add_scalar('IQ-Loss/Loss2', loss2, self._iter)
                self.sw_add_scalar('IQ-Loss/Chi2 Loss', chi2_loss, self._iter)
                self.sw_add_scalar('IQ-Loss/Alpha', self._alpha, self._iter)

        # update policy
        if self._replay_memory.size > self._warmup_transitions() and self._iter % self._delay_pi == 0:
            if self._train_policy_only_on_own_states:
                policy_training_states = input_states[~is_expert]
                policy_training_next_states = input_n_states[~is_expert]
            else:
                policy_training_states = input_states
                policy_training_next_states = input_n_states
            action_new, log_prob = self.policy.compute_action_and_log_prob_t(policy_training_states)
            loss = self._actor_loss(policy_training_states, action_new, policy_training_next_states, log_prob)
            self._optimize_actor_parameters(loss)
            grads = []
            for param in self.policy._approximator.model.network.parameters():
                grads.append(param.grad.view(-1))
            grads = torch.cat(grads)
            norm = grads.norm(dim=0, p=2)
            if self._iter % self._logging_iter == 0:
                self.sw_add_scalar('Gradients/Norm2 Gradient Q wrt. Pi-parameters', norm,
                                    self._iter)
                self.sw_add_scalar('Actor/Loss', loss, self._iter)
                _, log_prob = self.policy.compute_action_and_log_prob_t(input_states)
                self.sw_add_scalar('Actor/Entropy Expert States', torch.mean(-log_prob[is_expert]).detach().item(), self._iter)
                self.sw_add_scalar('Actor/Entropy Policy States', torch.mean(-log_prob[~is_expert]).detach().item(), self._iter)
                _, logsigma = self.policy.get_mu_log_sigma(input_states[~is_expert])
                ent_gauss = self.policy.entropy_from_logsigma(logsigma)
                e_lb = self.policy.get_e_lb()
                self.sw_add_scalar('Actor/Entropy from Gaussian Policy States', torch.mean(ent_gauss).detach().item(), self._iter)
                self.sw_add_scalar('Actor/Entropy Lower Bound', e_lb, self._iter)
                _, logsigma = self.policy.get_mu_log_sigma(input_states[is_expert])
                ent_gauss = self.policy.entropy_from_logsigma(logsigma)
                self.sw_add_scalar('Actor/Entropy from Gaussian Expert States', torch.mean(ent_gauss).detach().item(), self._iter)
            if self._learnable_alpha:
                self._update_alpha(log_prob.detach())

        if self._iter % self._delay_Q == 0:
            self._update_target(self._critic_approximator,
                                self._target_critic_approximator)
            self._update_target_H(self._H_approximator,
                                  self._target_H_approximator)

    def _update_target_H(self, online, target):
        for i in range(len(target)):
            weights = self._H_tau() * online[i].get_weights()
            weights += (1 - self._H_tau.get_value()) * target[i].get_weights()
            target[i].set_weights(weights)

    def _actor_loss(self, state, action_new, next_state, log_prob):
        q = self._critic_approximator(state, action_new, output_tensor=True)
        H = self._H_approximator(state, action_new, output_tensor=True)
        soft_q = q + H
        return (self._alpha.detach() * log_prob - self._Q_pi_multiplier * soft_q).mean()

    def getV(self, obs):
        with torch.no_grad():
            action, _ = self.policy.compute_action_and_log_prob_t(obs)
        current_V = self._critic_approximator(obs, action.detach().cpu().numpy(), output_tensor=True)
        return current_V

    def get_targetV(self, obs):
        with torch.no_grad():
            action, _ = self.policy.compute_action_and_log_prob_t(obs)
        target_V = self._target_critic_approximator(obs, action.detach().cpu().numpy(), output_tensor=True)
        return target_V
