import torch
import numpy as np
from .iq_sac import IQ_SAC
import torch.nn.functional as F

from mushroom_rl.utils.torch import to_float_tensor


class LSIQ(IQ_SAC):

    """ This is the vanilla version of LSIQ that neither uses an entropy critic nor a regularization critic. It is
        implemented the same way as IQ. Note that there is another version of LSIQ, that uses the SQIL-like loss
        formulation. """

    def __init__(self, Q_max, Q_min, loss_mode_exp="fix", abs_mult=1.0, Q_exp_loss=None,
                 treat_absorbing_states=False, **kwargs):

        # call parent
        super(LSIQ, self).__init__(**kwargs)

        self._Q_max = Q_max
        self._Q_min = Q_min
        self._loss_mode_exp = loss_mode_exp # or bootstrap
        self._Q_exp_loss = Q_exp_loss   # either MSE or Huber | Only used when loss_mod_exp == "fix"
        self._abs_mult = abs_mult
        self._treat_absorbing_states = treat_absorbing_states   # when using fixed targets, this is not needed.

    def _lossQ(self, obs, act, next_obs, absorbing, is_expert):

        # Calculate 1st term of loss
        gamma = to_float_tensor(self.mdp_info.gamma).cuda() if self._use_cuda else to_float_tensor(self.mdp_info.gamma)
        absorbing = torch.tensor(absorbing).cuda() if self._use_cuda else absorbing
        current_Q = self._critic_approximator(obs, act, output_tensor=True)
        if not self._use_target:
            next_v = self.getV(next_obs)
        else:
            with torch.no_grad():
                next_v = self.get_targetV(next_obs).detach()
        absorbing = torch.unsqueeze(absorbing, 1)
        y = (1 - absorbing) * gamma.detach() * self._Q_Q_multiplier * torch.clip(next_v, self._Q_min, self._Q_max)

        reward = (self._Q_Q_multiplier*current_Q - y)
        exp_reward = reward[is_expert]

        if self._loss_mode_exp == "bootstrap":
            loss_term1 = - exp_reward.mean()
        elif self._loss_mode_exp == "fix":
            if self._Q_exp_loss == "MSE":
                loss_term1 = F.mse_loss(current_Q[is_expert], torch.ones_like(current_Q[is_expert]) * self._Q_max)
            elif self._Q_exp_loss == "Huber":
                loss_term1 = F.huber_loss(current_Q[is_expert], torch.ones_like(current_Q[is_expert]) * self._Q_max)
            elif self._Q_exp_loss is None:
                raise ValueError("If you choose loss_mode_exp == fix, you have to specify Q_exp_loss. Setting it to"
                                 "None is not valid.")
            else:
                raise ValueError(
                    "Choosen Q_exp_loss %s is not supported. Choose either MSE or Huber." % self._Q_exp_loss)

        # do the logging
        self.logging_loss(current_Q, y, reward, is_expert, obs, act, absorbing)

        # 2nd term for our loss (use expert and policy states)
        V = self._Q_Q_multiplier * self.getV(obs)
        value = (V - y)
        self.sw_add_scalar('V for policy on all states', self._Q_Q_multiplier * V.mean(), self._iter)
        value_loss = value
        if self._plcy_loss_mode == "value":
            loss_term2 = value_loss.mean()
        elif self._plcy_loss_mode == "value_expert":
            value_loss_exp = value_loss[is_expert]
            loss_term2 = value_loss_exp.mean()
        elif self._plcy_loss_mode == "value_policy":
            value_loss_plcy = value_loss[~is_expert]
            loss_term2 = value_loss_plcy.mean()
        elif self._plcy_loss_mode == "q_old_policy":
            reward_plcy = reward[~is_expert]
            loss_term2 = reward_plcy.mean()
        elif self._plcy_loss_mode == "value_q_old_policy":
            reward_plcy = reward[~is_expert]
            loss_term2 = reward_plcy.mean() + value_loss.mean()
        elif self._plcy_loss_mode == "v0":
            value_loss_v0 = (1-gamma.detach()) * self.getV(obs[is_expert])
            loss_term2 = value_loss_v0.mean()
        elif self._plcy_loss_mode == "off":
            loss_term2 = 0.0
        else:
            raise ValueError("Undefined policy loss mode: %s" % self._plcy_loss_mode)

        # regularize
        chi2_loss = self.regularizer_loss(absorbing, reward, gamma, is_expert, treat_absorbing_states=self._treat_absorbing_states)

        loss_Q = loss_term1 + loss_term2 + chi2_loss
        self.update_Q_parameters(loss_Q)

        grads = []
        for param in self._critic_approximator.model.network.parameters():
            grads.append(param.grad.view(-1))
        grads = torch.cat(grads)
        norm = grads.norm(dim=0, p=2)
        if self._iter % self._logging_iter == 0:
            self.sw_add_scalar('Gradients/Norm2 Gradient LossQ wrt. Q-parameters', norm, self._iter)

        return loss_term1, loss_term2, chi2_loss
