import torch
import numpy as np
from scipy.stats import ecdf
import matplotlib.pyplot as plt

class QuantileNormalizer:

    def __init__(self):
        pass

    def fit(self, x):

        self.n_samples = x.shape[0]
        self.n_features = x.shape[1]
        self.fitted = True
        
        self.ecdfs = []
        self.quantiles = []
        self.grid = np.linspace(0., 1., 1000)
        for i_feature in range(self.n_features):
            self.ecdfs.append(ecdf(x[:, i_feature].numpy()))
            self.quantiles.append(np.quantile(x[:, i_feature], q=self.grid))

    def transform(self, x):

        if not self.fitted:
            raise ValueError('Fit the QuantileNormalizer before calling `transform` method.')

        z = torch.empty_like(x)
        for i_feature in range(self.n_features):
            u = self.ecdfs[i_feature].cdf.evaluate(x[:, i_feature].numpy())
            tol = 1/self.n_samples
            u[u == 1] = 1.-tol
            u[u == 0] = tol        
            
            z[:, i_feature] = torch.distributions.normal.Normal(0., 1.).icdf(torch.from_numpy(u))

        return z

    def inverse(self, z):

        if not self.fitted:
            raise ValueError('Fit the QuantileNormalizer before calling `inverse` method.')

        x = torch.empty_like(z)
        for i_feature in range(self.n_features):
            u = torch.distributions.normal.Normal(0., 1.).cdf(z[:, i_feature])
            x[:, i_feature] = torch.from_numpy(np.interp(u.numpy(), self.grid, self.quantiles[i_feature]))

        return x


class StandardScaler():

    def __init__(self):
        pass

    def fit(self, x):
        self.means = x.mean(dim=0, keepdim=True)
        self.stds = x.std(dim=0, keepdim=True)
        
    def transform(self, x):
        return (x - self.means) / self.stds

    def inverse_transform(self, x):
        return (x * self.stds) + self.means


def compare_histograms(tensors, colors, labels, **kwargs):

    n_features = tensors[0].shape[1]
    for tensor in tensors:
        assert tensor.shape[1] == n_features

    for i_feature in range(n_features):
        for tensor, color, label in zip(tensors, colors, labels):
            plt.hist(tensor[:, i_feature], color=color, alpha=0.2, label=label, density=True, **kwargs)
        plt.legend()
        plt.show()

def check_mem(device, msg=None):    
    # Clean GPU cache
    if torch.cuda.is_available() and device.type == 'cuda':
        torch.cuda.empty_cache()
    else:
        print('Using CPU: no memory check.')
        return None
    
    print('------------------------------------------------------------')
    if msg is not None:
        print('MEMORY CHECK: ' + msg)
    else:
        print('MEMORY CHECK')     
    
    free_memory, _ = torch.cuda.mem_get_info(device)
    allocated_memory = torch.cuda.memory_allocated()
    print(f'Current device: {device}')
    print(f'Free memory: {free_memory*1e-9:.2f} GB')
    print(f'Allocated memory: {allocated_memory*1e-9:.2f} GB')
    print('------------------------------------------------------------')