from imitation_lib.imitation import LSIQ
import torch
import torch.nn.functional as F

from mushroom_rl.utils.torch import to_float_tensor


class LSIQ_SQIL(LSIQ):

    """ This is a version of LSIQ that uses a SQIL-like loss formulation. """

    def __init__(self, target_clipping=True, **kwargs):
        self._target_clipping = target_clipping
        super().__init__(**kwargs)

    def _lossQ(self, obs, act, next_obs, absorbing, is_expert):

        # Calculate 1st term of loss: -E_(ρ_expert)[Q(s, a) - γV(s')]
        gamma = to_float_tensor(self.mdp_info.gamma).cuda() if self._use_cuda else to_float_tensor(self.mdp_info.gamma)
        absorbing = torch.tensor(absorbing).cuda() if self._use_cuda else absorbing
        current_Q = self._critic_approximator(obs, act, output_tensor=True)
        if not self._use_target:
            next_v = self.getV(next_obs)
        else:
            with torch.no_grad():
                next_v = self.get_targetV(next_obs).detach()
        absorbing = torch.unsqueeze(absorbing, 1)
        if self._target_clipping:
            y = (1 - absorbing) * gamma.detach() * self._Q_Q_multiplier * torch.clip(next_v, self._Q_min, self._Q_max)
        else:
            y = (1 - absorbing) * gamma.detach() * self._Q_Q_multiplier * next_v

        # define the rewards
        if self._treat_absorbing_states:
            r_max = (1 - absorbing) * ((1 / self._reg_mult)) \
                    + absorbing * (1 / (1 - gamma.detach())) * ((1 / self._reg_mult))
            r_min = (1 - absorbing) * (-(1 / self._reg_mult))\
                    + absorbing * (1 / (1 - gamma.detach())) * (-(1 / self._reg_mult))
        else:
            r_max = torch.ones_like(absorbing) * ((1 / self._reg_mult))
            r_min = torch.ones_like(absorbing) * (-(1 / self._reg_mult))

        r_max = r_max[is_expert]
        r_min = r_min[~is_expert]

        # expert part
        if self._loss_mode_exp == "bootstrap":
            if self._Q_exp_loss == "MSE":
                loss_term1 = torch.mean(torch.square(current_Q[is_expert] - (r_max + y[is_expert])))
            elif self._Q_exp_loss == "Huber":
                loss_term1 = F.huber_loss(current_Q[is_expert], (r_max + y[is_expert]))
            else:
                raise ValueError("Unknown loss.")
        elif self._loss_mode_exp == "fix":
            if self._Q_exp_loss == "MSE":
                loss_term1 = F.mse_loss(current_Q[is_expert], torch.ones_like(current_Q[is_expert]) * self._Q_max)
            elif self._Q_exp_loss == "Huber":
                loss_term1 = F.huber_loss(current_Q[is_expert], torch.ones_like(current_Q[is_expert]) * self._Q_max)
            else:
                raise ValueError("Unknown loss.")
        else:
            raise ValueError("Unknown expert loss mode.")

        # policy part
        if self._Q_exp_loss == "MSE":
            loss_term2 = torch.mean(torch.square(current_Q[~is_expert] - (r_min + y[~is_expert])))
        elif self._Q_exp_loss == "Huber":
            loss_term2 = F.huber_loss(current_Q[~is_expert], (r_min + y[~is_expert]))
        else:
            raise ValueError("Unknown loss.")

        # do the logging
        reward = (self._Q_Q_multiplier * current_Q - y)
        self.logging_loss(current_Q, y, reward, is_expert, obs, act, absorbing)

        loss_Q = loss_term1 + loss_term2
        self.update_Q_parameters(loss_Q)

        grads = []
        for param in self._critic_approximator.model.network.parameters():
            grads.append(param.grad.view(-1))
        grads = torch.cat(grads)
        norm = grads.norm(dim=0, p=2)
        if self._iter % self._logging_iter == 0:
            self.sw_add_scalar('Gradients/Norm2 Gradient LossQ wrt. Q-parameters', norm, self._iter)

        return loss_term1, loss_term2, 0.0
