import os
import pickle
import torch
from torch.utils.data import TensorDataset, DataLoader
from pytorch_lightning import LightningDataModule
from utils import add_label_bias

class FairDataModule(LightningDataModule):
    def __init__(self, 
                 data_location, 
                 dataset_name, 
                 bias_amount=0.2,
                 add_bias=False,
                 batch_size=128):
        super().__init__()
        self.data_location = data_location
        self.dataset_name = dataset_name
        self.batch_size = batch_size
        self.add_bias = add_bias
        self.theta_dict = {'theta_0_p':bias_amount,'theta_0_m':0,'theta_1_p':0,'theta_1_m':bias_amount}
        
        self.train_file_name = os.path.join(data_location, dataset_name, 'train.pkl')
        self.test_file_name = os.path.join(data_location, dataset_name, 'test.pkl')
        
    def setup(self, stage= None):
        with open(self.train_file_name, "rb") as f:
            train_raw_data = pickle.load(f)
        x_train, s_train, y_train = train_raw_data['x'], train_raw_data['s'], train_raw_data['y']
      
        with open(self.test_file_name, "rb") as f:
            test_raw_data = pickle.load(f)
        x_test, s_test, y_test = test_raw_data['x'], test_raw_data['s'], test_raw_data['y']

        
        if self.add_bias:
            y_train = add_label_bias(y_train.squeeze(), s_train, self.theta_dict).unsqueeze(1)


        self.train_dataset = TensorDataset(
            torch.tensor(x_train, dtype=torch.float32),
            torch.tensor(s_train.squeeze(-1), dtype=torch.long),
            torch.tensor(y_train.squeeze(-1), dtype=torch.long)
        )
        
        self.test_dataset = TensorDataset(
            torch.tensor(x_test, dtype=torch.float32),
            torch.tensor(s_test.squeeze(-1), dtype=torch.long),
            torch.tensor(y_test.squeeze(-1), dtype=torch.long)
        )
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)
