import torch
from torchvision import transforms
from datasets import load_dataset

from PIL import Image, ImageFilter
import requests
import os
from io import BytesIO
import random
import numpy as np
import copy
from typing import Any, Mapping
import json

import matplotlib
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import seaborn as sns


def read_json(filename: str) -> Mapping[str, Any]:
    """Returns a Python dict representation of JSON object at input file."""
    with open(filename) as fp:
        return json.load(fp)
    

def set_random_seed(seed=0):
    torch.manual_seed(seed + 0)
    torch.cuda.manual_seed(seed + 1)
    torch.cuda.manual_seed_all(seed + 2)
    np.random.seed(seed + 3)
    torch.cuda.manual_seed_all(seed + 4)
    random.seed(seed + 5)


def download_image(url):
    try:
        response = requests.get(url)
    except:
        return None
    return Image.open(BytesIO(response.content)).convert("RGB")


def transform_img(image, target_size=512):
    tform = transforms.Compose(
        [
            transforms.Resize(target_size),
            transforms.CenterCrop(target_size),
            transforms.ToTensor(),
        ]
    )
    image = tform(image)
    return 2.0 * image - 1.0


def latents_to_imgs(pipe, latents):
    x = pipe.decode_image(latents)
    x = pipe.torch_to_numpy(x)
    x = pipe.numpy_to_pil(x)
    return x


def channelwise_heatmap(heatmap):
    heatmap = heatmap.detach().cpu().numpy()

    if len(heatmap[0]) >= 0:
        curr_heatmap = heatmap[0][0]
        # c_min = abs(curr_heatmap.min().item())
        # c_max = abs(curr_heatmap.max().item())
        # c_offset = max(c_min, c_max)
        # sns.heatmap(curr_heatmap, cmap=sns.diverging_palette(220, 20, as_cmap=True), xticklabels=False, yticklabels=False, square=True, cbar=True, cbar_kws={'shrink': 0.25}, vmin=-c_offset, vmax=c_offset)
        # sns.heatmap(curr_heatmap, cmap=sns.diverging_palette(220, 20, as_cmap=True), xticklabels=False, yticklabels=False, square=True, cbar=True, cbar_kws={'shrink': 0.25})

        c_min = abs(curr_heatmap.min().item())
        c_max = abs(curr_heatmap.max().item())
        c_offset = max(c_min, c_max)
        custom_palette = 'YlGnBu'
        sns.heatmap(curr_heatmap, cmap=custom_palette, xticklabels=False, yticklabels=False, square=True, cbar=False, cbar_kws={'shrink': 0.25}, vmin=-c_offset, vmax=c_offset)
    else:
        fig, ax = plt.subplots(nrows=1, ncols=len(heatmap[0]), figsize=(18, 14))
        for i in range(len(heatmap[0])):
            curr_heatmap = heatmap[0][i]
            # c_min = abs(curr_heatmap.min().item())
            # c_max = abs(curr_heatmap.max().item())
            # c_offset = max(c_min, c_max)
            # sns.heatmap(curr_heatmap, cmap=sns.diverging_palette(220, 20, as_cmap=True), xticklabels=False, yticklabels=False, square=True, cbar=True, ax=ax[i], cbar_kws={'shrink': 0.25}, vmin=-c_offset, vmax=c_offset)
            sns.heatmap(curr_heatmap, cmap=sns.diverging_palette(220, 20, as_cmap=True), xticklabels=False, yticklabels=False, square=True, cbar=True, ax=ax[i], cbar_kws={'shrink': 0.25})

        # sns.heatmap(heatmap[0][0], cmap=sns.diverging_palette(220, 20, as_cmap=True), xticklabels=False, yticklabels=False, square=True, cbar=False, ax=ax[0])
        # sns.heatmap(heatmap[0][1], cmap=sns.diverging_palette(220, 20, as_cmap=True), xticklabels=False, yticklabels=False, square=True, cbar=False, ax=ax[1])
        # sns.heatmap(heatmap[0][2], cmap=sns.diverging_palette(220, 20, as_cmap=True), xticklabels=False, yticklabels=False, square=True, cbar=False, ax=ax[2])
        # sns.heatmap(heatmap[0][3], cmap=sns.diverging_palette(220, 20, as_cmap=True), xticklabels=False, yticklabels=False, square=True, cbar=False, ax=ax[3])

        # sns.heatmap(heatmap[0][0], cmap=sns.diverging_palette(220, 20, as_cmap=True), xticklabels=False, yticklabels=False, square=True, cbar=True, ax=ax[0], cbar_kws={'shrink': 0.25})
        # sns.heatmap(heatmap[0][1], cmap=sns.diverging_palette(220, 20, as_cmap=True), xticklabels=False, yticklabels=False, square=True, cbar=True, ax=ax[1], cbar_kws={'shrink': 0.25})
        # sns.heatmap(heatmap[0][2], cmap=sns.diverging_palette(220, 20, as_cmap=True), xticklabels=False, yticklabels=False, square=True, cbar=True, ax=ax[2], cbar_kws={'shrink': 0.25})
        # sns.heatmap(heatmap[0][3], cmap=sns.diverging_palette(220, 20, as_cmap=True), xticklabels=False, yticklabels=False, square=True, cbar=True, ax=ax[3], cbar_kws={'shrink': 0.25})

    plt.show()


