from model_utils import *
import pickle
from tqdm import tqdm
import torch

class MultiDec(torch.nn.Module):
    """
    AAAI 2019 model 1
    """

    def __init__(self, args, reny_dataloader):
        super(MultiDec, self).__init__()
        raise NotImplementedError
        self.args = args
        self.training_all_except_encoder = True
        self.reny_dataloader = reny_dataloader
        self.encoder = EncoderRNN(args, args.number_of_tokens, args.hidden_dim, args.number_of_layers)
        self.decoder_pos = DecoderRNN(args, args.hidden_dim, args.number_of_tokens, args.number_of_layers)
        self.decoder_neg = DecoderRNN(args, args.hidden_dim, args.number_of_tokens, args.number_of_layers)
        self.loss = torch.nn.NLLLoss(ignore_index=self.args.tokenizer.pad_token_id)
        self.proj_content = nn.Sequential(nn.Linear(args.hidden_dim, args.hidden_dim), nn.LeakyReLU(),
                                          nn.Linear(args.hidden_dim, args.hidden_dim))

        # D_\gamma
        self.d_gamma = ClassifierGamma(args.hidden_dim + args.number_of_styles,
                                       args.number_of_styles)  # TODO : clarify

        # Style classifier
        self.style_classifier = Classifier(args.hidden_dim, args.number_of_styles, args.use_complex_classifier)

        # Loss : paper multipliers
        self.mul_mi = self.args.mul_mi

    def forward(self, input_tensor, labels, teacher_ratio):
        ###############################
        # Update style classifier loss:
        ###############################
        loss_gen, loss_mi, reny, loss_gamma, loss_style_classifier, loss_gen_reny, loss_h_sz, loss_h_s, loss_gen_reny = torch.tensor(
            0.0).to(
            self.args.device), torch.tensor(0.0).to(self.args.device), torch.tensor(0.0).to(
            self.args.device), torch.tensor(0.0).to(self.args.device), torch.tensor(0.0).to(
            self.args.device), torch.tensor(0.0).to(self.args.device), torch.tensor(0.0).to(
            self.args.device), torch.tensor(0.0).to(self.args.device), torch.tensor(0.0).to(self.args.device)
        gradient_encoder, gradient_decoder, gradient_content_proj, gradient_style_proj, gradient_reny = torch.tensor(
            0.0).to(self.args.device), torch.tensor(0.0).to(self.args.device), torch.tensor(0.0).to(
            self.args.device), torch.tensor(0.0).to(self.args.device), torch.tensor(0.0).to(self.args.device)

        # epoch_iterator = tqdm(self.reny_dataloader, desc="Reny Training")
        if self.training_all_except_encoder:
            for step, batch in enumerate(self.reny_dataloader):
                if step == self.args.reny_training + 1:
                    break
                inputs_reny = batch['line'].to(self.args.device)
                labels_reny = batch['label'].to(self.args.device)
                free_params(self.style_classifier)
                frozen_params(self.d_gamma)
                frozen_params(self.proj_content)
                frozen_params(self.encoder)
                frozen_params(self.decoder_neg)
                frozen_params(self.decoder_pos)

                encoder_hidden = self.encoder.initHidden()
                encoder_output, encoder_hidden = self.encoder(inputs_reny, encoder_hidden)

                # Content space
                content = self.proj_content(encoder_hidden)
                style_content_pred = self.style_classifier(content)
                loss_style_classifier = self.loss(style_content_pred, labels_reny)
                if self.training:
                    loss_style_classifier.backward()
                    torch.nn.utils.clip_grad_norm_(self.style_classifier.parameters(), self.args.max_grad_norm)
                    self.args.optimizer.step()
                    self.style_classifier.zero_grad()

                ################
                # Update D_gamma
                ################
                if not self.args.no_minimization_of_mi_training:
                    loss_gamma = 0
                    free_params(self.d_gamma)
                    frozen_params(self.style_classifier)
                    frozen_params(self.proj_content)
                    frozen_params(self.encoder)
                    frozen_params(self.decoder_pos)
                    frozen_params(self.decoder_neg)
                    inputs_reny_dgamma = inputs_reny.clone()
                    encoder_hidden = self.encoder.initHidden()
                    encoder_output, encoder_hidden = self.encoder(inputs_reny_dgamma, encoder_hidden)

                    # Content space
                    content = self.proj_content(encoder_hidden)

                    style_content_pred = self.style_classifier(content)
                    label_content_pred = style_content_pred.topk(1, dim=-1)[-1].squeeze(-1)
                    label_v = torch.tensor(
                        [[1., 0.] if el == 0 else [0., 1.] for el in label_content_pred.tolist()]).to(
                        self.args.device)
                    label_u = torch.tensor([[1., 0.] if el == 0 else [0., 1.] for el in labels.tolist()]).to(
                        self.args.device)
                    u = torch.cat([content, label_u.unsqueeze(0).float().repeat(4, 1, 1)],
                                  dim=-1)  # 4 is for bid + 2 layers
                    v = torch.cat([content, label_v.unsqueeze(0).float().repeat(4, 1, 1)], dim=-1)

                    d_gamma_content_pred_u = self.d_gamma(u)  # - log
                    d_gamma_content_pred_v = self.d_gamma(v)  # - log

                    loss_gamma -= torch.mean(d_gamma_content_pred_u[:, 0]) / 2  # first term in (16)
                    loss_gamma -= torch.mean(d_gamma_content_pred_v[:, 1]) / 2
                    if self.training:
                        loss_gamma.backward()
                        gradient_reny += comput_gradient_norm(self.d_gamma)
                        torch.nn.utils.clip_grad_norm_(self.d_gamma.parameters(), self.args.max_grad_norm)
                        self.args.optimizer.step()
                        self.d_gamma.zero_grad()

                ################
                # Update Decoder
                ################
                free_params(self.decoder_neg)
                free_params(self.decoder_pos)
                frozen_params(self.d_gamma)
                frozen_params(self.style_classifier)
                frozen_params(self.proj_content)
                frozen_params(self.encoder)

                input_tensor_reny_gen_golden = inputs_reny.clone()
                inputs_reny_gen = inputs_reny.clone()
                if self.args.add_noise:
                    inputs_reny_gen = corrupt_input(self, inputs_reny)

                encoder_hidden = self.encoder.initHidden()
                encoder_output, encoder_hidden = self.encoder(inputs_reny_gen, encoder_hidden)

                # Content space
                encoder_hidden = self.proj_content(encoder_hidden)

                # Sentence Generation
                loss_gen_reny = 0
                decoder_input = torch.ones(self.args.batch_size, 1).to(
                    self.args.device).long() * self.args.tokenizer.sep_token_id
                decoder_hidden = encoder_hidden
                use_teacher_forcing = True if random.random() < teacher_ratio else False

                decoded_words = [decoder_input]

                decoder_input_pos = decoder_input[labels == 1]
                decoder_input_neg = decoder_input[labels == 0]

                decoder_hidden_pos = decoder_hidden[:, labels == 1, :]
                decoder_hidden_neg = decoder_hidden[:, labels == 0, :]

                if use_teacher_forcing and self.training:
                    # Teacher forcing: Feed the target as the next input
                    for di in range(self.args.max_length):
                        decoder_output_pos, decoder_hidden_pos = self.decoder_pos(decoder_input_pos, decoder_hidden_pos)
                        decoder_output_neg, decoder_hidden_neg = self.decoder_neg(decoder_input_neg, decoder_hidden_neg)

                        loss_gen_reny += self.loss(decoder_output_pos.squeeze(1),
                                                   input_tensor_reny_gen_golden[labels == 1][:,
                                                   di]) / self.args.max_length
                        loss_gen_reny += self.loss(decoder_output_neg.squeeze(1),
                                                   input_tensor_reny_gen_golden[labels == 0][:,
                                                   di]) / self.args.max_length
                        decoder_input_neg = input_tensor_reny_gen_golden[labels == 0][:, di].unsqueeze(
                            -1)  # Teacher forcing
                        decoder_input_pos = input_tensor_reny_gen_golden[labels == 1][:, di].unsqueeze(
                            -1)  # Teacher forcing

                else:
                    # Without teacher forcing: use its own predictions as the next input
                    for di in range(self.args.max_length):
                        decoder_output_pos, decoder_hidden_pos = self.decoder_pos(decoder_input_pos, decoder_hidden_pos)
                        decoder_output_neg, decoder_hidden_neg = self.decoder_neg(decoder_input_neg, decoder_hidden_neg)
                        topv_neg, topi_neg = decoder_output_neg.topk(1)
                        topv_pos, topi_pos = decoder_output_pos.topk(1)
                        decoder_input_pos = topi_pos.squeeze().detach().unsqueeze(-1)  # detach from history as input
                        decoder_input_neg = topi_neg.squeeze().detach().unsqueeze(-1)  # detach from history as input
                        loss_gen_reny += self.loss(decoder_output_pos.squeeze(1),
                                                   input_tensor_reny_gen_golden[labels == 1][:,
                                                   di]) / self.args.max_length
                        loss_gen_reny += self.loss(decoder_output_neg.squeeze(1),
                                                   input_tensor_reny_gen_golden[labels == 0][:,
                                                   di]) / self.args.max_length

                if self.training:
                    loss_gen_reny.backward()
                    torch.nn.utils.clip_grad_norm_(self.decoder_pos.parameters(), self.args.max_grad_norm)
                    torch.nn.utils.clip_grad_norm_(self.decoder_neg.parameters(), self.args.max_grad_norm)
                    self.args.optimizer.step()
                    self.decoder_neg.zero_grad()
                    self.decoder_pos.zero_grad()

        ###############################
        # Update Genloss + \lambda * MI
        ###############################
        input_tensor_gen_golden = input_tensor.clone()
        input_tensor_gen = input_tensor.clone()
        if self.args.add_noise:
            input_tensor_gen = corrupt_input(self, input_tensor_gen)
        input_tensor_mi = input_tensor_gen.clone()
        if self.training:
            frozen_params(self.d_gamma)
            frozen_params(self.style_classifier)
            frozen_params(self.decoder_pos)
            frozen_params(self.decoder_neg)
            free_params(self.proj_content)
            free_params(self.encoder)

        # H(S)
        encoder_hidden_mi = self.encoder.initHidden()
        encoder_output_mi, encoder_hidden_mi = self.encoder(input_tensor_mi, encoder_hidden_mi)

        # Content space
        content_mi = self.proj_content(encoder_hidden_mi)
        style_content_pred_mi = self.style_classifier(content_mi)

        if not self.args.no_minimization_of_mi_training:
            loss_h_s = - torch.mean(style_content_pred_mi[:, 0]) * (1 - torch.mean(labels.float()).item()) - torch.mean(
                style_content_pred_mi[:, 1]) * torch.mean(labels.float()).item()  # we use Jensen

            # Compute Reny
            label_content_pred = style_content_pred_mi.topk(1, dim=-1)[-1].squeeze(-1)
            label_v = torch.tensor([[1., 0.] if el == 0 else [0., 1.] for el in label_content_pred.tolist()]).to(
                self.args.device)
            label_u = torch.tensor([[1., 0.] if el == 0 else [0., 1.] for el in labels.tolist()]).to(self.args.device)
            u = torch.cat([content_mi, label_u.unsqueeze(0).float().repeat(4, 1, 1)], dim=-1)
            v = torch.cat([content_mi, label_v.unsqueeze(0).float().repeat(4, 1, 1)], dim=-1)
            d_gamma_content_pred_u = self.d_gamma(u)
            d_gamma_content_pred_v = self.d_gamma(v)
            R = torch.mean((torch.exp(d_gamma_content_pred_u[:, 0]) / torch.exp(d_gamma_content_pred_v[:, 1])) ** (
                    self.args.alpha - 1))

            reny = torch.abs(torch.log(R) / (self.args.alpha - 1))

            loss_h_sz = self.loss(style_content_pred_mi, labels)

            # MI
            if self.args.no_reny:
                loss_mi = self.args.mul_mi * torch.abs(loss_h_s - loss_h_sz)
            else:
                loss_mi = self.args.mul_mi * torch.abs(loss_h_s - loss_h_sz + reny)
        else:
            loss_mi = - self.args.mul_mi * self.loss(style_content_pred_mi, labels)

        #########################
        ######## Genloss ########
        #########################
        encoder_hidden = self.encoder.initHidden()
        encoder_output, encoder_hidden = self.encoder(input_tensor_gen, encoder_hidden)

        # Content space
        encoder_hidden = self.proj_content(encoder_hidden)

        # Sentence Generation
        loss_gen = 0
        decoder_input = torch.ones(self.args.batch_size, 1).to(
            self.args.device).long() * self.args.tokenizer.sep_token_id
        decoder_hidden = encoder_hidden
        use_teacher_forcing = True if random.random() < teacher_ratio else False

        decoded_words = [decoder_input]

        decoder_input_pos = decoder_input[labels == 1]
        decoder_input_neg = decoder_input[labels == 0]

        decoder_hidden_pos = decoder_hidden[:, labels == 1, :]
        decoder_hidden_neg = decoder_hidden[:, labels == 0, :]

        # TODO : changer
        if use_teacher_forcing and self.training:
            # Teacher forcing: Feed the target as the next input
            for di in range(self.args.max_length):
                decoder_output_pos, decoder_hidden_pos = self.decoder_pos(decoder_input_pos, decoder_hidden_pos)
                decoder_output_neg, decoder_hidden_neg = self.decoder_neg(decoder_input_neg, decoder_hidden_neg)
                loss_gen += self.loss(decoder_output_pos.squeeze(1),
                                      input_tensor_gen_golden[labels == 1][:, di]) / self.args.max_length
                loss_gen += self.loss(decoder_output_neg.squeeze(1),
                                      input_tensor_gen_golden[labels == 0][:, di]) / self.args.max_length
                decoder_input_pos = input_tensor_gen_golden[labels == 1][:, di].unsqueeze(-1)  # Teacher forcing
                decoder_input_neg = input_tensor_gen_golden[labels == 0][:, di].unsqueeze(-1)  # Teacher forcing

        else:
            # Without teacher forcing: use its own predictions as the next input
            for di in range(self.args.max_length):
                decoder_output_pos, decoder_hidden_pos = self.decoder_pos(decoder_input_pos, decoder_hidden_pos)
                decoder_output_neg, decoder_hidden_neg = self.decoder_neg(decoder_input_neg, decoder_hidden_neg)
                topv_neg, topi_neg = decoder_output_neg.topk(1)
                topv_pos, topi_pos = decoder_output_pos.topk(1)
                decoder_input_pos = topi_pos.squeeze().detach().unsqueeze(-1)  # detach from history as input
                decoder_input_neg = topi_neg.squeeze().detach().unsqueeze(-1)  # detach from history as input
                decoded_words.append(torch.cat([topi_neg.squeeze(-1), topi_pos.squeeze(-1)], dim=0))
                loss_gen += self.loss(decoder_output_pos.squeeze(1),
                                      input_tensor_gen_golden[labels == 1][:,
                                      di]) / self.args.max_length
                loss_gen += self.loss(decoder_output_neg.squeeze(1),
                                      input_tensor_gen_golden[labels == 0][:,
                                      di]) / self.args.max_length

        # Compute All Losses for MI
        if self.training:
            loss = loss_gen + loss_mi
            loss.backward()
            gradient_encoder, gradient_decoder, gradient_content_proj = comput_gradient_norm(
                self.encoder), comput_gradient_norm(self.decoder_pos), comput_gradient_norm(
                self.proj_content)
            torch.nn.utils.clip_grad_norm_(self.proj_content.parameters(), self.args.max_grad_norm)
            torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), self.args.max_grad_norm)
            torch.nn.utils.clip_grad_norm_(self.decoder_pos.parameters(), self.args.max_grad_norm)
            torch.nn.utils.clip_grad_norm_(self.decoder_neg.parameters(), self.args.max_grad_norm)
            self.args.optimizer.step()
            self.args.scheduler.step()
            self.proj_content.zero_grad()
            self.encoder.zero_grad()

        decoded_words = torch.cat(decoded_words, dim=-1)  # memory ineficient but nicer
        losses_dic = {'loss_gen': loss_gen, 'loss_mi': loss_mi, 'reny': reny, 'loss_gamma': loss_gamma,
                      'loss_h_sz': loss_h_sz, 'loss_h_s': loss_h_s, 'loss_style_classifier': loss_style_classifier,
                      'gradient_encoder': gradient_encoder, 'gradient_decoder': gradient_decoder,
                      'loss_gen_reny': loss_gen_reny, 'gradient_content_proj': gradient_content_proj,
                      'gradient_reny': gradient_reny / self.args.reny_training}
        return (losses_dic, decoded_words, input_tensor_gen)

    def predict_latent_space(self, input_tensor):
        encoder_hidden = self.encoder.initHidden()
        encoder_output, encoder_hidden = self.encoder(input_tensor, encoder_hidden)

        # Ensure Style not in content space
        content = self.proj_content(encoder_hidden)
        return content

    def forward_transfert(self, inputs, labels_to_transfert, pos_style, neg_style):
        encoder_hidden = self.encoder.initHidden()
        encoder_output, encoder_hidden = self.encoder(inputs, encoder_hidden)

        # Ensure Style not in content space
        encoder_hidden = self.proj_content(encoder_hidden)

        decoder_input = torch.ones(self.args.batch_size, 1).to(
            self.args.device).long() * self.args.tokenizer.sep_token_id
        decoder_hidden = encoder_hidden

        decoded_words = [decoder_input]

        # TODO : finir cette merde pour garder track des choses
        decoder_input_pos = decoder_input[(labels_to_transfert == 1).bool()]
        decoder_input_neg = decoder_input[(labels_to_transfert == 0).bool()]

        decoder_hidden_pos = decoder_hidden[:, (labels_to_transfert == 1).bool(), :]
        decoder_hidden_neg = decoder_hidden[:, (labels_to_transfert == 0).bool(), :]

        for di in range(self.args.max_length):
            decoder_output_pos, decoder_hidden_pos = self.decoder_pos(decoder_input_pos, decoder_hidden_pos)
            decoder_output_neg, decoder_hidden_neg = self.decoder_neg(decoder_input_neg, decoder_hidden_neg)
            topv_neg, topi_neg = decoder_output_neg.topk(1)
            topv_pos, topi_pos = decoder_output_pos.topk(1)
            decoded_words.append(torch.cat([topi_neg.squeeze(-1), topi_pos.squeeze(-1)], dim=0))
            decoder_input_pos = topi_pos.squeeze().detach().unsqueeze(-1)  # detach from history as input
            decoder_input_neg = topi_neg.squeeze().detach().unsqueeze(-1)  # detach from history as input

        decoded_words = torch.cat(decoded_words, dim=-1)
        reorder_decoded_words = torch.zeros_like(decoded_words).to(self.args.device)
        reorder_decoded_words[(labels_to_transfert == 0).bool()] = decoded_words[
                                                                   :torch.sum(labels_to_transfert == 0).item(),
                                                                   :]
        reorder_decoded_words[(labels_to_transfert == 1).bool()] = decoded_words[
                                                                   torch.sum(labels_to_transfert == 0).item():,
                                                                   :]
        return reorder_decoded_words

