# Normalizers for diffusion.
from typing import List

import torch
from torch import nn


# Base class for normalizers with normalize and unnormalize methods.
class BaseNormalizer(nn.Module):
    def __init__(self):
        super().__init__()

    def normalize(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError


class IdentityNormalizer(BaseNormalizer):
    def __init__(self):
        super().__init__()

    def normalize(self, x):
        return x

    def unnormalize(self, x):
        return x


# Min-max normalizer that scales data to [-1, 1] and supports resetting
class MinMaxNormalizer(BaseNormalizer):
    def __init__(self, dataset: torch.Tensor, eps: float = 1e-5):
        super().__init__()
        self.register_buffer('min', dataset.min(dim=0).values)
        self.register_buffer('max', dataset.max(dim=0).values + eps)
        print('Mins:', self.min)
        print('Maxs:', self.max)

    def normalize(self, x: torch.Tensor) -> torch.Tensor:
        return (x - self.min) / (self.max - self.min) * 2 - 1

    def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
        return (x + 1) / 2 * (self.max - self.min) + self.min

    def reset(self, dataset: torch.Tensor, eps: float = 1e-5):
        self.min = dataset.min(dim=0).values
        self.max = dataset.max(dim=0).values + eps
        print('Mins:', self.min)
        print('Maxs:', self.max)


# Standardizes data to zero mean and target std,
# with optional dimension skipping.
class Normalizer(BaseNormalizer):
    def __init__(self,
                 dataset: torch.Tensor,
                 eps: float = 1e-5,
                 skip_dims: List[int] = [],
                 target_std: float = 1.0):
        super().__init__()
        self.register_buffer('mean', dataset.mean(dim=0))
        self.register_buffer('std', dataset.std(dim=0) + eps)
        self.skip_dims = skip_dims
        if skip_dims:
            self.mean[skip_dims] = 0.0
            self.std[skip_dims] = 1.0
        self.target_std = target_std
        print('Means:', self.mean)
        print('Stds:', self.std)

    def normalize(self, x: torch.Tensor) -> torch.Tensor:
        return (x - self.mean) / self.std * self.target_std

    def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
        return x / self.target_std * self.std + self.mean

    def reset(self, dataset: torch.Tensor, eps: float = 1e-5):
        self.mean = dataset.mean(dim=0)
        self.std = dataset.std(dim=0) + eps
        if self.skip_dims:
            self.mean[self.skip_dims] = 0.0
            self.std[self.skip_dims] = 1.0
        print('Means:', self.mean)
        print('Stds:', self.std)


# Factory function to create a normalizer based on the specified type.
def normalizer_factory(normalizer_type: str,
                       dataset: torch.Tensor,
                       skip_dims: List[int] = [],
                       **kwargs) -> BaseNormalizer:

    if normalizer_type == "minmax":
        return MinMaxNormalizer(dataset, **kwargs)
    elif normalizer_type == "standard":
        return Normalizer(dataset, skip_dims=skip_dims, **kwargs)
    elif normalizer_type == "identity":
        return IdentityNormalizer()
    else:
        raise ValueError(f'Unknown normalizer type: {normalizer_type}')