def image_distortion(img1, img2, w_mask, seed, args):
    if args.r_degree is not None:
        img1 = transforms.RandomRotation((args.r_degree, args.r_degree))(img1)
        img2 = transforms.RandomRotation((args.r_degree, args.r_degree))(img2)

        if w_mask is not None:
            w_mask = transforms.RandomRotation((args.r_degree, args.r_degree))(w_mask)

    if args.jpeg_ratio is not None:
        img1.save(f"tmp_{args.jpeg_ratio}_{args.run_name}.jpg", quality=args.jpeg_ratio)
        img1 = Image.open(f"tmp_{args.jpeg_ratio}_{args.run_name}.jpg")
        img2.save(f"tmp_{args.jpeg_ratio}_{args.run_name}.jpg", quality=args.jpeg_ratio)
        img2 = Image.open(f"tmp_{args.jpeg_ratio}_{args.run_name}.jpg")

    if args.crop_scale is not None and args.crop_ratio is not None:
        set_random_seed(seed)
        img1 = transforms.RandomResizedCrop(img1.size, scale=(args.crop_scale, args.crop_scale), ratio=(args.crop_ratio, args.crop_ratio))(img1)
        set_random_seed(seed)
        img2 = transforms.RandomResizedCrop(img2.size, scale=(args.crop_scale, args.crop_scale), ratio=(args.crop_ratio, args.crop_ratio))(img2)

        if w_mask is not None:
            set_random_seed(seed)
            w_mask = transforms.RandomResizedCrop(w_mask.shape[-2:], scale=(args.crop_scale, args.crop_scale), ratio=(args.crop_ratio, args.crop_ratio))(w_mask)

    if args.gaussian_blur_r is not None:
        img1 = img1.filter(ImageFilter.GaussianBlur(radius=args.gaussian_blur_r))
        img2 = img2.filter(ImageFilter.GaussianBlur(radius=args.gaussian_blur_r))

    if args.gaussian_std is not None:
        img_shape = np.array(img1).shape
        g_noise = np.random.normal(0, args.gaussian_std, img_shape) * 255
        g_noise = g_noise.astype(np.uint8)
        img1 = Image.fromarray(np.clip(np.array(img1) + g_noise, 0, 255))
        img2 = Image.fromarray(np.clip(np.array(img2) + g_noise, 0, 255))

    if args.flip is True:
        img1 = transforms.RandomHorizontalFlip(p=1)(img1)
        img2 = transforms.RandomHorizontalFlip(p=1)(img2)

        if w_mask is not None:
            w_mask = transforms.RandomHorizontalFlip(p=1)(w_mask)
    
    if args.brightness_factor is not None:
        img1 = transforms.ColorJitter(brightness=args.brightness_factor)(img1)
        img2 = transforms.ColorJitter(brightness=args.brightness_factor)(img2)

    if args.rand_aug >= 1:
        args.rand_seed = seed
        all_augs = [aug_rotate, aug_jpeg, aug_crop, aug_blur, aug_gaussian, aug_color]
        random.shuffle(all_augs)
        
        for curr_aug in all_augs[:args.rand_aug]:
            img1, img2 = curr_aug(img1, img2, args)

    return img1, img2, w_mask


