import torch
from PIL import Image, ImageFilter
import random
import numpy as np
import copy
from typing import Any, Mapping
import json
import scipy



def circle_mask(size=64, r=10, x_offset=0, y_offset=0):
    # reference: https://stackoverflow.com/questions/69687798/generating-a-soft-circluar-mask-using-numpy-python-3
    x0 = y0 = size // 2
    x0 += x_offset
    y0 += y_offset
    y, x = np.ogrid[:size, :size]
    y = y[::-1]

    return ((x - x0)**2 + (y-y0)**2)<= r**2

class AddGaussianNoise(object):
    def __init__(self, std):
        self.std = std

    def __call__(self, img):
        # 将PIL图像转换为NumPy数组
        np_img = np.array(img).astype(np.float32)
        
        # 添加高斯噪声
        noise = np.random.randn(*np_img.shape) * self.std
        np_img += noise
        
        # 确保值在有效范围内
        np_img = np.clip(np_img, 0, 255)
        
        # 将NumPy数组转换回PIL图像
        noisy_img = Image.fromarray(np_img.astype(np.uint8))
        return noisy_img
    
# def mae_loss(embedding1, embedding2, args, mask=None):
#     """
#     计算两个嵌入之间的MSE损失

#     Args:
#     embedding1: 第一个嵌入张量
#     embedding2: 第二个嵌入张量

#     Returns:
#     mse: 计算得到的MSE损失
#     """
#     if mask is not None:
#         if args.shift:
#             embedding1 = torch.fft.fftshift(torch.fft.fft2(embedding1), dim=(-1, -2))
#             embedding2 = torch.fft.fftshift(torch.fft.fft2(embedding2), dim=(-1, -2))
#         else:
#             embedding1 = torch.fft.fft2(embedding1)
#             embedding2 = torch.fft.fft2(embedding2)
#         if mask.shape[0] != embedding1.shape[0]:
#             expand_mask = mask.expand(embedding1.shape[0], -1, -1, -1)
            
#         mse = torch.abs(embedding1[expand_mask] - embedding2[expand_mask])
#         mse = torch.mean(mse)
#     else:
#         mse = torch.mean(torch.abs(embedding1 - embedding2))
#     return mse

from skimage.draw import circle_perimeter
from skimage.draw import disk

def create_round_ring_mask(image_size, radius):
    center = image_size // 2
    mask = np.zeros((image_size, image_size))

    # 生成外圆环
    rr_outer, cc_outer = circle_perimeter(center, center, radius)
    mask[rr_outer, cc_outer] = 1

    return mask.astype(bool)

def create_round_ring_disk_mask(image_size, radius):
    center = image_size // 2
    mask = np.zeros((image_size, image_size), dtype=bool)

    # 创建外圆盘
    rr_outer, cc_outer = disk((center, center), radius)
    mask[rr_outer, cc_outer] = 1
    

    # 创建内圆盘
    rr_inner, cc_inner = disk((center, center), radius-1)
    mask[rr_inner, cc_inner] = 0
    

    return mask.astype(bool)


