import torch
import torch.nn as nn

from ....util import default, instantiate_from_config
from ..lpips.loss.lpips import LPIPS


class LatentLPIPS(nn.Module):
    def __init__(
        self,
        decoder_config,
        perceptual_weight=1.0,
        latent_weight=1.0,
        scale_input_to_tgt_size=False,
        scale_tgt_to_input_size=False,
        perceptual_weight_on_inputs=0.0,
    ):
        super().__init__()
        self.scale_input_to_tgt_size = scale_input_to_tgt_size
        self.scale_tgt_to_input_size = scale_tgt_to_input_size
        self.init_decoder(decoder_config)
        self.perceptual_loss = LPIPS().eval()
        self.perceptual_weight = perceptual_weight
        self.latent_weight = latent_weight
        self.perceptual_weight_on_inputs = perceptual_weight_on_inputs

    def init_decoder(self, config):
        self.decoder = instantiate_from_config(config)
        if hasattr(self.decoder, "encoder"):
            del self.decoder.encoder

    def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
        log = dict()
        loss = (latent_inputs - latent_predictions) ** 2
        log[f"{split}/latent_l2_loss"] = loss.mean().detach()
        image_reconstructions = None
        if self.perceptual_weight > 0.0:
            image_reconstructions = self.decoder.decode(latent_predictions)
            image_targets = self.decoder.decode(latent_inputs)
            perceptual_loss = self.perceptual_loss(
                image_targets.contiguous(), image_reconstructions.contiguous()
            )
            loss = (
                self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean()
            )
            log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()

        if self.perceptual_weight_on_inputs > 0.0:
            image_reconstructions = default(
                image_reconstructions, self.decoder.decode(latent_predictions)
            )
            if self.scale_input_to_tgt_size:
                image_inputs = torch.nn.functional.interpolate(
                    image_inputs,
                    image_reconstructions.shape[2:],
                    mode="bicubic",
                    antialias=True,
                )
            elif self.scale_tgt_to_input_size:
                image_reconstructions = torch.nn.functional.interpolate(
                    image_reconstructions,
                    image_inputs.shape[2:],
                    mode="bicubic",
                    antialias=True,
                )

            perceptual_loss2 = self.perceptual_loss(
                image_inputs.contiguous(), image_reconstructions.contiguous()
            )
            loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
            log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
        return loss, log
