from PIL import Image
import os
import numpy as np
from scipy.fftpack import dct, idct
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms

def patch_badnet(img):
    if img.mode != 'RGB':
        img = img.convert('RGB')

    width, height = img.size
    trigger_size = max(1, min(width, height) // 10)

    trigger = Image.new('RGB', (trigger_size, trigger_size), 'white')
    pixels = trigger.load()
    for i in range(trigger_size):
        for j in range(trigger_size):
            if (i % 2 == 0) and (j % 2 == 0):
                pixels[i, j] = (0, 0, 0)

    x = width - trigger_size - 1
    y = height - trigger_size - 1

    img.paste(trigger, (x, y))

    return img


def patch_blended(img, trigger_path='mimi.png', alpha=0.1, seed=42):
    if img.mode != 'RGB':
        img = img.convert('RGB')

    width, height = img.size

    if trigger_path and os.path.exists(trigger_path):
        trigger = Image.open(trigger_path).convert('RGB')
        trigger = trigger.resize((width, height))
    else:
        rng = np.random.RandomState(seed)
        trigger_arr = rng.randint(0, 256, (height, width, 3), dtype=np.uint8)
        trigger = Image.fromarray(trigger_arr)

    return Image.blend(img, trigger, alpha)


def patch_sig(img, delta=20, f=6):
    img_arr = np.array(img, dtype=np.float32)

    if len(img_arr.shape) == 2:
        h, w = img_arr.shape
        c = 1
    else:
        h, w, c = img_arr.shape

    x = np.arange(w)
    signal = delta * np.sin(2 * np.pi * f * x / w)

    signal = np.tile(signal, (h, 1))

    if c > 1:
        signal = np.stack([signal] * c, axis=-1)

    img_arr = img_arr + signal
    img_arr = np.clip(img_arr, 0, 255)

    return Image.fromarray(img_arr.astype(np.uint8))


def patch_ftrojan(img: Image.Image,
                  channel_list=[0, 1, 2],
                  magnitude=100.0,
                  window_size=32,
                  pos_list=[(15, 15)],
                  use_yuv=True) -> Image.Image:

    img = img.convert('RGB')
    img_arr = np.array(img, dtype=np.float32)

    if use_yuv:
        transform_matrix = np.array([[0.299, 0.587, 0.114],
                                     [-0.14713, -0.28886, 0.436],
                                     [0.615, -0.51499, -0.10001]])
        img_arr = img_arr.dot(transform_matrix.T)

    h, w, c = img_arr.shape

    for ch in channel_list:
        if ch >= c:
            continue

        channel_data = img_arr[:, :, ch]

        for i in range(0, h, window_size):
            for j in range(0, w, window_size):
                block_h = min(window_size, h - i)
                block_w = min(window_size, w - j)
                block = channel_data[i:i+block_h, j:j+block_w]

                if block_h != window_size or block_w != window_size:
                    continue

                dct_block = dct(dct(block.T, norm='ortho').T, norm='ortho')

                for pos in pos_list:
                    u, v = pos
                    if u < window_size and v < window_size:
                        dct_block[u, v] += magnitude

                idct_block = idct(idct(dct_block.T, norm='ortho').T, norm='ortho')

                channel_data[i:i+block_h, j:j+block_w] = idct_block

        img_arr[:, :, ch] = channel_data

    if use_yuv:
        inverse_matrix = np.linalg.inv(transform_matrix)
        img_arr = img_arr.dot(inverse_matrix.T)

    img_arr = np.clip(img_arr, 0, 255).astype(np.uint8)
    return Image.fromarray(img_arr)

def patch_wanet(img: Image.Image,
                k: int = 16,
                s: float = 8.0,
                grid_rescale: float = 0.95,
                seed: int = 42) -> Image.Image:

    to_tensor = transforms.ToTensor()
    to_pil = transforms.ToPILImage()

    x = to_tensor(img).unsqueeze(0)
    _, _, H, W = x.shape

    state = torch.get_rng_state()
    torch.manual_seed(seed)

    ins = torch.rand(1, 2, k, k) * 2 - 1

    ins = ins / torch.mean(torch.abs(ins))

    noise_grid = F.interpolate(ins, size=(H, W), mode="bicubic", align_corners=True)
    noise_grid = noise_grid.permute(0, 2, 3, 1)

    torch.set_rng_state(state)
    array_h = torch.linspace(-1, 1, steps=H)
    array_w = torch.linspace(-1, 1, steps=W)

    xx, yy = torch.meshgrid(array_h, array_w, indexing='ij')
    identity_grid = torch.stack((yy, xx), 2).unsqueeze(0)
    grid_temps = (identity_grid + s * noise_grid / H) * grid_rescale

    grid_temps = torch.clamp(grid_temps, -1, 1)
    x_bd = F.grid_sample(x, grid_temps, align_corners=True)
    output_pil = to_pil(x_bd.squeeze(0))

    return output_pil
