import torch
import torch.nn as nn
import torch.nn.functional as F

# Relaxed Instance Frequency-wise Normalization
# y = lambda * LN(x) + (1-lambda) * IFN(x)

class IFN(nn.Module):
    def __init__(self, eps = 1e-5):
        super().__init__()
        self.eps = eps

    def forward(self, x):
        mu  = x.mean(dim=(1, 3), keepdim=True) # E_IFN
        var = x.var (dim=(1, 3), keepdim=True, unbiased=False)
        return (x - mu) / torch.sqrt(var + self.eps) # IFN
class InsLN(nn.Module):
    # Instance LayerNorm
    def __init__(self, eps = 1e-5):
        super().__init__()
        self.eps = eps

    def forward(self, x):
        mu  = x.mean(dim=(1, 2, 3), keepdim=True)
        var = x.var (dim=(1, 2, 3), keepdim=True, unbiased=False)
        return (x - mu) / torch.sqrt(var + self.eps)    # LN
class RFN(nn.Module):
    def __init__(self, lam = 0.5):
        super().__init__()
        self.lam  = lam
        self.ifn  = IFN()
        self.ins_ln = InsLN()

    def forward(self, x):
        return self.lam * self.ins_ln(x) + (1.0 - self.lam) * self.ifn(x)
