import numpy as np
import scipy.sparse as sp
import torch
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from torch_geometric.data.data import Data
from torch_geometric.loader import DataLoader

class MinMaxTransformation(object):
    def __init__(self, data_loader: DataLoader):
        # Get min and max values
        self.scaler = MinMaxScaler()

        for data in data_loader:
            self.scaler.partial_fit(data.x)

    def __call__(self, data: Data) -> Data:
        transformed = self.scaler.transform(data.x)
        data.x = torch.tensor(transformed, dtype=torch.float)
        return data

class StandardTransformation(object):
    def __init__(self, data_loader: DataLoader):
        # Get min and max values
        self.scaler = StandardScaler()

        for data in data_loader:
            self.scaler.partial_fit(data.x)

    def __call__(self, data: Data) -> Data:
        transformed = self.scaler.transform(data.x)
        data.x = torch.tensor(transformed, dtype=torch.float)
        return data

class RowNormalisation(object):
    def __init__(self, *args, **kwargs):
        pass

    def __call__(self, data):

        mx = data.x

        """Row-normalize sparse matrix"""
        rowsum = np.array(mx.sum(1))
        r_inv = np.power(rowsum, -1).flatten()
        r_inv[np.isinf(r_inv)] = 0.
        r_mat_inv = sp.diags(r_inv)
        mx = r_mat_inv.dot(mx)

        data.x = torch.tensor(mx, dtype=torch.float)

        return data