# file: prism/losses/losses.py
from abc import ABC, abstractmethod

import torch
import torch.nn as nn
from torchvision.models.vgg import vgg16, vgg19, VGG16_Weights, VGG19_Weights

from prism.core.base_objects import BaseLoss
from prism.core.registry import LOSSES


@LOSSES.register("mse")
class MSELoss(BaseLoss):
    def __init__(self, config):
        super().__init__(config)
        self.loss = nn.MSELoss()

    def forward(self, input, target):
        return self.loss(input, target)


@LOSSES.register("l1")
class L1Loss(BaseLoss):
    def __init__(self, config):
        super().__init__(config)
        self.loss = nn.L1Loss()

    def forward(self, input, target):
        return self.loss(input, target)


@LOSSES.register("bce")
class BCEWithLogitsLoss(BaseLoss):
    def __init__(self, config):
        super().__init__(config)
        self.loss = nn.BCEWithLogitsLoss()

    def forward(self, input, target_is_real):
        target = torch.ones_like(input) if target_is_real else torch.zeros_like(input)
        return self.loss(input, target)


@LOSSES.register("softplus")
class SoftplusLoss(BaseLoss):
    def __init__(self, config):
        super().__init__(config)

    def forward(self, input, target_is_real):
        if target_is_real:
            return torch.nn.functional.softplus(-input).mean()
        else:
            return torch.nn.functional.softplus(input).mean()


class _BaseVGGLoss(BaseLoss, ABC):
    def __init__(self, config):
        super().__init__(config)
        vgg_cfg = self.config.loss.vgg

        self.input_is_tanh = getattr(vgg_cfg, "input_is_tanh", True)
        loss_type = getattr(vgg_cfg, "loss_type", "l1").lower()
        model_type = getattr(vgg_cfg, "model_type", "vgg19").lower()

        supported_models = {
            'vgg16': (vgg16, VGG16_Weights.DEFAULT),
            'vgg19': (vgg19, VGG19_Weights.DEFAULT)
        }
        if model_type not in supported_models:
            raise ValueError(f"Unsupported model_type: {model_type}. Choose from {list(supported_models.keys())}")

        model_fn, weights = supported_models[model_type]
        vgg_features = model_fn(weights=weights).features

        self.vgg_features = vgg_features
        self.vgg_features.eval()
        self.vgg_features.requires_grad_(False)

        mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
        self.register_buffer("mean", mean)
        self.register_buffer("std", std)

        self.crit = nn.L1Loss() if loss_type == "l1" else nn.MSELoss()

    def _preprocess(self, x):
        if self.input_is_tanh:
            x = (x + 1.0) / 2.0
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1)
        return (x - self.mean) / self.std

    @abstractmethod
    def forward(self, input, target):
        raise NotImplementedError


@LOSSES.register("vgg_single_layer")
class SingleLayerVGGLoss(_BaseVGGLoss):
    def __init__(self, config):
        super().__init__(config)
        vgg_cfg = self.config.loss.vgg
        self.layer_idx = vgg_cfg.layer

        # Truncate the feature extractor to the desired layer
        self.vgg_features = self.vgg_features[:self.layer_idx + 1]

    def forward(self, input, target):
        input = input.to(self.mean.device)
        target = target.to(self.mean.device)

        batch = torch.cat([self._preprocess(input), self._preprocess(target)])
        features = self.vgg_features(batch)

        sep = input.shape[0]
        input_feats = features[:sep]
        target_feats = features[sep:]

        return self.crit(input_feats, target_feats.detach())


@LOSSES.register("vgg_multi_layer")
class MultiLayerVGGLoss(_BaseVGGLoss):
    def __init__(self, config):
        super().__init__(config)
        vgg_cfg = self.config.loss.vgg
        self.layer_weights = vgg_cfg.layers_and_weights
        self.max_layer = max(self.layer_weights.keys())

    def _extract_features(self, x):
        feats = {}
        out = x
        for idx, module in enumerate(self.vgg_features):
            out = module(out)
            if idx in self.layer_weights:
                feats[idx] = out
            if idx >= self.max_layer:
                break
        return feats

    def forward(self, input, target):
        input = input.to(self.mean.device)
        target = target.to(self.mean.device)

        in_n = self._preprocess(input)
        tgt_n = self._preprocess(target)

        in_feats = self._extract_features(in_n)
        with torch.no_grad():
            tgt_feats = self._extract_features(tgt_n)

        loss = 0.0
        for layer_idx, weight in self.layer_weights.items():
            loss += weight * self.crit(in_feats[layer_idx], tgt_feats[layer_idx])
        return loss