class DAE(torch.nn.Module):
    """
    ACL 2019
    """

    def __init__(self, args, reny_dataloader):
        super(DAE, self).__init__()
        print('Reverifier')
        raise NotImplementedError
        self.args = args
        self.ignored_index = []
        self.reny_dataloader = reny_dataloader
        self.encoder = EncoderRNN(args, args.number_of_tokens, args.hidden_dim, args.number_of_layers)
        self.decoder = DecoderRNN(args, args.style_dim + args.content_dim, args.number_of_tokens, args.number_of_layers)
        self.loss = torch.nn.NLLLoss(ignore_index=self.args.tokenizer.pad_token_id)  # ignored_indextortss
        self.proj_style = nn.Sequential(nn.Linear(args.hidden_dim, args.hidden_dim), nn.LeakyReLU(),
                                        nn.Linear(args.hidden_dim, args.style_dim))
        self.proj_content = nn.Sequential(nn.Linear(args.hidden_dim, args.hidden_dim), nn.LeakyReLU(),
                                          nn.Linear(args.hidden_dim, args.content_dim))

        # D_\gamma
        self.d_gamma = ClassifierGamma(args.content_dim + args.number_of_styles,
                                       args.number_of_styles)

        # Style classifier
        self.style_classifier_from_content = Classifier(args.content_dim, args.number_of_styles)
        self.style_classifier_from_style = Classifier(args.style_dim, args.number_of_styles)

        # Content classifier
        self.content_classifier_from_content = Classifier(args.content_dim, args.number_of_tokens)
        self.content_classifier_from_style = Classifier(args.style_dim, args.number_of_tokens)

        # Loss : paper multipliers
        self.mul_mi = self.args.mul_mi

        self.lambda_content_style = 0.03  # read content prediction FROM Style
        self.lambda_content_content = 3  # read content prediction FROM content
        self.lambda_style_content = self.args.mul_mi  # read style prediction FROM content
        self.lambda_style_style = 10  # read style prediction FROM Style

        # TODO : loss with index not to backprop

    def content_loss(self, content_content_pred, labels):
        # BCE loss avec des 1

        loss = 0
        for b_index in range(labels.size(0)):
            labels_list = list(set(labels[b_index, :].tolist()))
            labels_list.remove(0)
            loss -= torch.mean(
                torch.index_select(content_content_pred[0, :], 0, torch.tensor(labels_list).to(self.args.device)))
        return loss

    def forward(self, input_tensor, labels, teacher_ratio):
        ###############################
        # Update style classifier loss:
        ###############################
        loss_gen, loss_mi, reny, loss_gamma, loss_style_classifier, loss_gen_reny, loss_adv_first_stage = torch.tensor(
            0.0).to(self.args.device), torch.tensor(0.0).to(self.args.device), torch.tensor(0.0).to(
            self.args.device), torch.tensor(0.0).to(self.args.device), torch.tensor(0.0).to(
            self.args.device), torch.tensor(0.0).to(self.args.device), torch.tensor(0.0).to(self.args.device)
        gradient_encoder, gradient_decoder, gradient_content_proj, gradient_style_proj, gradient_reny = torch.tensor(
            0.0).to(self.args.device), torch.tensor(0.0).to(self.args.device), torch.tensor(0.0).to(
            self.args.device), torch.tensor(0.0).to(self.args.device), torch.tensor(0.0).to(self.args.device)

        # epoch_iterator = tqdm(self.reny_dataloader, desc="Reny Training")

        for step, batch in enumerate(self.reny_dataloader):
            if step == self.args.reny_training + 1:
                break
            inputs_reny = batch['line'].to(self.args.device)
            labels_reny = batch['label'].to(self.args.device)
            free_params(self.style_classifier_from_content)
            free_params(self.style_classifier_from_style)
            free_params(self.content_classifier_from_content)
            free_params(self.content_classifier_from_style)
            frozen_params(self.d_gamma)
            frozen_params(self.proj_style)
            frozen_params(self.proj_content)
            frozen_params(self.encoder)
            frozen_params(self.decoder)

            encoder_hidden = self.encoder.initHidden()
            encoder_output, encoder_hidden = self.encoder(inputs_reny, encoder_hidden)

            ################################################
            # 1 . Update Style Classifier from Content     #
            # 2 . Update Content Classifier from Content   #
            # 3 . Update Style Classifier from Style       #
            # 4 . Update Content Classifier from Style     #
            ################################################
            content = self.proj_content(encoder_hidden)
            style_content_pred = self.style_classifier_from_content(content)
            loss_style_classifier_from_content = self.loss(style_content_pred, labels_reny)

            content_content_pred = self.content_classifier_from_content(content)
            loss_content_classifier_from_content = self.content_loss(content_content_pred, inputs_reny)

            style = self.proj_style(encoder_hidden)
            style_style_pred = self.style_classifier_from_style(style)
            loss_style_classifier_from_style = self.loss(style_style_pred, labels_reny)

            style_content_pred = self.content_classifier_from_style(style)
            loss_content_classifier_from_style = self.content_loss(style_content_pred, inputs_reny)

            if self.training:
                loss_adv_first_stage = loss_style_classifier_from_style + loss_style_classifier_from_content + loss_content_classifier_from_content + loss_content_classifier_from_style
                loss_adv_first_stage.backward()
                torch.nn.utils.clip_grad_norm_(self.style_classifier_from_content.parameters(), self.args.max_grad_norm)
                torch.nn.utils.clip_grad_norm_(self.style_classifier_from_style.parameters(), self.args.max_grad_norm)
                torch.nn.utils.clip_grad_norm_(self.content_classifier_from_style.parameters(), self.args.max_grad_norm)
                torch.nn.utils.clip_grad_norm_(self.content_classifier_from_content.parameters(),
                                               self.args.max_grad_norm)
                self.args.optimizer.step()
                self.style_classifier_from_content.zero_grad()
                self.content_classifier_from_style.zero_grad()
                self.content_classifier_from_content.zero_grad()
                self.style_classifier_from_style.zero_grad()

            ################
            # Update Decoder
            ################
            free_params(self.decoder)
            frozen_params(self.style_classifier_from_content)
            frozen_params(self.content_classifier_from_content)
            frozen_params(self.style_classifier_from_style)
            frozen_params(self.content_classifier_from_style)
            frozen_params(self.proj_style)
            frozen_params(self.proj_content)
            frozen_params(self.d_gamma)
            frozen_params(self.encoder)

            input_tensor_reny_gen_golden = inputs_reny.clone()
            inputs_reny_gen = inputs_reny.clone()
            if self.args.add_noise:
                inputs_reny_gen = corrupt_input(self, inputs_reny)

            encoder_hidden = self.encoder.initHidden()
            encoder_output, encoder_hidden = self.encoder(inputs_reny_gen, encoder_hidden)

            # Style space
            style = self.proj_style(encoder_hidden)
            # Content space
            content = self.proj_content(encoder_hidden)

            # Concatenate style and other
            encoder_hidden = torch.cat([style, content], dim=-1)

            # Sentence Generation
            loss_gen_reny = 0
            decoder_input = torch.ones(self.args.batch_size, 1).to(
                self.args.device).long() * self.args.tokenizer.sep_token_id
            decoder_hidden = encoder_hidden
            use_teacher_forcing = True if random.random() < teacher_ratio else False

            decoded_words = [decoder_input]

            if use_teacher_forcing and self.training:
                # Teacher forcing: Feed the target as the next input
                for di in range(self.args.max_length):
                    decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
                    loss_gen_reny += self.loss(decoder_output.squeeze(1),
                                               input_tensor_reny_gen_golden[:, di]) / self.args.max_length
                    decoder_input = input_tensor_reny_gen_golden[:, di].unsqueeze(-1)  # Teacher forcing

            else:
                # Without teacher forcing: use its own predictions as the next input
                for di in range(self.args.max_length):
                    decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
                    topv, topi = decoder_output.topk(1)
                    decoder_input = topi.squeeze().detach().unsqueeze(-1)  # detach from history as input
                    decoded_words.append(topi.squeeze(-1))
                    loss_gen_reny += self.loss(decoder_output.squeeze(1),
                                               input_tensor_reny_gen_golden[:, di]) / self.args.max_length

            if self.training:
                loss_gen_reny.backward()
                torch.nn.utils.clip_grad_norm_(self.decoder.parameters(), self.args.max_grad_norm)
                self.args.optimizer.step()
                self.decoder.zero_grad()

            ################
            # Update D_gamma
            ################
            free_params(self.d_gamma)
            frozen_params(self.style_classifier_from_style)
            frozen_params(self.style_classifier_from_content)
            frozen_params(self.content_classifier_from_content)
            frozen_params(self.content_classifier_from_style)
            frozen_params(self.proj_content)
            frozen_params(self.proj_style)
            frozen_params(self.encoder)
            frozen_params(self.decoder)
            if not self.args.no_minimization_of_mi_training:
                loss_gamma = 0

                inputs_reny_dgamma = inputs_reny.clone()
                encoder_hidden = self.encoder.initHidden()
                encoder_output, encoder_hidden = self.encoder(inputs_reny_dgamma, encoder_hidden)

                # Content space
                content = self.proj_content(encoder_hidden)

                style_content_pred = self.style_classifier_from_content(content)
                label_content_pred = style_content_pred.topk(1, dim=-1)[-1].squeeze(-1)
                label_v = torch.tensor(
                    [[1., 0.] if el == 0 else [0., 1.] for el in label_content_pred.tolist()]).to(
                    self.args.device)
                label_u = torch.tensor([[1., 0.] if el == 0 else [0., 1.] for el in labels_reny.tolist()]).to(
                    self.args.device)
                u = torch.cat([content, label_u.unsqueeze(0).float().repeat(4, 1, 1)],
                              dim=-1)  # 4 is for bid + 2 layers
                v = torch.cat([content, label_v.unsqueeze(0).float().repeat(4, 1, 1)], dim=-1)

                d_gamma_content_pred_u = self.d_gamma(u)  # - log
                d_gamma_content_pred_v = self.d_gamma(v)  # - log

                loss_gamma -= torch.mean(d_gamma_content_pred_u[:, 0]) / 2  # first term in (16)
                loss_gamma -= torch.mean(d_gamma_content_pred_v[:, 1]) / 2
                if self.training:
                    loss_gamma.backward()
                    gradient_reny += comput_gradient_norm(self.d_gamma)
                    torch.nn.utils.clip_grad_norm_(self.d_gamma.parameters(), self.args.max_grad_norm)
                    self.args.optimizer.step()
                    self.d_gamma.zero_grad()

        ###############################
        # Update \lambda * MI
        ###############################
        frozen_params(self.decoder)
        frozen_params(self.d_gamma)
        frozen_params(self.style_classifier_from_content)
        frozen_params(self.style_classifier_from_style)
        frozen_params(self.content_classifier_from_content)
        frozen_params(self.content_classifier_from_style)
        free_params(self.proj_style)
        free_params(self.proj_content)
        free_params(self.encoder)
        input_tensor_gen_golden = input_tensor.clone()
        input_tensor_gen = input_tensor.clone()
        if self.args.add_noise:
            input_tensor_gen = corrupt_input(self, input_tensor_gen)

        #########################
        ######## Genloss ########
        #########################
        encoder_hidden = self.encoder.initHidden()
        encoder_output, encoder_hidden = self.encoder(input_tensor_gen, encoder_hidden)

        # Style space
        style = self.proj_style(encoder_hidden)
        # Content space
        content = self.proj_content(encoder_hidden)

        # Concatenate style and other
        encoder_hidden = torch.cat([style, content], dim=-1)

        # Sentence Generation
        loss_gen = 0
        decoder_input = torch.ones(self.args.batch_size, 1).to(
            self.args.device).long() * self.args.tokenizer.sep_token_id
        decoder_hidden = encoder_hidden
        use_teacher_forcing = True if random.random() < teacher_ratio else False

        decoded_words = [decoder_input]

        if use_teacher_forcing and self.training:
            # Teacher forcing: Feed the target as the next input
            for di in range(self.args.max_length):
                decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
                loss_gen += self.loss(decoder_output.squeeze(1), input_tensor_gen_golden[:, di]) / self.args.max_length
                decoder_input = input_tensor_gen_golden[:, di].unsqueeze(-1)  # Teacher forcing

        else:
            # Without teacher forcing: use its own predictions as the next input
            for di in range(self.args.max_length):
                decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
                topv, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze().detach().unsqueeze(-1)  # detach from history as input
                decoded_words.append(topi.squeeze(-1))
                loss_gen += self.loss(decoder_output.squeeze(1), input_tensor_gen_golden[:, di]) / self.args.max_length

        #########################
        ######## Adv loss #######
        #########################
        encoder_hidden = self.encoder.initHidden()
        encoder_output, encoder_hidden = self.encoder(input_tensor_gen, encoder_hidden)

        content_mi = self.proj_content(encoder_hidden)
        style_content_pred_mi = self.style_classifier_from_content(content)
        if not self.args.no_minimization_of_mi_training:
            loss_h_s = - torch.mean(style_content_pred_mi[:, 0]) * (1 - torch.mean(labels.float()).item()) - torch.mean(
                style_content_pred_mi[:, 1]) * torch.mean(labels.float()).item()  # we use Jensen

            # Compute Reny
            label_content_pred = style_content_pred_mi.topk(1, dim=-1)[-1].squeeze(-1)
            label_v = torch.tensor([[1., 0.] if el == 0 else [0., 1.] for el in label_content_pred.tolist()]).to(
                self.args.device)
            label_u = torch.tensor([[1., 0.] if el == 0 else [0., 1.] for el in labels.tolist()]).to(self.args.device)
            u = torch.cat([content_mi, label_u.unsqueeze(0).float().repeat(4, 1, 1)], dim=-1)
            v = torch.cat([content_mi, label_v.unsqueeze(0).float().repeat(4, 1, 1)], dim=-1)
            d_gamma_content_pred_u = self.d_gamma(u)
            d_gamma_content_pred_v = self.d_gamma(v)
            R = torch.mean((torch.exp(d_gamma_content_pred_u[:, 0]) / torch.exp(d_gamma_content_pred_v[:, 1])) ** (
                    self.args.alpha - 1))

            # Remove biais from gradients
            reny = torch.abs(torch.log(R) / (self.args.alpha - 1))

            loss_h_sz = self.loss(style_content_pred_mi, labels)

            # MI
            if self.args.no_reny:
                loss_mi = self.args.mul_mi * torch.abs(loss_h_s - loss_h_sz)
            else:
                loss_mi = self.args.mul_mi * torch.abs(loss_h_s - loss_h_sz + reny)
        else:
            loss_mi = - self.lambda_style_content * self.loss(style_content_pred_mi, labels)

        content_content_pred = self.content_classifier_from_content(content)
        loss_content_classifier_from_content = self.lambda_content_content * self.content_loss(content_content_pred,
                                                                                               input_tensor_gen)

        style = self.proj_style(encoder_hidden)
        style_style_pred = self.style_classifier_from_style(style)
        loss_style_classifier_from_style = self.lambda_style_style * self.loss(style_style_pred, labels)

        style_content_pred = self.content_classifier_from_style(style)
        loss_content_classifier_from_style = self.lambda_content_style * self.content_loss(style_content_pred,
                                                                                           input_tensor_gen)

        loss_adv_seconde_stage = loss_content_classifier_from_content + loss_style_classifier_from_style - loss_content_classifier_from_style

        # Compute All Losses for MI
        if self.training:
            loss = loss_gen + loss_adv_seconde_stage + loss_mi
            loss.backward()
            gradient_encoder, gradient_content_proj, gradient_style_proj = comput_gradient_norm(
                self.encoder), comput_gradient_norm(self.proj_content), comput_gradient_norm(self.proj_style)
            torch.nn.utils.clip_grad_norm_(self.proj_style.parameters(), self.args.max_grad_norm)
            torch.nn.utils.clip_grad_norm_(self.proj_content.parameters(), self.args.max_grad_norm)
            torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), self.args.max_grad_norm)
            torch.nn.utils.clip_grad_norm_(self.decoder.parameters(), self.args.max_grad_norm)
            self.args.optimizer.step()
            self.args.scheduler.step()  # Update learning rate schedule
            self.proj_style.zero_grad()
            self.proj_content.zero_grad()
            self.encoder.zero_grad()

        decoded_words = torch.cat(decoded_words, dim=-1)  # memory ineficient but nicer
        losses_dic = {'loss_gen': loss_gen, 'loss_mi': loss_mi,
                      'loss_adv_first_stage': loss_adv_first_stage, 'loss_adv_seconde_stage': loss_adv_seconde_stage,
                      'gradient_encoder': gradient_encoder, 'gradient_decoder': gradient_decoder,
                      'loss_gen_reny': loss_gen_reny, "reny": reny,
                      'gradient_content_proj': gradient_content_proj, 'gradient_style_proj': gradient_style_proj,
                      'gradient_reny': gradient_reny / self.args.reny_training}
        return (losses_dic, decoded_words, input_tensor_gen)

    def predict_latent_space(self, input_tensor):
        encoder_hidden = self.encoder.initHidden()
        encoder_output, encoder_hidden = self.encoder(input_tensor, encoder_hidden)

        # Ensure Style not in content space
        content = self.proj_content(encoder_hidden)
        return content

    def predict_style_space(self, input_tensor, labels):
        encoder_hidden = self.encoder.initHidden()
        encoder_output, encoder_hidden = self.encoder(input_tensor, encoder_hidden)

        # Ensure Style not in content space
        style = self.proj_style(encoder_hidden)
        return style[:, labels == 0, :], style[:, labels == 1, :]

    def forward_transfert(self, inputs, labels_to_transfert, pos_style, neg_style):
        encoder_hidden = self.encoder.initHidden()
        encoder_output, encoder_hidden = self.encoder(inputs, encoder_hidden)
        # Style space
        style = []
        for label in labels_to_transfert.tolist():
            if label == 1:
                style.append(pos_style.unsqueeze(1))
            else:
                style.append(neg_style.unsqueeze(1))
        style = torch.cat(style, dim=1).to(self.args.device)
        # TODO : synthetic style
        # style = self.proj_style(labels_to_transfert.float().unsqueeze(1))
        # style = torch.cat([style.unsqueeze(0) for _ in range(4)])

        # Content space
        content = self.proj_content(encoder_hidden)

        # Concatenate style and other
        encoder_hidden = torch.cat([style, content], dim=-1)

        # Sentence Generation
        loss_gen_reny = 0
        decoder_input = torch.ones(self.args.batch_size, 1).to(
            self.args.device).long() * self.args.tokenizer.sep_token_id
        decoder_hidden = encoder_hidden

        decoded_words = [decoder_input]

        for di in range(self.args.max_length):
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach().unsqueeze(-1)  # detach from history as input
            decoded_words.append(topi.squeeze(-1))
        decoded_words = torch.cat(decoded_words, dim=-1)  # clean data
        return decoded_words
