import torch
from torch.utils import mkldnn
import ofa.config


def is_mkldnn_layout(x):
    return x.layout == torch._mkldnn

def skip_mkldnn(func):
    """
    Decorator that tells the forward pass to 
    """

    def wrapper(_, x):
        if ofa.config.USE_MKLDNN is False:
            return func(_, x)
        
        # convert any tensors to dense
        if type(x) == torch.Tensor and is_mkldnn_layout(x): 
            x = x.to_dense()
        
        # run on cpu, and move back to MKLDNN
        return func(_, x).to_mkldnn()
    return wrapper


# TODO: fix to use decorator
def mat_mul(x,y):
    if ofa.config.USE_MKLDNN is False:
        return x*y

    if is_mkldnn_layout(x): 
        x = x.to_dense()
    if is_mkldnn_layout(y): 
        y = y.to_dense()
    
    return (x*y).to_mkldnn()

def squeeze_hw(x):
    if ofa.config.USE_MKLDNN is False:
        return torch.squeeze(x)
    else:
        if is_mkldnn_layout(x):
            x = x.to_dense()
        x = torch.squeeze(x, dim=2)
        x = torch.squeeze(x, dim=2)
        return x.to_mkldnn()

