import torch
import torchvision.transforms.functional as F

ACTIVITY_NET_MEAN = [114.7748 / 255., 107.7354 / 255., 99.4750 / 255.]

def batch_mean_sub(batch, time_dim=1):
    # assumes batch first
    if len(batch):
        for i, vid in enumerate(torch.split(batch, 1, dim=0)):
            vid = vid.squeeze(0)
            batch[i] = mean_sub(vid, time_dim=time_dim)
    return batch

def batch_mean_add(batch, time_dim=1):
    # assumes BCTWH; we need to add in the C dimension
    batch = batch + torch.Tensor(ACTIVITY_NET_MEAN).view(-1, 1, 1, 1).to(batch.device)
    return batch

def mean_sub(seq, time_dim=1):
    if len(seq):
        for i, img in enumerate(torch.split(seq, 1, dim=time_dim)):
            img = img.squeeze(time_dim)
            seq[:, i, ...] = F.normalize(img, ACTIVITY_NET_MEAN, [1., 1., 1.])
    return seq

def mean_add(seq, time_dim=1):
    if len(seq):
        for i, img in enumerate(torch.split(seq, 1, dim=time_dim)):
            img = img.squeeze(time_dim)
            seq[:, i, ...] = img + torch.Tensor(ACTIVITY_NET_MEAN).view(-1, 1, 1).to(img.device)
    return seq
