import torch
# import lpips
import torch.nn as nn


class FastSpeech2Loss(nn.Module):
    def __init__(self, config):
        super(FastSpeech2Loss, self).__init__()
        self.mae_loss = nn.L1Loss()
        self.mse_loss = nn.MSELoss()
        self.n_mels = config.n_mel_channels
        self.n_egemap_features = config.n_egemap_features
        # self.lpips_loss_scale = config.lpips_loss_scale
        # self.lpips_loss = lpips.LPIPS(net=config.lpips_net) if config.lpips_loss_scale > 0 else None

    # @staticmethod
    # def _prepare_input_for_lpips_loss(mel_input: torch.Tensor, mel_mask: torch.Tensor) -> torch.Tensor:
    #     # norm values as lpips works better this way
    #     normed_mel = 2 * (mel_input - torch.min(mel_input) / (torch.max(mel_input) - torch.min(mel_input)) - 1)
    #     normed_mel[~mel_mask] = 0
    #     return normed_mel.unsqueeze(1)

    def forward(self, device: torch.device, inputs: dict, predictions: dict, compute_mel_loss: bool = True) -> dict:
        phone_masks = ~predictions["phone_masks"].to(device)
        mel_masks = ~predictions["mel_masks"][:, :]
        mel_predictions = predictions["predicted_mel"]
        # postnet_mel_predictions = predictions["post_net_predicted_mel"]
        pitch_predictions = predictions["predicted_pitch"].masked_select(phone_masks)
        energy_predictions = predictions["predicted_energy"].masked_select(phone_masks)
        log_duration_predictions = predictions["predicted_log_durations"].masked_select(phone_masks)
        egemap_predictions = predictions["predicted_egemap"]
        mel_targets = inputs["mels"].detach()
        pitch_targets = inputs["pitches"].detach().masked_select(phone_masks)
        energy_targets = inputs["energies"].detach().masked_select(phone_masks)
        log_duration_targets = torch.log(inputs["durations"].float() + 1).detach().masked_select(phone_masks)
        egemap_targets = inputs["egemap_features"]

        pitch_loss = self.mse_loss(pitch_predictions, pitch_targets)
        energy_loss = self.mse_loss(energy_predictions, energy_targets)
        duration_loss = self.mse_loss(log_duration_predictions, log_duration_targets)

        losses_dict = {"pitch_loss": pitch_loss, "energy_loss": energy_loss, "duration_loss": duration_loss}

        if egemap_targets is not None:
            egemap_loss = self.mse_loss(egemap_predictions, egemap_targets)
            losses_dict["egemap_loss"] = egemap_loss
        else:
            egemap_loss = torch.FloatTensor([0]).detach().to(device)

        if not compute_mel_loss:
            return losses_dict

        # if self.lpips_loss:
        #     # reshape mask to 3d size -> normalize mels to [-1, 1] -> change pad values to 0
        #     mask3d = mel_masks.unsqueeze(2).expand(-1, -1, self.n_mels)
        #     mel_predicted_lpips = self._prepare_input_for_lpips_loss(mel_predictions, mask3d)
        #     mel_target_lpips = self._prepare_input_for_lpips_loss(mel_targets, mask3d)
        #     lpips_loss = torch.mean(self.lpips_loss(mel_predicted_lpips, mel_target_lpips)) * self.lpips_loss_scale
        #     losses_dict["lpips_loss"] = lpips_loss
        # else:
        #     lpips_loss = torch.FloatTensor([0]).detach().to(device)

        mel_predictions = mel_predictions.masked_select(mel_masks.unsqueeze(-1)) # b, t, 1 -> b, t, c
        # postnet_mel_predictions = postnet_mel_predictions.masked_select(mel_masks.unsqueeze(-1))
        mel_targets = mel_targets.masked_select(mel_masks.unsqueeze(-1))
        mel_loss = self.mae_loss(mel_predictions, mel_targets)
        # postnet_mel_loss = self.mae_loss(postnet_mel_predictions, mel_targets)
        total_loss = mel_loss + duration_loss + pitch_loss + energy_loss + egemap_loss # + postnet_mel_loss + lpips_loss
        losses_dict["total_loss"] = total_loss
        losses_dict["mel_loss"] = mel_loss
        # losses_dict["postnet_mel_loss"] = postnet_mel_loss

        return losses_dict
