import numpy as np
from torch.nn import functional as F
import torch

def unfold_tensor (x, step_h, step_w):
    kh, kw = step_h, step_w  # kernel size
    dh, dw = step_h, step_w  # stride
    nh, remainder = np.divmod(x.size(2), kh)
    nh += bool(remainder)
    
    nw, remainder = np.divmod(x.size(3), kw)
    nw += bool(remainder)    
    
    pad_h, pad_w = nh*kh - x.size(2), nw*kw - x.size(3)
    x = F.pad(x, ( 0, pad_h, 0, pad_w))
    patches = x.unfold(2, kh, dh).unfold(3, kw, dw)
    unfold_shape = patches.size()
    patches = patches.flatten(start_dim = 4).flatten(start_dim = 2, end_dim = 3)
    # patches = patches.reshape(-1,unfold_shape[1], unfold_shape[2]*unfold_shape[3], unfold_shape[4]*unfold_shape[5])
    patches = patches.permute(0, 1, 3, 2)
    return patches, unfold_shape