import torch
from functools import reduce
from typing import Tuple, Union


def unravel_ind(idx: torch.Tensor, 
                img_size: Union[torch.Size, Tuple]):
    C = idx // (img_size[-2] * img_size[-1])
    X = (idx // img_size[-1]) % img_size[-2]
    Y = idx % img_size[-1] 
    return torch.stack((C, X, Y), dim=1)


def gen_private_key(img_size: Union[torch.Size, Tuple], 
                    key_len: int, is_structured: bool,
                    inner_region_len = 150, delta = 10,
                    device: Union[torch.device, str] = None):
    
    if is_structured:
        return structured_gen_private_key(
            img_size, key_len, device,
            inner_region_len, delta,
                    )
    
    if len(img_size) == 4:# and img_size[0] > 1:
        return torch.stack([
            gen_private_key(img_size[1:], key_len,
                            is_structured, inner_region_len, 
                            delta, device)
            for _ in range(img_size[0])
            ])
    
    numel = reduce(lambda x, y: x * y, img_size)
    if numel < 2 * key_len:
        raise ValueError("Impossible to generate unique index pairs!")
        
    idx = torch.randperm(numel, device=device)[: 2 * key_len]
    idx = idx.view(2, key_len)
    return unravel_ind(idx, img_size) 


def structured_gen_private_key(img_size: Union[torch.Size, Tuple], 
                    key_len: int, 
                    device: Union[torch.device, str] = None,
                    inner_region_len = 150,
                    delta = 10,
                    ):
    if len(img_size) == 4:# and img_size[0] > 1:
        return torch.stack([
            structured_gen_private_key(img_size[1:], key_len)
            for _ in range(img_size[0])
            ])
        
    numel = reduce(lambda x, y: x * y, img_size)
    if numel < 2 * key_len:
        raise ValueError("Impossible to generate unique index pairs!")


    inner_idx_c = torch.randint(low=0, high=3, size=(key_len,))
    inner_idx_y = torch.randperm(inner_region_len, device=device)[:key_len] + (img_size[-2] // 2 - inner_region_len // 2)
    inner_idx_x = torch.randperm(inner_region_len, device=device)[:key_len] + (img_size[-1] // 2 - inner_region_len // 2)
    
    outer_idx_c = torch.randint(low=0, high=3, size=(key_len,))
    outer_idx_y = torch.randperm(img_size[-2] - inner_region_len - delta, device=device)[:key_len]
    outer_idx_x = torch.randperm(img_size[-1] - inner_region_len - delta, device=device)[:key_len]
    
    outer_idx_y, outer_idx_x = [
        torch.where(torch.logical_and(
            x > img_size[-1] // 2 - inner_region_len // 2 - delta, 
            x < img_size[-1] // 2 + inner_region_len // 2 + delta # assume square img
        ), x + inner_region_len + delta, x) for x in (outer_idx_y, outer_idx_x)
    ]
    
    
    inner = torch.stack((
        inner_idx_c, inner_idx_y, inner_idx_x
    ))
    
    outer = torch.stack((
        outer_idx_c, outer_idx_y, outer_idx_x
    ))

    return torch.stack((inner, outer)) 



def gen_public_key(num_keys: int, key_len: int, 
            device: Union[torch.device, str]=None):
    return torch.randint(0, 2, (num_keys, key_len,), device=device)



def extract_wm(img: torch.Tensor, private_key: torch.Tensor, hard=True):
    img = img.float() # to avoid uint8 case -> inadequate pixel differences
    if img.ndim == 4:
        if private_key.ndim == 3:
            private_key = private_key[None]
            
        return torch.stack([
            extract_wm(im, priv, hard=hard) for im, priv
            in zip(img, private_key)
        ])
    pk1, pk2 = private_key
    
    wm = (img[pk1[0], pk1[1], pk1[2]] - img[pk2[0], pk2[1], pk2[2]])
    if hard:
        wm = wm >= 0
    return wm


def compute_wm_error(attacked_img, priv_key, pub_key, tau=None, reduce='average'):
    attacked_wm = extract_wm(attacked_img, priv_key)
    if tau is None:
        num_errors = torch.where(attacked_wm != pub_key)
    else:
        num_errors = torch.sum(attacked_wm == pub_key, dim=1) 
        num_errors = torch.logical_or(num_errors < 50 - tau, num_errors > 50 + tau).sum() # num of neg
        return num_errors.item()
        
    if reduce == 'average':
        return num_errors[0].size(0) / attacked_img.size(0) / pub_key.size(-1)
    elif reduce == 'none':
        return torch.unique(num_errors[0], return_counts=True)[1] / pub_key.size(-1)
    else:
        raise NotImplementedError
    

def swap_pxls(img: torch.Tensor, private_key: torch.Tensor,
            public_key: torch.Tensor, eps: Union[float, int] = 1):
    if img.ndim == 4 and img.size(0) > 1:
        return torch.concatenate([
            swap_pxls(im[None], priv, pub, eps) for im, priv, pub 
            in zip(img, private_key, public_key)
            ])

    wm = extract_wm(img, private_key)
    diff = wm != public_key

    # to subtract from lesser numbers 
    # and add to greater ones
    sgn = (wm > public_key).float() * 2 - 1
    sgn = sgn[diff] * eps
    
    private_key = private_key.to(img.device)

    if diff.ndim == 2 and diff.size(0) == 1: # FIXME edge case
        diff = diff.squeeze(0)


    swap_ind1, swap_ind2 = private_key[:, :, :, diff][0]
    
    img[:, swap_ind1[0], swap_ind1[1], swap_ind1[2]], img[:, swap_ind2[0], swap_ind2[1], swap_ind2[2]] = (img[:, swap_ind2[0], swap_ind2[1], swap_ind2[2]] - sgn, 
                                            img[:, swap_ind1[0], swap_ind1[1], swap_ind1[2]] + sgn)
    return img


