import torch
import torch.nn as nn


class DPPLoss(nn.Module):
    """ DPP Loss """

    def __init__(self, preprocess_config, model_config):
        super(DPPLoss, self).__init__()
        self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][
            "feature"
        ]
        self.energy_feature_level = preprocess_config["preprocessing"]["energy"][
            "feature"
        ]
        self.mse_loss = nn.MSELoss()
        self.mae_loss = nn.L1Loss()


    def forward(self, inputs, predictions):
        (
            mel_targets,
            _,
            _,
            pitch_targets,
            energy_targets,
            duration_targets,
        ) = inputs[6:]
        (
            mel_predictions,
            pitch_predictions,
            energy_predictions,
            l_length,
            _,
            src_masks,
            mel_masks,
            _,
            _,
        ) = predictions
        src_masks = ~src_masks
        mel_masks = ~mel_masks
        log_duration_targets = torch.log(duration_targets.float() + 1)
        mel_targets = mel_targets[:, : mel_masks.shape[1], :]
        mel_masks = mel_masks[:, :mel_masks.shape[1]]

        log_duration_targets.requires_grad = False 
        pitch_targets.requires_grad = False
        energy_targets.requires_grad = False 
        mel_targets.requires_grad = False

    
        energy_predictions = energy_predictions.masked_select(src_masks)
        energy_targets = energy_targets.masked_select(src_masks)
        mel_predictions = mel_predictions.masked_select(mel_masks.unsqueeze(-1))
        mel_targets = mel_targets.masked_select(mel_masks.unsqueeze(-1))

        mel_loss = self.mae_loss(mel_predictions.float(), mel_targets)
        pitch_loss =  torch.sum(pitch_predictions.float()) 
        energy_loss = self.mse_loss(energy_predictions.float(), energy_targets)
        duration_loss = torch.sum(l_length.float())

        total_loss = (
            mel_loss + duration_loss + pitch_loss + energy_loss
        )

        return (
            total_loss,
            mel_loss,
            pitch_loss,
            energy_loss,
            duration_loss
        )
