import torch
import numpy as np
from scipy.stats import ecdf

class QuantileNormalizer:

    def __init__(self):
        pass

    def fit(self, x):

        self.n_samples = x.shape[0]
        self.n_features = x.shape[1]
        self.fitted = True
        
        self.ecdfs = []
        self.quantiles = []
        self.grid = np.linspace(0., 1., 1000)
        for i_feature in range(self.n_features):
            self.ecdfs.append(ecdf(x[:, i_feature].numpy()))
            self.quantiles.append(np.quantile(x[:, i_feature], q=self.grid))

    def transform(self, x):

        if not self.fitted:
            raise ValueError('Fit the QuantileNormalizer before calling `transform` method.')

        z = torch.empty_like(x)
        for i_feature in range(self.n_features):
            u = self.ecdfs[i_feature].cdf.evaluate(x[:, i_feature].numpy())
            tol = 1/self.n_samples
            u[u == 1] = 1.-tol
            u[u == 0] = tol        
            
            z[:, i_feature] = torch.distributions.normal.Normal(0., 1.).icdf(torch.from_numpy(u))

        return z

    def inverse(self, z):

        if not self.fitted:
            raise ValueError('Fit the QuantileNormalizer before calling `inverse` method.')

        x = torch.empty_like(z)
        for i_feature in range(self.n_features):
            u = torch.distributions.normal.Normal(0., 1.).cdf(z[:, i_feature])
            x[:, i_feature] = torch.from_numpy(np.interp(u.numpy(), self.grid, self.quantiles[i_feature]))

        return x