def get_watermarking_mask(init_latents_w, args, device):
    if args.w_scale_hw>1:
        b, c, h, w = init_latents_w.shape
        init_latents_w = init_latents_w.reshape(b, -1, h//args.w_scale_hw, w//args.w_scale_hw)
        watermarking_mask = torch.zeros(init_latents_w.shape, dtype=torch.bool).to(device)
    else:
        watermarking_mask = torch.zeros(init_latents_w.shape, dtype=torch.bool).to(device)
    scale_factor = args.w_scale_hw*args.w_scale_hw
    if args.w_mask_shape == 'circle':
        # np_mask = circle_mask(init_latents_w.shape[-1], r=args.w_r_start-args.w_r_end)
        np_mask = create_rounder_ring_mask(init_latents_w.shape[-1], inner_radius=args.w_r_end, outer_radius=args.w_r_start)
        torch_mask = torch.tensor(np_mask).to(device)

        if args.w_channel[0] == -1:
            # all channels
            watermarking_mask[:, :] = torch_mask
        else:
            for c in args.w_channel:
                watermarking_mask[:, c*scale_factor: (c+1)*scale_factor] = torch_mask
                # watermarking_mask[:, c] = torch_mask
        # true_count = torch.sum(watermarking_mask).item()

        # # 计算总元素的数量
        # total_count = watermarking_mask.numel()

        # # 计算 True 元素的比例
        # true_ratio = true_count / total_count

        # print(f"True 元素的数量: {true_count}")
        # print(f"总元素的数量: {total_count}")
        # print(f"True 元素的比例: {true_ratio:.2%}")

    elif args.w_mask_shape == 'square':
        anchor_p = init_latents_w.shape[-1] // 2
        if args.w_channel[0] == -1:
            # all channels
            watermarking_mask[:, :, anchor_p-args.w_radius:anchor_p+args.w_radius, anchor_p-args.w_radius:anchor_p+args.w_radius] = True
        else:
            for c in args.w_channel:
                watermarking_mask[:, c, anchor_p-args.w_radius:anchor_p+args.w_radius, anchor_p-args.w_radius:anchor_p+args.w_radius] = True
    elif args.w_mask_shape == 'no':
        pass
    else:
        raise NotImplementedError(f'w_mask_shape: {args.w_mask_shape}')

    return watermarking_mask


def get_watermarking_pattern(args, device, gt_init=None):
    if args.w_scale_hw>1:
        b, c, h, w = gt_init.shape
        gt_init = gt_init.reshape(b, -1, h//args.w_scale_hw, w//args.w_scale_hw)
    else:
        gt_init = gt_init
    # gt_init = torch.randn(gt_init.shape, device=device)

    if 'seed_ring' in args.w_pattern:
        gt_patch = gt_init

        gt_patch_tmp = copy.deepcopy(gt_patch.detach())
        for i in range(args.w_r_start, args.w_r_end, -args.w_r_interval):
            tmp_mask = circle_mask(gt_init.shape[-1], r=i)
            tmp_mask = torch.tensor(tmp_mask).to(device)
            
            for j in range(gt_patch.shape[1]):
                gt_patch[:, j, tmp_mask] = gt_patch_tmp[0, j, 0, i].item()
    elif 'seed_zeros' in args.w_pattern:
        gt_patch = gt_init * 0
    elif 'seed_rand' in args.w_pattern:
        gt_patch = gt_init
    elif 'rand' in args.w_pattern:
        gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2))
        gt_patch[:] = gt_patch[0]
    elif 'zeros' in args.w_pattern:
        gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2)) * 0
    elif 'const' in args.w_pattern:
        gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2)) * 0
        gt_patch.real = 10000
    elif 'ring' in args.w_pattern:
        if args.shift:
            gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2))
        else:
            gt_patch = torch.fft.fft2(gt_init)
            
        
        gt_patch_tmp = copy.deepcopy(gt_patch.detach())
        for i in range(args.w_r_start, args.w_r_end, -args.w_r_interval):
            # tmp_mask = circle_mask(gt_init.shape[-1], r=i) 
            tmp_mask = create_rounder_ring_mask(gt_init.shape[-1], inner_radius=args.w_r_end, outer_radius=i)
            tmp_mask = torch.tensor(tmp_mask).to(device)
            for j in range(gt_patch.shape[1]):
                gt_patch[:, j, tmp_mask] = gt_patch_tmp[0, j, 0, i].item()
        gt_patch.imag = 0
        
        # test
        
        # N = gt_patch.shape[-1]
        # u, v = torch.tensor(np.meshgrid(np.arange(N), np.arange(N))).to(device)
        # chessboard_pattern = (-1)**(u + v)
        # eta = 0.85
        # gt_patch = gt_patch * chessboard_pattern * eta
        
    return gt_patch


# # 生成棋盘图案
# u, v = np.meshgrid(np.arange(N), np.arange(N))
# chessboard_pattern = (-1)**(u + v)
# # 设置eta值
# eta = 0.85

# # 在频域中相乘并调整强度
# fft_watermark_modified = fft_watermark_shifted * chessboard_pattern * eta

