

import torch
import torch.nn as nn


class RevIN(nn.Module):

    """Reversible instance normalization for stabilizing multivariate series."""

    def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False):

        super(RevIN, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.affine = affine
        self.subtract_last = subtract_last

        if self.affine:
            self._init_params()

    def forward(self, x, mode: str):
        """Normalize (`norm`) or denormalize (`denorm`) the incoming tensor."""

        if mode == 'norm':
            self._get_statistics(x)
            x = self._normalize(x)
        elif mode == 'denorm':
            x = self._denormalize(x)
        else:
            raise ValueError(
                f"Invalid mode '{mode}'. Expected 'norm' or 'denorm'.")
        return x

    def _init_params(self):
        """Create affine parameters used after normalization."""

        self.affine_weight = nn.Parameter(torch.ones(self.num_features))
        self.affine_bias = nn.Parameter(torch.zeros(self.num_features))

    def _get_statistics(self, x):
        """Compute per-instance statistics required by RevIN."""

        dim2reduce = tuple(range(1, x.ndim - 1))

        if self.subtract_last:

            self.last = x[:, -1, :].unsqueeze(1)
        else:

            self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()

        self.stdev = torch.sqrt(
            torch.var(x, dim=dim2reduce, keepdim=True,
                      unbiased=False) + self.eps
        ).detach()

    def _normalize(self, x):
        """Apply RevIN normalization with optional affine transform."""

        if self.subtract_last:
            x = x - self.last
        else:
            x = x - self.mean

        x = x / self.stdev

        if self.affine:
            x = x * self.affine_weight
            x = x + self.affine_bias

        return x

    def _denormalize(self, x):
        """Invert RevIN normalization for the stored statistics."""

        if self.affine:
            x = x - self.affine_bias
            x = x / (self.affine_weight + self.eps * self.eps)

        x = x * self.stdev

        if self.subtract_last:
            x = x + self.last
        else:
            x = x + self.mean

        return x


class RevINComplex(nn.Module):

    """Pair of RevIN modules operating on real and imaginary parts separately."""

    def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False):

        super(RevINComplex, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.affine = affine
        self.subtract_last = subtract_last

        self.revin_real = RevIN(num_features, eps, affine, subtract_last)
        self.revin_imag = RevIN(num_features, eps, affine, subtract_last)

    def forward(self, x, mode: str):
        """Apply RevIN on complex tensors by handling components independently."""

        if torch.is_complex(x):

            x_real = self.revin_real(x.real, mode)
            x_imag = self.revin_imag(x.imag, mode)
            return torch.complex(x_real, x_imag)
        else:

            return self.revin_real(x, mode)


def test_revin():

    batch_size = 4
    seq_len = 96
    num_features = 7

    torch.manual_seed(42)
    x = torch.randn(batch_size, seq_len, num_features)

    for i in range(num_features):
        x[:, :, i] = x[:, :, i] * (i + 1) * 10
        trend = torch.linspace(
            0, i * 5, seq_len).unsqueeze(0).repeat(batch_size, 1)
        x[:, :, i] = x[:, :, i] + trend

    print("Original data statistics:")
    print(f"Mean: {x.mean(dim=1).mean(dim=0)}")
    print(f"Std: {x.std(dim=1).mean(dim=0)}")
    print(f"Shape: {x.shape}")

    revin = RevIN(num_features, affine=True, subtract_last=False)

    x_norm = revin(x, 'norm')
    print(f"\nNormalized data statistics:")
    print(f"Mean: {x_norm.mean(dim=1).mean(dim=0)}")
    print(f"Std: {x_norm.std(dim=1).mean(dim=0)}")

    x_denorm = revin(x_norm, 'denorm')
    print(f"\nDenormalized data statistics:")
    print(f"Mean: {x_denorm.mean(dim=1).mean(dim=0)}")
    print(f"Std: {x_denorm.std(dim=1).mean(dim=0)}")

    reconstruction_error = torch.abs(x - x_denorm).mean()
    print(f"\nReconstruction error: {reconstruction_error:.6f}")
    print(f"Max absolute difference: {torch.abs(x - x_denorm).max():.6f}")

    print(f"\n" + "="*50)
    print("Testing Complex RevIN")
    x_complex = torch.complex(x, x * 0.5)

    revin_complex = RevINComplex(num_features, affine=True)
    x_complex_norm = revin_complex(x_complex, 'norm')
    x_complex_denorm = revin_complex(x_complex_norm, 'denorm')

    complex_error = torch.abs(x_complex - x_complex_denorm).mean()
    print(f"Complex reconstruction error: {complex_error:.6f}")


if __name__ == "__main__":
    test_revin()
