import torch
from sklearn.preprocessing import StandardScaler

class DataScaler:
    def __init__(self, x_train, y_train):
        x_train = x_train.clone().cpu().numpy()
        y_train = y_train.clone().cpu().numpy()

        self.input_scaler = StandardScaler().fit(x_train)
        self.output_scaler = StandardScaler().fit(y_train.reshape(-1, 1))

    def transform_x(self, x):
        """
        """
        device = x.device   
        dtype = x.dtype

        # Convert to numpy if torch tensor
        if torch.is_tensor(x):
            x = x.cpu().numpy()
        
        # transform
        x = self.input_scaler.transform(x, copy=True)
        
        return torch.tensor(x, dtype=dtype, device=device)
    
    def transform_y(self, y):
        """
        """
        device = y.device   
        dtype = y.dtype

        # Convert to numpy if torch tensor
        if torch.is_tensor(y):
            y = y.cpu().numpy()
        
        # transform
        y = self.output_scaler.transform(
            y.reshape(-1, 1), copy=True
        )
        
        return torch.tensor(y, dtype=dtype, device=device)
    
    def inverse_transform_y(self, y, y_std=None):
        device = y.device
        dtype = y.dtype

        # Convert to numpy if torch tensor
        if torch.is_tensor(y):
            y = y.cpu().numpy()

        if y_std is not None and torch.is_tensor(y_std):
            y_std = y_std.cpu().numpy()
        
        # Inverse transform
        y = self.output_scaler.inverse_transform(y, copy=True)
        
        if y_std is not None:
            y_std = y_std.copy() * self.output_scaler.scale_
            return (
                torch.tensor(y, dtype=dtype, device=device), 
                torch.tensor(y_std, dtype=dtype, device=device)
            )

        return torch.tensor(y, dtype=dtype, device=device)
    
    