import torch
import kornia
from typing import Tuple
from .pxl_swap import gen_public_key, extract_wm
from tqdm import tqdm
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torch.nn.functional import mse_loss
import cv2

def get_ksize(sigma: torch.Tensor):
    ksize = 2 * ((sigma.mean(0) - 0.8) / 0.3  + 1) + 1
    ksize = ksize.long()
    ksize += (ksize + 1) % 2 # to make it odd
    return ksize

def unsqueeze_to_4d(*tt: Tuple[torch.Tensor]):
    for x in tt:
        yield x[:, None, None, None]

def prep_transform_tensor(low, high, bsz, device, dtype='float'):
    if dtype == 'int':
        _t = torch.randint(low=low, high=high, size=(bsz,), device=device)
    else:
        _t = torch.rand(size=(bsz,), device=device) * (high - low) + low
    return next(unsqueeze_to_4d(_t))

def linear_transform_to_range(x, new_max=255):
    min_el, max_el = x.flatten(2).min(-1)[0], x.flatten(2).max(-1)[0]
    min_el, max_el = [x[..., None, None] for x in (min_el, max_el)]
    x = (x - min_el) / (max_el - min_el) * new_max
    if new_max == 255:
        x = x.long()
    return x

def attack_br_shift(x: torch.Tensor, low, high):
    add = prep_transform_tensor(low, high, x.size(0), x.device, dtype='int')
    x = x + add
    # return x.clip(0, 255).long()
    return linear_transform_to_range(x)

def attack_contrast(x: torch.Tensor, low, high):
    mul = prep_transform_tensor(low, high, x.size(0), x.device)
    x = x * mul
    # return x.clip(0, 255).long()
    return linear_transform_to_range(x)

def attack_gamma(x: torch.Tensor, low, high):
    gamma = prep_transform_tensor(low, high, x.size(0), x.device)
    x = x ** gamma
    # return x.clip(0, 255).long()
    return linear_transform_to_range(x)


def attack_noise(x, inf_norm_val):
    noise = 2 * torch.rand_like(x) - 1 # [-1, 1]
    noise = inf_norm_val * noise
    x = x + noise
    # return x.clip(0, 255).long()
    return linear_transform_to_range(x)


def attack_blur(x: torch.Tensor, sigma: torch.Tensor, ksize=None):
    ''' b[:, 0] - multiplication, b[:, 1] - addition,
        b[:, 2:4] - translation  '''

    if sigma.ndim == 1:
        sigma = sigma.unsqueeze(1)
    if sigma.size(1) == 1:
        sigma = sigma.repeat(1,2)
    # Now sigma.size must be [B, 2]

    sigma = torch.sqrt(sigma)
    if ksize is None:
        ksize = get_ksize(sigma)
        
    sigma = sigma.to(x.device)
        
    out = kornia.filters.gaussian_blur2d(x, kernel_size=ksize, sigma=sigma, 
                                        border_type='reflect')
    return out


def attack_WB(
    x, priv_key, num_iter,
    lr, attack_budget, wm_loss_w,
    lpips_w, mse_w, eps
    ):
    torch.set_grad_enabled(True)
    
    target_wm = gen_public_key(num_keys=priv_key.size(0),
                                    key_len=priv_key.size(-1),
                                    device=priv_key.device,
                                    ).float()
    
    # to [-1, 1]
    x = (x - 127.5) / 127.5  
    
    src = x.detach().clone()
    
    
    lpips = LearnedPerceptualImagePatchSimilarity(net_type='alex')
    lpips = lpips.to(x.device)
        
    # for _ in tqdm(range(num_iter)):
    for _ in range(num_iter):
        x.requires_grad = True
        opt = torch.optim.Adam(params=(x,), lr=lr)
        
        loss = wm_loss_batched(x, priv_key, target_wm, 
                            lpips, src, wm_loss_w, 
                            lpips_w, mse_w, eps)
        
        opt.zero_grad()    
        loss.backward()
        opt.step()
        
        with torch.no_grad():
            x = torch.where(x > src + attack_budget, src + attack_budget, x)
            x = torch.where(x < src - attack_budget, src - attack_budget, x)
    
    # print(f'linf norm of overall attack {torch.abs(x - src).max()}')
    return linear_transform_to_range(x.detach())
    

def wm_loss_batched(imgs, priv_key, pub_key, lpips, real_img, wm_loss_w, lpips_w, mse_w, eps):    
    # LPIPS
    lpips_loss = lpips(imgs.clip(-1, 1), real_img) 
    lpips_loss = lpips_loss + mse_w * mse_loss(imgs, real_img)  
    
    # WM error
    wm = extract_wm(imgs, priv_key, hard=False)
    sgn = 2 * pub_key - 1
    
    wm_error = - torch.min(sgn * wm - eps, 
                        torch.zeros_like(wm)).sum(-1).mean()
    
    return wm_loss_w * wm_error + lpips_w * lpips_loss
