'''Modified from https://github.com/alinlab/LfF/blob/master/util.py'''

import torch
import torch.nn.functional as nnf

class EMA:
    def __init__(self, label, num_classes=None, alpha=0.9):
        self.label = label.cuda()
        self.alpha = alpha
        self.parameter = torch.zeros(label.size(0))
        self.updated = torch.zeros(label.size(0))
        self.num_classes = num_classes
        self.max = torch.zeros(self.num_classes).cuda()

    def update(self, data, index, curve=None, iter_range=None, step=None):
        self.parameter = self.parameter.to(data.device)
        self.updated = self.updated.to(data.device)
        index = index.to(data.device)

        if curve is None:
            self.parameter[index] = self.alpha * self.parameter[index] + (1 - self.alpha * self.updated[index]) * data
        else:
            alpha = curve ** -(step / iter_range)
            self.parameter[index] = alpha * self.parameter[index] + (1 - alpha * self.updated[index]) * data

        self.updated[index] = 1

    def max_loss(self, label):
        label_index = torch.where(self.label == label)[0]
        return self.parameter[label_index].max()
    
    def mean_loss(self, label):
        label_index = torch.where(self.label == label)[0]
        return self.parameter[label_index].mean()


class Hook:
    def __init__(self, module, backward=False):
        self.feature = []
        if backward==False:
            self.hook = module.register_forward_hook(self.hook_fn)
        else:
            self.hook = module.register_backward_hook(self.hook_fn)

    def hook_fn(self, module, input, output):
        self.input = input
        self.output = output
        self.feature.append(output)

    def close(self):
        self.hook.remove()
        
class ShufflePatches(object):
  def __init__(self, patch_size):
    self.ps = patch_size

  def __call__(self, x):
    # divide the batch of images into non-overlapping patches
    u = nnf.unfold(x, kernel_size=self.ps, stride=self.ps, padding=0)
    # permute the patches of each image in the batch
    pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
    # fold the permuted patches back together
    f = nnf.fold(pu, x.shape[-2:], kernel_size=self.ps, stride=self.ps, padding=0)
    return f

class CenterOcclude(object):
    def __init__(self, occlusion_size):
        self.os = occlusion_size

    def __call__(self, x):
        # Get the dimensions of the input image
        b, c, h, w = x.size()
        
        # Create occlusion masks
        pad_x = int((w-self.os)/2)
        pad_y = int((h-self.os)/2)
        masks = torch.zeros_like(x)
        masks[:, :, pad_y:(h-pad_y), pad_x:(w-pad_x)] = 1
        mean = x[masks==1].view(b, c, self.os, self.os).mean(dim=(2, 3))
        occlusion = mean.view(b, c, 1, 1).expand(b, c, h, w)

        occluded_images = occlusion*masks + x*(1-masks)
        
        # for i in range(b):
        #     masks[i, :, occlusion_y[i]:occlusion_y[i]+self.os, occlusion_x[i]:occlusion_x[i]+self.os] = 0
        
        # # Apply the occlusion masks to the images
        # occluded_images = x * masks
        
        return occluded_images
    
def gray_scale(x):
    # Check if the input is in RGB format (with 3 channels)
    if x.shape[1] == 3:  # Assuming x has shape (batch_size, channels, height, width)
        # Use a weighted sum to mimic grayscale transformation on each RGB image
        # Standard weights for grayscale conversion are [0.2989, 0.5870, 0.1140]
        grayscale = 0.2989 * x[:, 0, :, :] + 0.5870 * x[:, 1, :, :] + 0.1140 * x[:, 2, :, :]
        # Expand grayscale to 3 channels to match the original shape (batch_size, 3, height, width)
        x = grayscale.unsqueeze(1).repeat(1, 3, 1, 1)
    return x