import torch
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader

class TablePandasDataset(Dataset):
    """Pandas dataset.    """

    def __init__(self, pd, cov_list, utility_tag = 'utility_cat', 
                 sensitive_tag = 'sensitive_cat', 
                 noisy_sensitive_tag = 'noisy_cat', 
                 transform = None):
        """
        Args:
            pd: Pandas dataframe,
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.U_torch = torch.Tensor(np.vstack(pd[utility_tag].values).astype('float32')) #stacking because they are saved as list of numpy arrays
        self.A_torch = torch.Tensor(np.vstack(pd[sensitive_tag].values).astype('float32'))
        self.Z_torch = torch.Tensor(np.vstack(pd[noisy_sensitive_tag].values).astype('float32'))
        self.X_torch = torch.Tensor(pd[cov_list].to_numpy().astype('float32'))
        self.transform = transform

    def __len__(self):
        return self.U_torch.shape[0]

    def __getitem__(self, idx):
        # data = self.pd_torch[idx]
        U = self.U_torch[idx]
        A = self.A_torch[idx]
        Z = self.Z_torch[idx]
        X = self.X_torch[idx]
        if self.transform:
            X = self.transform(X)

        return X, U, A, Z

class TablePandasDataset_Noise_Estimate(Dataset):
    """Pandas dataset.    """

    def __init__(self, pd, cov_list, utility_tag = 'utility_cat', transform = None):
        """
        Args:
            pd: Pandas dataframe,
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.U_torch = torch.Tensor(np.vstack(pd[utility_tag].values).astype('float32')) #stacking because they are saved as list of numpy arrays
        self.X_torch = torch.Tensor(pd[cov_list].to_numpy().astype('float32'))
        self.transform = transform

    def __len__(self):
        return self.U_torch.shape[0]

    def __getitem__(self, idx):
        # data = self.pd_torch[idx]
        U = self.U_torch[idx]
        X = self.X_torch[idx]
        if self.transform:
            X = self.transform(X)

        return X, U