### for rand_aug
def aug_rotate(img1, img2, args):
    img1 = transforms.RandomRotation((75, 75))(img1)
    img2 = transforms.RandomRotation((75, 75))(img2)

    return img1, img2

def aug_jpeg(img1, img2, args):
    img1.save(f"tmp_{25}_{args.run_name}.jpg", quality=25)
    img1 = Image.open(f"tmp_{25}_{args.run_name}.jpg")
    img2.save(f"tmp_{25}_{args.run_name}.jpg", quality=25)
    img2 = Image.open(f"tmp_{25}_{args.run_name}.jpg")

    return img1, img2

def aug_crop(img1, img2, args):
    set_random_seed(args.rand_seed)
    img1 = transforms.RandomResizedCrop(img1.size, scale=(0.75, 0.75), ratio=(0.75, 0.75))(img1)
    set_random_seed(args.rand_seed)
    img2 = transforms.RandomResizedCrop(img2.size, scale=(0.75, 0.75), ratio=(0.75, 0.75))(img2)

    return img1, img2

def aug_blur(img1, img2, args):
    img1 = img1.filter(ImageFilter.GaussianBlur(radius=4))
    img2 = img2.filter(ImageFilter.GaussianBlur(radius=4))

    return img1, img2

def aug_gaussian(img1, img2, args):
    img_shape = np.array(img1).shape
    g_noise = np.random.normal(0, 0.1, img_shape) * 255
    g_noise = g_noise.astype(np.uint8)
    img1 = Image.fromarray(np.clip(np.array(img1) + g_noise, 0, 255))
    img2 = Image.fromarray(np.clip(np.array(img2) + g_noise, 0, 255))

    return img1, img2

def aug_color(img1, img2, args):
    img1 = transforms.ColorJitter(brightness=6)(img1)
    img2 = transforms.ColorJitter(brightness=6)(img2)

    return img1, img2
### for rand_aug


def parity_watermark(latent, gamma=0.5, place=1):
    latent = copy.deepcopy(latent)
    ori_shape = latent.shape
    latent = latent.flatten()

    for i in range(len(latent)):
        x = latent[i].item()
        x_s = f'%.{place+1}f' % x

        pos = x_s.find('.')
        pos = pos + place
        x_target = int(x_s[pos])

        if random.random() < gamma:
            # to even
            if x_target % 2 != 0:
                # radomly +-1
                if random.random() < 0.5:
                    x_target = x_target-1
                else:
                    x_target = x_target+1
        else:
            # to odd
            if x_target % 2 == 0:
                if random.random() < 0.5:
                    x_target = x_target-1
                else:
                    x_target = x_target+1

        x_target = x_target % 10
        x_final = x_s[:pos] + str(x_target) + x_s[pos+1:]
        latent[i] = float(x_final)

    latent = latent.reshape(ori_shape)
    return latent

def check_parity(x, place=1):
    x_s = f'%.{place+1}f' % x

    pos = x_s.find('.')
    pos = pos + place
    x_target = int(x_s[pos])

    if x_target % 2 == 0:
        return 1
    else:
        return -1


def get_parity_map(latent, place=1):
    ori_shape = latent.shape
    latent = latent.flatten()
    parity_map = torch.zeros(latent.shape).to(latent.device)

    for i in range(len(latent)):
        parity_map[i] = check_parity(latent[i].item(), place=place)

    parity_map = parity_map.reshape(ori_shape)
    return parity_map


# for one prompt to multiple images
def measure_similarity(images, prompt, model, clip_preprocess, tokenizer, device):
    with torch.no_grad():
        img_batch = [clip_preprocess(i).unsqueeze(0) for i in images]
        img_batch = torch.concatenate(img_batch).to(device)
        image_features = model.encode_image(img_batch)

        text = tokenizer([prompt]).to(device)
        text_features = model.encode_text(text)
        
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        
        return (image_features @ text_features.T).mean(-1)


