from functools import partial

import torch
import torch.nn.functional as F


unsqueezer = partial(torch.unsqueeze, dim=0)


def map_fn(batch, fn):
    if isinstance(batch, dict):
        for k in batch.keys():
            batch[k] = map_fn(batch[k], fn)
        return batch
    elif isinstance(batch, list):
        return [map_fn(e, fn) for e in batch]
    elif isinstance(batch, str):
        return batch
    else:
        return fn(batch)


def to(data, device, non_blocking=True):
    if isinstance(data, dict):
        return {k: to(data[k], device, non_blocking=non_blocking) for k in data.keys()}
    elif isinstance(data, list):
        return [to(v, device, non_blocking=non_blocking) for v in data]
    elif isinstance(data, str):
        return data
    else:
        return data.to(device, non_blocking=non_blocking)


def set_requires_grad(nets, requires_grad=False):
    if not isinstance(nets, list):
        nets = [nets]
    for net in nets:
        if net is not None:
            for param in net.parameters():
                param.requires_grad = requires_grad


def mask_mean(t: torch.Tensor, m: torch.Tensor, dim=None, keepdim=False):
    t = t.clone()
    t[m] = 0
    els = 1
    if dim is None or len(dim)==0:
        dim = list(range(len(t.shape)))
    for d in dim:
        els *= t.shape[d]
    return torch.sum(t, dim=dim, keepdim=keepdim) / (els - torch.sum(m.to(torch.float), dim=dim, keepdim=keepdim))


def apply_crop(array, crop):
    return array[crop[0]:crop[0] + crop[2], crop[1]:crop[1] + crop[3]]


def shrink_mask(mask, shrink=3):
    mask = F.avg_pool2d(mask.to(torch.float32), kernel_size=shrink, padding=shrink // 2, stride=1)
    return (mask == 1.).to(torch.float32)


def get_mask(size, border=5, device=None):
    mask = torch.ones(size, dtype=torch.float32)
    mask = shrink_mask(mask, border)
    if device is not None:
        mask = mask.to(device)
    return mask


def get_grid(H, W, normalize=True):
    if normalize:
        h_range = torch.linspace(-1,1,H)
        w_range = torch.linspace(-1,1,W)
    else:
        h_range = torch.arange(0,H)
        w_range = torch.arange(0,W)
    grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).flip(2).float() # flip h,w to x,y
    return grid


def detach(t):
    if isinstance(t, tuple):
        return tuple(t_.detach() for t_ in t)
    else: return t.detach()