def inject_watermark(init_latents_w, watermarking_mask, gt_patch, args, device, return_fft=None):
    if init_latents_w.shape[0]!=gt_patch.shape[0] and args.w_injection != 'learn':
        gt_patch = gt_patch.expand(init_latents_w.shape)
        watermarking_mask = watermarking_mask.expand(init_latents_w.shape)
    
    # patchify
    if args.w_scale_hw>1:
        b, c, h, w = init_latents_w.shape
        init_latents_w = init_latents_w.reshape(b, -1, h//args.w_scale_hw, w//args.w_scale_hw)

    if args.shift:
        init_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(init_latents_w), dim=(-1, -2))
    else:
        init_latents_w_fft = torch.fft.fft2(init_latents_w)
    if args.w_injection == 'complex':
        
        init_latents_w_fft[watermarking_mask] = gt_patch[watermarking_mask].clone()
        
        # init_latents_w_fft[watermarking_mask] = gt_patch[watermarking_mask].clone()
        # init_latents_w_fft[watermarking_mask].imag = 0
    elif args.w_injection == 'learn':
        # init_latents_w_fft = init_latents_w_fft.detach()
        # init_latents_w_fft = init_latents_w_fft.detach()
        # real_init_latents_w_fft = torch.real(init_latents_w_fft)
        init_latents_w_fft.real[watermarking_mask] = gt_patch
        init_latents_w_fft[watermarking_mask].imag = 0
    elif args.w_injection == "mean":
        for i in range(args.w_r_start, args.w_r_end, -args.w_r_interval):
            tmp_mask = create_rounder_ring_mask(init_latents_w.shape[-1], inner_radius=args.w_r_end, outer_radius=i)
            tmp_mask = torch.tensor(tmp_mask).to(device)
            for c in args.w_channel:
                init_latents_w_fft[:, c, tmp_mask].real = torch.min(init_latents_w_fft[0, c, tmp_mask].real)
                init_latents_w_fft[:, c, tmp_mask].imag = 0
                # init_latents_w_fft[:, c, tmp_mask] = init_latents_w_fft[0, c, 0, i].item()
    elif args.w_injection == 'seed':
        init_latents_w[watermarking_mask] = gt_patch[watermarking_mask].clone()
        return init_latents_w
    else:
        raise NotImplementedError(f'w_injection: {args.w_injection}')
    if return_fft:
        init_latents_w_fft.requires_grad_(True)
    if args.shift:
        init_latents_w = torch.fft.ifft2(torch.fft.ifftshift(init_latents_w_fft, dim=(-1, -2))).real
    else:
        init_latents_w = torch.fft.ifft2(init_latents_w_fft).real

    if args.w_scale_hw>1:
        init_latents_w = init_latents_w.reshape(b, c, h, w)
    if return_fft:
        return init_latents_w, init_latents_w_fft
    else:
        return init_latents_w

def cal_consistency(x, args):
    dis_list = []
    for i in range(args.w_r_start, args.w_r_end, -args.w_r_interval):
        tmp_mask = create_rounder_ring_mask(x.shape[-1], outer_radius=i, inner_radius=args.w_r_end)
        tmp_mask = torch.tensor(tmp_mask).to(x.device)
        for c in args.w_channel:
            x_select = x[0, tmp_mask]
            mean_value = torch.mean(x_select)
            # std_value = torch.std(x_select)
            dis = torch.mean(torch.abs((x_select - mean_value))).item()
            dis_list.append(dis)
    return np.mean(dis_list)

def cal_distance(x, target, args):
    dis_list = []
    for i in range(args.w_r_start, args.w_r_end, -args.w_r_interval):
        tmp_mask = circle_mask(x.shape[-1], r=i)
        tmp_mask = torch.tensor(tmp_mask).to(x.device)
        for c in args.w_channel:
            x_select = x[0, c, tmp_mask]
            target_value = target[0, c, 0, i]
            dis_list.extend(torch.abs((x_select - target_value)))
    return torch.mean(torch.tensor(dis_list)).item()


def loss_mae_part(reversed_latents, watermarking_mask, gt_patch, args):
    
    if args.shift:
        reversed_latents_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents), dim=(-1, -2))
    else:
        reversed_latents_fft = torch.fft.fft2(reversed_latents)
    target_patch = gt_patch
    # loss = torch.abs(reversed_latents_fft[watermarking_mask].real - target_patch).mean() + torch.abs(reversed_latents_fft[watermarking_mask].imag).mean()
    loss = torch.abs(reversed_latents_fft[watermarking_mask] - target_patch[watermarking_mask]).mean()
    return loss

