from typing import Any

import torch

from external.stylegan_xl.pg_modules.discriminator import MultiScaleD
from external.stylegan_xl.pg_modules.projector import F_RandomProj
from src.utils.style_gan_xl.utils import get_feature


def create_style_gan_xl_discriminator_and_discriminator_feature_extractor(
        image_size: int, conditional: bool, c_out: int = 64) -> tuple[torch.nn.ModuleDict, torch.nn.ModuleDict]:
    backbones: list[str] = ['deit_base_distilled_patch16_224', 'tf_efficientnet_lite0']
    discriminator_feature_extractor_dict: dict[str, F_RandomProj] = {}
    discriminator_dict: dict[str, MultiScaleD] = {}
    backbone_kwargs: dict[str, Any] = {'im_res': image_size, 'cout': c_out}
    for i, bb_name in enumerate(backbones):
        feat: F_RandomProj = F_RandomProj(bb_name, **backbone_kwargs)
        disc: MultiScaleD = MultiScaleD(
            cond=1 if conditional else 0,
            channels=feat.CHANNELS,
            resolutions=feat.RESOLUTIONS,
            **backbone_kwargs,
        )
        discriminator_feature_extractor_dict[bb_name] = feat
        discriminator_dict[bb_name] = disc
    discriminator_feature_extractor: torch.nn.ModuleDict = torch.nn.ModuleDict(discriminator_feature_extractor_dict)
    discriminator: torch.nn.ModuleDict = torch.nn.ModuleDict(discriminator_dict)
    return discriminator, discriminator_feature_extractor


class StyleGANXLDiscriminatorFeatureExtractor(torch.nn.Module):
    def __init__(self, discriminator_feature_extractor: torch.nn.ModuleDict) -> None:
        super().__init__()
        self.discriminator_feature_extractor: torch.nn.ModuleDict = discriminator_feature_extractor

    def forward(
            self,
            x: torch.Tensor,
            requires_feature_grad: bool,
            prob_aug: float = 1.0,
            shift_ratio: float = 0.125,
            cutout_ratio: float = 0.2
    ) -> list[dict[str, torch.Tensor]]:
        features: list[dict[str, torch.Tensor]] = []
        for bb_name, feat in self.discriminator_feature_extractor.items():
            brightness: torch.Tensor = \
                torch.rand(int(x.size(0) * prob_aug), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5
            saturation: torch.Tensor = \
                torch.rand(int(x.size(0) * prob_aug), 1, 1, 1, dtype=x.dtype, device=x.device) * 2
            contrast: torch.Tensor = \
                torch.rand(int(x.size(0) * prob_aug), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5
            shift_x, shift_y = int(x.size(2) * shift_ratio + 0.5), int(x.size(3) * shift_ratio + 0.5)
            translation_x: torch.Tensor = torch.randint(
                -shift_x, shift_x + 1, size=[int(x.size(0) * prob_aug), 1, 1], device=x.device)
            translation_y: torch.Tensor = torch.randint(
                -shift_y, shift_y + 1, size=[int(x.size(0) * prob_aug), 1, 1], device=x.device)
            cutout_size: tuple[int, int] = \
                int(x.size(2) * cutout_ratio + 0.5), int(x.size(3) * cutout_ratio + 0.5)
            offset_x: torch.Tensor = torch.randint(
                0, x.size(2) + (1 - cutout_size[0] % 2), size=[int(x.size(0) * prob_aug), 1, 1], device=x.device)
            offset_y: torch.Tensor = torch.randint(
                0, x.size(3) + (1 - cutout_size[1] % 2), size=[int(x.size(0) * prob_aug), 1, 1], device=x.device)

            if requires_feature_grad:
                feature: dict[str, torch.Tensor] = get_feature(
                    x, feat, brightness, saturation, contrast, translation_x, translation_y, offset_x, offset_y)
                features.append(feature)
            else:
                with torch.no_grad():
                    feature: dict[str, torch.Tensor] = get_feature(
                        x, feat, brightness, saturation, contrast, translation_x, translation_y, offset_x, offset_y)
                    features.append(feature)

        return features


class StyleGANXLDiscriminator(torch.nn.Module):
    def __init__(self, discriminator: torch.nn.ModuleDict) -> None:
        super().__init__()
        self.discriminator: torch.nn.ModuleDict = discriminator

    def get_xl_logit(
            self,
            features: list[dict[str, torch.Tensor]],
            label: torch.Tensor,
    ) -> list[torch.Tensor]:
        logit: list[torch.Tensor] = []
        for (bb_name, feat), feature in zip(self.discriminator.items(), features):
            logit += self.discriminator[bb_name](feature, label)
        return logit

    def forward(
            self,
            features: list[dict[str, torch.Tensor]],
            label: torch.Tensor
    ) -> list[torch.Tensor]:
        return self.get_xl_logit(
            features=features,
            label=label
        )
