from copy import deepcopy

import torch
import torch.nn.functional as F
import numpy as np

from imitation_lib.imitation import GAIL_TRPO
from imitation_lib.utils import to_float_tensors


class VAIL(GAIL_TRPO):

    """ This is a version of GAIL that uses the variational discriminator bottleneck as proposed by:

        Xue Bin Peng, Angjoo Kanazawa, Sam Toyer, Pieter Abbeel, Sergey Levine
        Variational Discriminator Bottleneck: Improving Imitation Learning, Inverse RL, and GANs by Constraining Information Flow."""

    def __init__(self, **kwargs):

        # call base constructor
        super(VAIL, self).__init__(**kwargs)

    def discrim_output(self, *inputs, apply_mask=True):
        inputs = self.prepare_discrim_inputs(inputs, apply_mask=apply_mask)
        d_out,_ ,_ = self._D(*inputs)
        return d_out

    def _discriminator_logging(self, inputs, targets):
        if self._sw:
            plcy_inputs, demo_inputs = self.divide_data_to_demo_and_plcy(inputs)
            loss = deepcopy(self._loss)
            loss_eval = loss.forward(to_float_tensors(self._D(*inputs)), torch.tensor(targets))
            self._sw.add_scalar('DiscrimLoss', loss_eval, self._iter // 3)

            # calculate the accuracies
            dout_exp = torch.sigmoid(torch.tensor(self.discrim_output(*demo_inputs, apply_mask=False)))
            dout_plcy = torch.sigmoid(torch.tensor(self.discrim_output(*plcy_inputs, apply_mask=False)))
            accuracy_exp = np.mean((F.sigmoid(torch.tensor(self.discrim_output(*demo_inputs))) > 0.5).numpy())
            accuracy_gen = np.mean((F.sigmoid(torch.tensor(self.discrim_output(*plcy_inputs))) < 0.5).numpy())
            self._sw.add_scalar('D_Generator_Accuracy', accuracy_gen, self._iter // 3)
            self._sw.add_scalar('D_Out_Generator', np.mean(dout_plcy.numpy()), self._iter // 3)
            self._sw.add_scalar('D_Expert_Accuracy', accuracy_exp, self._iter // 3)
            self._sw.add_scalar('D_Out_Expert', np.mean(dout_exp.numpy()), self._iter // 3)

            # calculate individual losses
            bernoulli_ent = torch.mean(loss.logit_bernoulli_entropy(torch.tensor(self.discrim_output(*inputs))))
            neg_bernoulli_ent_loss = -loss.entcoeff * bernoulli_ent
            plcy_target = targets[0:len(targets)//2]
            demo_target = targets[len(targets)//2:]
            loss_exp = loss.forward(to_float_tensors(self._D(*demo_inputs)), torch.tensor(demo_target)) / 2
            loss_gen = loss.forward(to_float_tensors(self._D(*plcy_inputs)), torch.tensor(plcy_target)) / 2
            self._sw.add_scalar('Bernoulli Ent.', bernoulli_ent, self._iter // 3)
            self._sw.add_scalar('Neg. Bernoulli Ent. Loss (incl. in DiscrimLoss)', neg_bernoulli_ent_loss, self._iter // 3)
            self._sw.add_scalar('Generator_loss', loss_gen, self._iter // 3)
            self._sw.add_scalar('Expert_Loss', loss_exp, self._iter // 3)

            # calculate bottleneck loss
            d, mu, logvar = to_float_tensors(self._D(*inputs))
            bottleneck_loss = loss.bottleneck_loss(mu, logvar)
            self._sw.add_scalar('Bottleneck_Loss', bottleneck_loss, self._iter // 3)
            self._sw.add_scalar('Beta', loss._beta, self._iter // 3)
            self._sw.add_scalar('Bottleneck_Loss_times_Beta', loss._beta * bottleneck_loss, self._iter // 3)