def eval_watermark(reversed_latents_no_w, reversed_latents_w, watermarking_mask, gt_patch, args):
    # patchify
    if args.w_scale_hw>1:
        b, c, h, w = reversed_latents_no_w.shape
        reversed_latents_no_w = reversed_latents_no_w.reshape(b, -1, h//args.w_scale_hw, w//args.w_scale_hw)
        reversed_latents_w = reversed_latents_w.reshape(b, -1, h//args.w_scale_hw, w//args.w_scale_hw)
    if 'complex' in args.w_measurement:
        if args.shift:
            reversed_latents_no_w_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents_no_w), dim=(-1, -2))
            reversed_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents_w), dim=(-1, -2))
        else:
            reversed_latents_no_w_fft = torch.fft.fft2(reversed_latents_no_w)
            reversed_latents_w_fft = torch.fft.fft2(reversed_latents_w)
        target_patch = gt_patch
    elif 'seed' in args.w_measurement:
        reversed_latents_no_w_fft = reversed_latents_no_w
        reversed_latents_w_fft = reversed_latents_w
        target_patch = gt_patch
    else:
        NotImplementedError(f'w_measurement: {args.w_measurement}')

    if 'l1' in args.w_measurement:
        no_w_metric_list = []
        w_metric_list = []
        n_channels = reversed_latents_no_w.shape[1]
        scale_factor = args.w_scale_hw*args.w_scale_hw
        c_list = [j for c in args.w_channel for j in range(c * scale_factor, (c + 1) * scale_factor)]
        for c in c_list:
            no_w_metric_channel = torch.abs(reversed_latents_no_w_fft[:, c][watermarking_mask[:, c]] - target_patch[:, c][watermarking_mask[:, c]]).mean().item()
            w_metric_channel = torch.abs(reversed_latents_w_fft[:, c][watermarking_mask[:, c]] - target_patch[:, c][watermarking_mask[:, c]]).mean().item()
            no_w_metric_list.append(no_w_metric_channel)
            w_metric_list.append(w_metric_channel)
        no_w_metric = np.mean(np.sort(no_w_metric_list)[-5:])
        w_metric = np.mean(np.sort(w_metric_list)[-5:])
        
        # no_w_metric = np.mean(np.sort(no_w_metric_list)[:5])
        # w_metric = np.mean(np.sort(w_metric_list)[:5])
        
        # w_metric = np.std(w_metric_list)
        # no_w_metric = np.std(no_w_metric_list)
        
        # no_w_metric = max(no_w_metric_list)
        # w_metric = max(w_metric_list)
        
        # no_w_metric = torch.abs(reversed_latents_no_w_fft[watermarking_mask] - target_patch[watermarking_mask]).mean().item()
        # w_metric = torch.abs(reversed_latents_w_fft[watermarking_mask] - target_patch[watermarking_mask]).mean().item()
        
    elif 'distri' in args.w_measurement:
        target_patch_w, target_patch_no_w = target_patch[0], target_patch[1]
        no_w_metric = cal_distance(reversed_latents_no_w_fft, target_patch_no_w, args)
        w_metric = cal_distance(reversed_latents_w_fft, target_patch_w, args)
        # no_w_metric = torch.abs(reversed_latents_no_w_fft[watermarking_mask] - target_patch_no_w[watermarking_mask]).mean().item()
        # w_metric = torch.abs(reversed_latents_w_fft[watermarking_mask] - target_patch_w[watermarking_mask]).mean().item()
    elif "mean" in args.w_measurement:
        
        no_w_metric_list = []
        w_metric_list = []
        n_channels = reversed_latents_no_w.shape[1]
        scale_factor = args.w_scale_hw*args.w_scale_hw
        c_list = [j for c in args.w_channel for j in range(c * scale_factor, (c + 1) * scale_factor)]
        for c in c_list:
            no_w_metric_channel = cal_consistency(reversed_latents_no_w_fft[:, c], args)
            w_metric_channel = cal_consistency(reversed_latents_w_fft[:, c], args)
            no_w_metric_list.append(no_w_metric_channel)
            w_metric_list.append(w_metric_channel)
        # no_w_metric = np.mean(np.sort(no_w_metric_list)[-5:])
        # w_metric = np.mean(np.sort(w_metric_list)[-5:])
        
        no_w_metric = min(no_w_metric_list)
        w_metric = min(w_metric_list)
        
        # no_w_metric = cal_consistency(reversed_latents_no_w_fft, args)
        # w_metric = cal_consistency(reversed_latents_w_fft, args)
    else:
        NotImplementedError(f'w_measurement: {args.w_measurement}')

    return no_w_metric, w_metric

def get_p_value(reversed_latents_no_w, reversed_latents_w, watermarking_mask, gt_patch, args):
    # assume it's Fourier space wm
    reversed_latents_no_w_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents_no_w), dim=(-1, -2))[watermarking_mask].flatten()
    reversed_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents_w), dim=(-1, -2))[watermarking_mask].flatten()
    target_patch = gt_patch[watermarking_mask].flatten()

    target_patch = torch.concatenate([target_patch.real, target_patch.imag])
    
    # no_w
    reversed_latents_no_w_fft = torch.concatenate([reversed_latents_no_w_fft.real, reversed_latents_no_w_fft.imag])
    sigma_no_w = reversed_latents_no_w_fft.std()
    lambda_no_w = (target_patch ** 2 / sigma_no_w ** 2).sum().item()
    x_no_w = (((reversed_latents_no_w_fft - target_patch) / sigma_no_w) ** 2).sum().item()
    p_no_w = scipy.stats.ncx2.cdf(x=x_no_w, df=len(target_patch), nc=lambda_no_w)

    # w
    reversed_latents_w_fft = torch.concatenate([reversed_latents_w_fft.real, reversed_latents_w_fft.imag])
    sigma_w = reversed_latents_w_fft.std()
    lambda_w = (target_patch ** 2 / sigma_w ** 2).sum().item()
    x_w = (((reversed_latents_w_fft - target_patch) / sigma_w) ** 2).sum().item()
    p_w = scipy.stats.ncx2.cdf(x=x_w, df=len(target_patch), nc=lambda_w)

    return p_no_w, p_w
