import torch
from torchvision.transforms import Normalize

from src.utils.style_gan_xl.diff_augment import diff_augment


def get_feature(
        x: torch.Tensor,
        feat: torch.nn.Module,
        brightness: torch.Tensor,
        saturation: torch.Tensor,
        contrast: torch.Tensor,
        translation_x: torch.Tensor,
        translation_y: torch.Tensor,
        offset_x: torch.Tensor,
        offset_y: torch.Tensor
) -> dict[str, torch.Tensor]:
    x_aug_: torch.Tensor = x
    if brightness.shape[0] > 0:
        x_aug_: torch.Tensor = diff_augment(
            x[:brightness.shape[0]], brightness, saturation, contrast, translation_x, translation_y, offset_x, offset_y)
        x_aug_: torch.Tensor = torch.cat((x_aug_, x[brightness.shape[0]:]))
    x_aug: torch.Tensor = x_aug_.add(1).div(2)
    x_n: torch.Tensor = Normalize(feat.normstats['mean'], feat.normstats['std'])(x_aug)
    if x.shape[-2] < 256:
        x_n: torch.Tensor = torch.nn.functional.interpolate(x_n, 224, mode='bilinear', align_corners=False)
    x_features: dict[str, torch.Tensor] = feat(x_n)
    return x_features


def calculate_adaptive_weight(
        loss1: torch.Tensor,
        loss2: torch.Tensor,
        last_layer: torch.nn.Parameter = None
) -> torch.Tensor:
    loss1_grad: torch.Tensor = torch.autograd.grad(loss1, last_layer, retain_graph=True)[0]
    loss2_grad: torch.Tensor = torch.autograd.grad(loss2, last_layer, retain_graph=True)[0]
    d_weight: torch.Tensor = torch.norm(loss1_grad) / (torch.norm(loss2_grad) + 1e-4)
    d_weight: torch.Tensor = torch.clamp(d_weight, 0.0, 1e4).detach()
    return d_weight
