import numpy as np
import scipy.interpolate as interpolate
import torch
class Normalizer:
    def __init__(self, x):
        self.X = x.astype(np.float32)
        self.mins = x.min(axis=0)
        self.maxs = x.max(axis=0)

    def __call__(self, *args, **kwargs):
        return self.normalize(*args, **kwargs)

    def normalize(self, *arg, **kwargs):
        raise NotImplementedError

    def unnormalize(self, *arg, **kwargs):
        raise NotImplementedError

class CDFNormalizer(Normalizer):
    '''
        makes training data uniform (over each dimension) by transforming it with marginal CDFs
    '''

    def __init__(self, X):
        super().__init__(X)
        self.dim = self.X.shape[1]
        self.cdfs = [CDFNormalizer1d(self.X[:, i]) for i in range(self.dim)]

    def __repr__(self):
        return f'[ CDFNormalizer ] dim: {self.mins.size}\n' + '    |    '.join(
            f'{i:3d}: {cdf}' for i, cdf in enumerate(self.cdfs)
        )

    def wrap(self, fn_name, x):
        shape = x.shape
        ## reshape to 2d
        x = x.reshape(-1, self.dim)
        out = np.zeros_like(x)
        for i, cdf in enumerate(self.cdfs):
            fn = getattr(cdf, fn_name)
            out[:, i] = fn(x[:, i])
        return out.reshape(shape)

    def normalize(self, x):
        return self.wrap('normalize', x)

    def unnormalize(self, x):
        return self.wrap('unnormalize', x)

class CDFNormalizer1d:
    '''
        CDF normalizer for a single dimension
    '''

    def __init__(self, X):
        self.X = X.astype(np.float32)
        if self.X.max() == self.X.min():
            self.constant = True
        else:
            self.constant = False
            quantiles, cumprob = empirical_cdf(self.X)
            self.fn = interpolate.interp1d(quantiles, cumprob)
            self.inv = interpolate.interp1d(cumprob, quantiles)

            self.xmin, self.xmax = quantiles.min(), quantiles.max()
            self.ymin, self.ymax = cumprob.min(), cumprob.max()

    def __repr__(self):
        return (
            f'[{np.round(self.xmin, 2):.4f}, {np.round(self.xmax, 2):.4f}'
        )

    def normalize(self, x):
        if self.constant:
            return x

        x = np.clip(x, self.xmin, self.xmax)
        ## [ 0, 1 ]
        y = self.fn(x)
        ## [ -1, 1 ]
        y = 2 * y - 1
        return y

    def unnormalize(self, x, eps=1e-4):
        '''
            X : [ -1, 1 ]
        '''
        ## [ -1, 1 ] --> [ 0, 1 ]
        if self.constant:
            return x

        x = (x + 1) / 2.

        if (x < self.ymin - eps).any() or (x > self.ymax + eps).any():
            print(
                f'''[ dataset/normalization ] Warning: out of range in unnormalize: '''
                f'''[{x.min()}, {x.max()}] | '''
                f'''x : [{self.xmin}, {self.xmax}] | '''
                f'''y: [{self.ymin}, {self.ymax}]'''
            )

        x = np.clip(x, self.ymin, self.ymax)

        y = self.inv(x)
        return y

class NoNormalizer(Normalizer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.range = self.maxs - self.mins

    def __repr__(self):
        return (
            f'mins: {self.mins}, maxs: {self.maxs}'
        )

    def normalize(self, x):
        return x

    def unnormalize(self, x):
        return x

class LinearNormalizer(Normalizer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.range = self.maxs - self.mins

    def __repr__(self):
        return (
            f'mins: {self.mins}, maxs: {self.maxs}'
        )

    def normalize(self, x):
        return 2 * (x - self.mins) / self.range - 1

    def unnormalize(self, x):
        return (x + 1) / 2 * self.range + self.mins

class GaussianNormalizer(Normalizer):
    '''
        normalizes to zero mean and unit variance
    '''

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.means = self.X.mean(axis=0)
        self.stds = self.X.std(axis=0)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        self.means_torch = torch.tensor(self.means).to(device, dtype=torch.float32)
        self.stds = np.maximum(self.stds, 1e-4)
        self.stds_torch = torch.tensor(self.stds).to(device, dtype=torch.float32)
        self.z = 1

    def __repr__(self):
        return (
            f'''[ Normalizer ] dim: {self.mins.size}\n    '''
            f'''means: {np.round(self.means, 2)}\n    '''
            f'''stds: {np.round(self.z * self.stds, 2)}\n'''
        )

    def normalize(self, x):
        return (x - self.means) / self.stds

    def unnormalize(self, x):
        return x * self.stds + self.means

    def normalize_torch(self, x):
        return (x - self.means_torch) / self.stds_torch

    def unnormalize_torch(self, x):
        return x * self.stds_torch + self.means_torch

def empirical_cdf(sample):
    quantiles, counts = np.unique(sample, return_counts=True)
    cumprob = np.cumsum(counts).astype(np.double) / sample.size

    return quantiles, cumprob

