# Normalizers for diffusion.

from typing import List

import torch
from torch import nn


class BaseNormalizer(nn.Module):
    def __init__(self):
        super().__init__()

    def normalize(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError


class MinMaxNormalizer(BaseNormalizer):
    def __init__(self, dataset: torch.Tensor, eps: float = 1e-5):
        super().__init__()
        self.register_buffer('min', dataset.min(dim=0).values)
        self.register_buffer('max', dataset.max(dim=0).values + eps)
        # print('Mins:', self.min)
        # print('Maxs:', self.max)

    def normalize(self, x: torch.Tensor) -> torch.Tensor:
        return (x - self.min) / (self.max - self.min) * 2 - 1

    def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
        return (x + 1) / 2 * (self.max - self.min) + self.min

    def reset(self, dataset: torch.Tensor, eps: float = 1e-5):
        self.min = dataset.min(dim=0).values
        self.max = dataset.max(dim=0).values + eps
        # print('Mins:', self.min)
        # print('Maxs:', self.max)


class Normalizer(BaseNormalizer):
    def __init__(
            self,
            dataset: torch.Tensor,
            eps: float = 1e-5,
            skip_dims: List[int] = [],
            target_std: float = 1.0,
            added_dims: List[int] = None,
    ):
        super().__init__()
        self.register_buffer('mean', dataset.mean(dim=0))
        self.register_buffer('std', dataset.std(dim=0) + eps)
        self.skip_dims = skip_dims
        self.added_dims = added_dims
        if skip_dims:
            self.mean[skip_dims] = 0.0
            self.std[skip_dims] = 1.0
        self.target_std = target_std
        # print('Means:', self.mean)
        # print('Stds:', self.std)

    def normalize(self, x: torch.Tensor) -> torch.Tensor:
        if x.shape[-1] > self.mean.shape[-1]:
            assert self.added_dims is not None, "Please define the added dims."
            obs_dim, act_dim = self.added_dims
            mean = torch.cat([self.mean, self.mean[obs_dim:obs_dim+act_dim]], dim=-1)
            std = torch.cat([self.std, self.std[obs_dim:obs_dim+act_dim]], dim=-1)
            return (x - mean) / std * self.target_std
        else:
            return (x - self.mean) / self.std * self.target_std

    def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
        if x.shape[-1] > self.mean.shape[-1]:
            assert self.added_dims is not None, "Please define the added dims."
            obs_dim, act_dim = self.added_dims
            mean = torch.cat([self.mean, self.mean[obs_dim:obs_dim+act_dim]], dim=-1)
            std = torch.cat([self.std, self.std[obs_dim:obs_dim+act_dim]], dim=-1)
            return x / self.target_std * std + mean
        else:
            return x / self.target_std * self.std + self.mean

    def reset(self, dataset: torch.Tensor, eps: float = 1e-5):
        if self.mean.shape[0] != dataset.shape[1]:
            size = self.mean.shape[0]
            dataset = dataset[:, -size:]
        self.mean = dataset.mean(dim=0)
        self.std = dataset.std(dim=0) + eps
        if self.skip_dims:
            self.mean[self.skip_dims] = 0.0
            self.std[self.skip_dims] = 1.0
        # print('Means:', self.mean)
        # print('Stds:', self.std)


def normalizer_factory(
        normalizer_type: str,
        dataset: torch.Tensor,
        skip_dims: List[int] = [],
        added_dims: List[int] = None,
        **kwargs,
) -> BaseNormalizer:
    if normalizer_type == 'minmax':
        return MinMaxNormalizer(dataset, **kwargs)
    elif normalizer_type == 'standard':
        return Normalizer(dataset, 
                          skip_dims=skip_dims, 
                          added_dims=added_dims, 
                          **kwargs)
    else:
        raise ValueError(f'Unknown normalizer type: {normalizer_type}')