def get_dataset(args):
    if 'laion' in args.dataset:
        dataset = load_dataset(args.dataset)['train']
        prompt_key = 'TEXT'
    elif 'coco' in args.dataset:
        with open('fid_outputs/coco/meta_data.json') as f:
            dataset = json.load(f)
            dataset = dataset['annotations']
            prompt_key = 'caption'
    else:
        dataset = load_dataset(args.dataset)['test']
        prompt_key = 'Prompt'

    no_w_dataset = None

    return dataset, no_w_dataset, prompt_key


def circle_mask(size=64, r=10, x_offset=0, y_offset=0):
    # reference: https://stackoverflow.com/questions/69687798/generating-a-soft-circluar-mask-using-numpy-python-3
    x0 = y0 = size // 2
    x0 += x_offset
    y0 += y_offset
    y, x = np.ogrid[:size, :size]
    y = y[::-1]

    return ((x - x0)**2 + (y-y0)**2)<= r**2


def get_watermarking_mask(init_latents_w, args, device):
    watermarking_mask = torch.zeros(init_latents_w.shape, dtype=torch.bool).to(device)

    if args.w_mask_shape == 'circle':
        np_mask = circle_mask(init_latents_w.shape[-1], r=args.w_radius)
        torch_mask = torch.tensor(np_mask).to(device)

        if args.w_channel == -1:
            # all channels
            watermarking_mask[:, :] = torch_mask
        else:
            watermarking_mask[:, args.w_channel] = torch_mask
    elif args.w_mask_shape == 'square':
        anchor_p = init_latents_w.shape[-1] // 2
        if args.w_channel == -1:
            # all channels
            watermarking_mask[:, :, anchor_p-args.w_radius:anchor_p+args.w_radius, anchor_p-args.w_radius:anchor_p+args.w_radius] = True
        else:
            watermarking_mask[:, args.w_channel, anchor_p-args.w_radius:anchor_p+args.w_radius, anchor_p-args.w_radius:anchor_p+args.w_radius] = True
    elif args.w_mask_shape == 'no':
        pass
    else:
        raise NotImplementedError(f'w_mask_shape: {args.w_mask_shape}')

    return watermarking_mask


def get_watermarking_pattern(pipe, args, device, shape=None):
    set_random_seed(args.w_seed)
    if shape is not None:
        gt_init = torch.randn(*shape, device=device)
    else:
        gt_init = pipe.get_random_latents()

    if 'seed_ring' in args.w_pattern:
        gt_patch = gt_init

        gt_patch_tmp = copy.deepcopy(gt_patch)
        for i in range(args.w_radius, 0, -1):
            tmp_mask = circle_mask(gt_init.shape[-1], r=i)
            tmp_mask = torch.tensor(tmp_mask).to(device)
            
            for j in range(gt_patch.shape[1]):
                gt_patch[:, j, tmp_mask] = gt_patch_tmp[0, j, 0, i].item()
    elif 'seed_zeros' in args.w_pattern:
        gt_patch = gt_init * 0
    elif 'seed_rand' in args.w_pattern:
        gt_patch = gt_init
    elif 'rand' in args.w_pattern:
        gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2))
        gt_patch[:] = gt_patch[0]
    elif 'zeros' in args.w_pattern:
        gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2)) * 0
    elif 'const' in args.w_pattern:
        gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2)) * 0
        gt_patch += args.w_pattern_const
    elif 'ring' in args.w_pattern:
        gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2))

        gt_patch_tmp = copy.deepcopy(gt_patch)
        for i in range(args.w_radius, 0, -1):
            tmp_mask = circle_mask(gt_init.shape[-1], r=i)
            tmp_mask = torch.tensor(tmp_mask).to(device)
            
            for j in range(gt_patch.shape[1]):
                gt_patch[:, j, tmp_mask] = gt_patch_tmp[0, j, 0, i].item()
    
    if 'mag' in args.w_pattern:
        gt_patch = gt_patch.abs()

    return gt_patch


