import numpy as np
import torch
import torch.nn as nn

class RevIN(nn.Module):
    def __init__(self, channel, output_dim):
        super(RevIN, self).__init__()
        self.output_dim = output_dim

    def forward(self, x):
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x).float()
        # Calculate mean and std along dim=1
        self.means = x.mean(1, keepdim=True).detach()
        self.stdev = torch.sqrt(x.var(1, keepdim=True, unbiased=False) + 1e-5).detach()
        
        # Normalize using learned parameters
        x_normalized = (x - self.means) / self.stdev
        return x_normalized
    
    def inverse_normalize(self, x_normalized):
        # Adjust dimensions if necessary
        if self.stdev.dim() == 2:
            stdev = self.stdev.unsqueeze(1)
        else:
            stdev = self.stdev

        if self.means.dim() == 2:
            means = self.means.unsqueeze(1)
        else:
            means = self.means

        x_normalized = x_normalized * stdev[:, 0, :].unsqueeze(1).repeat(1, self.output_dim, 1)
        x_normalized = x_normalized + means[:, 0, :].unsqueeze(1).repeat(1, self.output_dim, 1)
        return x_normalized
