import torch

from src.configs.dataset import DatasetType
from src.utils.style_gan_xl.utils import calculate_adaptive_weight
from src.utils.utils import extract_model


class StyleGANXLGeneratorLoss:
    def __init__(
            self,
            model: torch.nn.Module,
            discriminator_feature_extractor: torch.nn.Module,
            discriminator: torch.nn.Module,
            reconstruction_loss: torch.nn.Module,
            dataset_type: DatasetType,
            use_adaptive_weight: bool = True
    ) -> None:
        self.model: torch.nn.Module = model
        self.discriminator_feature_extractor: torch.nn.Module = discriminator_feature_extractor
        self.discriminator: torch.nn.Module = discriminator
        self.reconstruction_loss: torch.nn.Module = reconstruction_loss
        self.dataset_type: DatasetType = dataset_type
        self.use_adaptive_weight: bool = use_adaptive_weight

    def get_generator_loss(
            self,
            fake: torch.Tensor,
            target: torch.Tensor,
            label: torch.Tensor = None,
            prob_aug: float = 1.0,
            shift_ratio: float = 0.125,
            cutout_ratio: float = 0.2
    ) -> dict[str, torch.Tensor]:
        reconstruction_loss_value: torch.Tensor = self.reconstruction_loss(fake, target).mean()
        features: list[dict[str, torch.Tensor]] = self.discriminator_feature_extractor(
            x=fake,
            requires_feature_grad=True,
            prob_aug=prob_aug,
            shift_ratio=shift_ratio,
            cutout_ratio=cutout_ratio
        )
        logit_fake: list[torch.Tensor] = self.discriminator(
            features=features,
            label=label
        )
        model_gan_loss: torch.Tensor = torch.mean(torch.stack([(-logit).mean() for logit in logit_fake]))
        if self.use_adaptive_weight:
            if self.dataset_type == DatasetType.CIFAR10_32_32:
                d_weight: torch.Tensor = calculate_adaptive_weight(
                    reconstruction_loss_value, model_gan_loss,
                    last_layer=extract_model(self.model).model.dec['32x32_aux_conv'].weight
                )
            elif self.dataset_type == DatasetType.ImageNet_64_64:
                d_weight: torch.Tensor = calculate_adaptive_weight(
                    reconstruction_loss_value, model_gan_loss,
                    last_layer=extract_model(self.model).model.dec['64x64_block3'].conv1.weight
                )
            elif self.dataset_type == DatasetType.AFHQv2_64_64:
                d_weight: torch.Tensor = calculate_adaptive_weight(
                    reconstruction_loss_value, model_gan_loss,
                    last_layer=extract_model(self.model).model.dec['64x64_aux_conv'].weight
                )
            else:
                raise ValueError(f'unknown dataset type: {self.dataset_type}')
            adaptive_weight: torch.Tensor = torch.clip(d_weight, 0.01, 10.)
        else:
            adaptive_weight: torch.Tensor = torch.tensor(1.)
        return {
            'reconstruction_loss': reconstruction_loss_value,
            'model_gan_loss': model_gan_loss,
            'adaptive_weight': adaptive_weight,
        }


class StyleGANXLDiscriminatorLoss:
    def __init__(
            self,
            discriminator_feature_extractor: torch.nn.Module,
            discriminator: torch.nn.Module
    ) -> None:
        self.discriminator_feature_extractor: torch.nn.Module = discriminator_feature_extractor
        self.discriminator: torch.nn.Module = discriminator

    def get_discriminator_loss_fake(
            self,
            fake: torch.Tensor,
            fake_label: torch.Tensor = None,
            prob_aug: float = 1.0,
            shift_ratio: float = 0.125,
            cutout_ratio: float = 0.2
    ) -> torch.Tensor:
        features: list[dict[str, torch.Tensor]] = self.discriminator_feature_extractor(
            x=fake,
            requires_feature_grad=False,
            prob_aug=prob_aug,
            shift_ratio=shift_ratio,
            cutout_ratio=cutout_ratio
        )
        logit_fake: list[torch.Tensor] = self.discriminator(
            features=features,
            label=fake_label
        )
        return torch.mean(torch.stack([
            (torch.nn.functional.relu(torch.ones_like(logit) + logit)).mean() for logit in logit_fake
        ]))

    def get_discriminator_loss_real(
            self,
            real: torch.Tensor,
            real_label: torch.Tensor = None,
            prob_aug: float = 1.0,
            shift_ratio: float = 0.125,
            cutout_ratio: float = 0.2
    ) -> torch.Tensor:
        features: list[dict[str, torch.Tensor]] = self.discriminator_feature_extractor(
            x=real,
            requires_feature_grad=False,
            prob_aug=prob_aug,
            shift_ratio=shift_ratio,
            cutout_ratio=cutout_ratio
        )
        logit_real: list[torch.Tensor] = self.discriminator(
            features=features,
            label=real_label
        )
        return torch.mean(torch.stack([
            (torch.nn.functional.relu(torch.ones_like(logit) - logit)).mean() for logit in logit_real
        ]))

    def get_discriminator_loss(
            self,
            real: torch.Tensor,
            fake: torch.Tensor,
            real_label: torch.Tensor = None,
            fake_label: torch.Tensor = None,
            prob_aug: float = 1.0,
            shift_ratio: float = 0.125,
            cutout_ratio: float = 0.2
    ) -> dict[str, torch.Tensor]:
        loss_fake: torch.Tensor = self.get_discriminator_loss_fake(
            fake=fake,
            fake_label=fake_label,
            prob_aug=prob_aug,
            shift_ratio=shift_ratio,
            cutout_ratio=cutout_ratio
        )
        loss_real: torch.Tensor = self.get_discriminator_loss_real(
            real=real,
            real_label=real_label,
            prob_aug=prob_aug,
            shift_ratio=shift_ratio,
            cutout_ratio=cutout_ratio
        )
        return {
            'fake_loss': loss_fake,
            'real_loss': loss_real
        }