def inject_watermark(init_latents_w, watermarking_mask, gt_patch, args):
    init_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(init_latents_w), dim=(-1, -2))
    if args.w_injection == 'complex':
        init_latents_w_fft[watermarking_mask] = gt_patch[watermarking_mask].clone()
    elif args.w_injection == 'mag':
        init_latents_w_fft[watermarking_mask] = gt_patch[watermarking_mask].clone() * init_latents_w_fft[watermarking_mask] / init_latents_w_fft[watermarking_mask].abs()
    elif args.w_injection == 'phase':
        phase_shift = torch.exp(1j * gt_patch[watermarking_mask].clone())
        init_latents_w_fft[watermarking_mask] = init_latents_w_fft[watermarking_mask].abs() * phase_shift
    elif args.w_injection == 'seed':
        init_latents_w[watermarking_mask] = gt_patch[watermarking_mask].clone()
        return init_latents_w
    else:
        NotImplementedError(f'w_injection: {args.w_injection}')

    init_latents_w = torch.fft.ifft2(torch.fft.ifftshift(init_latents_w_fft, dim=(-1, -2))).real

    return init_latents_w


def eval_watermark(reversed_latents_no_w, reversed_latents_w, watermarking_mask, gt_patch, args):
    if args.rotate_test is True:
        # hard coding for now
        tmp_no_w_m = []
        tmp_w_m = []

        for degree in range(360):
            latents_1 = transforms.RandomRotation((degree, degree))(reversed_latents_no_w)
            latents_2 = transforms.RandomRotation((degree, degree))(reversed_latents_w)

            if 'real' in args.w_measurement:
                reversed_latents_no_w_fft = torch.fft.fftshift(torch.fft.fft2(latents_1), dim=(-1, -2)).real
                reversed_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(latents_2), dim=(-1, -2)).real
                target_patch = gt_patch.real
            elif 'abs' in args.w_measurement:
                reversed_latents_no_w_fft = torch.fft.fftshift(torch.fft.fft2(latents_1), dim=(-1, -2)).abs()
                reversed_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(latents_2), dim=(-1, -2)).abs()
                target_patch = gt_patch.abs()
            elif 'complex' in args.w_measurement:
                reversed_latents_no_w_fft = torch.fft.fftshift(torch.fft.fft2(latents_1), dim=(-1, -2))
                reversed_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(latents_2), dim=(-1, -2))
                target_patch = gt_patch
            elif 'seed' in args.w_measurement:
                reversed_latents_no_w_fft = reversed_latents_no_w
                reversed_latents_w_fft = reversed_latents_w
                target_patch = gt_patch
            else:
                NotImplementedError(f'w_measurement: {args.w_measurement}')

            if 'as_real_l1' in args.w_measurement:
                # hard coding for now
                no_w_metric = torch.abs(torch.view_as_real(reversed_latents_no_w_fft[watermarking_mask]) - torch.view_as_real(target_patch[watermarking_mask])).mean().item()
                w_metric = torch.abs(torch.view_as_real(reversed_latents_w_fft[watermarking_mask]) - torch.view_as_real(target_patch[watermarking_mask])).mean().item()
            elif 'as_real_l2' in args.w_measurement:
                reversed_latents_no_w_fft[reversed_latents_no_w_fft == float("Inf")] = 0
                reversed_latents_w_fft[reversed_latents_w_fft == float("Inf")] = 0
                no_w_metric = (torch.view_as_real(reversed_latents_no_w_fft[watermarking_mask]) - torch.view_as_real(target_patch[watermarking_mask])).pow(2).mean().item()
                w_metric = (torch.view_as_real(reversed_latents_w_fft[watermarking_mask]) - torch.view_as_real(target_patch[watermarking_mask])).pow(2).mean().item()
            elif 'l1' in args.w_measurement:
                no_w_metric = torch.abs(reversed_latents_no_w_fft[watermarking_mask] - target_patch[watermarking_mask]).mean().item()
                w_metric = torch.abs(reversed_latents_w_fft[watermarking_mask] - target_patch[watermarking_mask]).mean().item()
            elif 'cosine' in args.w_measurement:
                cos_funct = torch.nn.CosineSimilarity(dim=0)
                no_w_metric = cos_funct(reversed_latents_no_w_fft[watermarking_mask].flatten(), target_patch[watermarking_mask].flatten()).item()
                w_metric = cos_funct(reversed_latents_w_fft[watermarking_mask].flatten(), target_patch[watermarking_mask].flatten()).item()
            elif 'var' in args.w_measurement:
                no_w_metric = reversed_latents_no_w_fft[watermarking_mask].var().item()
                w_metric = reversed_latents_w_fft[watermarking_mask].var().item()
            elif 'mag' in args.w_measurement:
                no_w_metric = torch.abs(reversed_latents_no_w_fft[watermarking_mask].abs() - target_patch[watermarking_mask]).mean().item()
                w_metric = torch.abs(reversed_latents_w_fft[watermarking_mask].abs() - target_patch[watermarking_mask]).mean().item()
            else:
                NotImplementedError(f'w_measurement: {args.w_measurement}')

            tmp_no_w_m.append(no_w_metric)
            tmp_w_m.append(w_metric)

        no_w_metric = min(tmp_no_w_m)
        w_metric = min(tmp_w_m)
    else:
        if 'complex' in args.w_measurement:
            reversed_latents_no_w_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents_no_w), dim=(-1, -2))
            reversed_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents_w), dim=(-1, -2))
            target_patch = gt_patch
        elif 'abs' in args.w_measurement:
            reversed_latents_no_w_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents_no_w), dim=(-1, -2)).abs()
            reversed_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents_w), dim=(-1, -2)).abs()
            target_patch = gt_patch.abs()
        elif 'real' in args.w_measurement:
            reversed_latents_no_w_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents_no_w), dim=(-1, -2)).real
            reversed_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents_w), dim=(-1, -2)).real
            target_patch = gt_patch.real
        elif 'seed' in args.w_measurement:
            reversed_latents_no_w_fft = reversed_latents_no_w
            reversed_latents_w_fft = reversed_latents_w
            target_patch = gt_patch
        else:
            NotImplementedError(f'w_measurement: {args.w_measurement}')

        if 'as_real_l1' in args.w_measurement:
            no_w_metric = torch.abs(torch.view_as_real(reversed_latents_no_w_fft[watermarking_mask]) - torch.view_as_real(target_patch[watermarking_mask])).mean().item()
            w_metric = torch.abs(torch.view_as_real(reversed_latents_w_fft[watermarking_mask]) - torch.view_as_real(target_patch[watermarking_mask])).mean().item()
        elif 'as_real_l2' in args.w_measurement:
            reversed_latents_no_w_fft[reversed_latents_no_w_fft == float("Inf")] = 0
            reversed_latents_w_fft[reversed_latents_w_fft == float("Inf")] = 0
            no_w_metric = (torch.view_as_real(reversed_latents_no_w_fft[watermarking_mask]) - torch.view_as_real(target_patch[watermarking_mask])).pow(2).mean().item()
            w_metric = (torch.view_as_real(reversed_latents_w_fft[watermarking_mask]) - torch.view_as_real(target_patch[watermarking_mask])).pow(2).mean().item()
        elif 'l1' in args.w_measurement:
            no_w_metric = torch.abs(reversed_latents_no_w_fft[watermarking_mask] - target_patch[watermarking_mask]).mean().item()
            w_metric = torch.abs(reversed_latents_w_fft[watermarking_mask] - target_patch[watermarking_mask]).mean().item()
        elif 'cosine' in args.w_measurement:
            cos_funct = torch.nn.CosineSimilarity(dim=0)
            no_w_metric = cos_funct(reversed_latents_no_w_fft[watermarking_mask].flatten(), target_patch[watermarking_mask].flatten()).item()
            w_metric = cos_funct(reversed_latents_w_fft[watermarking_mask].flatten(), target_patch[watermarking_mask].flatten()).item()
        elif 'var' in args.w_measurement:
            no_w_metric = reversed_latents_no_w_fft[watermarking_mask].var().item()
            w_metric = reversed_latents_w_fft[watermarking_mask].var().item()
        elif 'mag' in args.w_measurement:
            no_w_metric = torch.abs(reversed_latents_no_w_fft[watermarking_mask].abs() - target_patch[watermarking_mask]).mean().item()
            w_metric = torch.abs(reversed_latents_w_fft[watermarking_mask].abs() - target_patch[watermarking_mask]).mean().item()
        else:
            NotImplementedError(f'w_measurement: {args.w_measurement}')

    return no_w_metric, w_metric
