import torch
import numpy as np
from pytorch_lightning import LightningDataModule
from sklearn.preprocessing import StandardScaler

def pushforward_nintervention(sample_size):
    As = 10.5 + 0.8* np.random.randn(sample_size, 1)
    Bs = As*0.02 + .5*np.random.randn(sample_size, 1)
    no_noise_Cs = Bs + 3.
    Cs = no_noise_Cs + np.random.randn(sample_size, 1)
    
    tilde_As = 10.5 + 0.8 * np.random.randn(sample_size, 1)
    tilde_Bs = np.random.uniform(2, 5) + 0.02*np.random.randn(sample_size, 1)
    tilde_Cs = no_noise_Cs + np.random.randn(sample_size, 1) # use the old B to break the link
    obs_zs = np.concatenate([As, Bs, Cs], axis=-1)
    int_zs = np.concatenate([tilde_As, tilde_Bs, tilde_Cs], axis=-1)
    return obs_zs, int_zs
    
# DATASET
class SyntheticDataSet(torch.utils.data.Dataset):
    def __init__(self, sample_size, mixing):
        self.sample_size = sample_size
        self.mixing = mixing
        
        # initialize
        obs_zs, int_zs = pushforward_nintervention(sample_size)
        
        self.data = np.stack([self.mixing(torch.from_numpy(obs_zs).float()).numpy(), 
                              self.mixing(torch.from_numpy(int_zs).float()).numpy()], 1) # bs, 2, 3
        self.zs = np.stack([obs_zs, int_zs], 1) # bs, 2, 3
        self.data = StandardScaler().fit_transform(self.data.reshape(-1, 3)).reshape(-1, 2, 3)
        self.zs = StandardScaler().fit_transform(self.zs.reshape(-1, 3)).reshape(-1, 2, 3)
    
    def __len__(self):
        return self.sample_size
    
    def __getitem__(self, idx):
        return self.data[idx], self.zs[idx]

class DataModule(LightningDataModule):
    def __init__(self, batch_size=100):
        super().__init__()
        self.batch_size = batch_size
        self.dataset = SyntheticDataSet(sample_size=200_000, mixing=NonlinearMixing(3,3))
        train_size = int(0.8 * len(self.dataset))
        val_size = len(self.dataset) - train_size
        self.train_set, self.val_set = torch.utils.data.random_split(self.dataset, [train_size, val_size])


    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_set, num_workers=7, 
                                           batch_size=self.batch_size, 
                                           shuffle=True, pin_memory=True, drop_last=True)
    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_set, num_workers=7, 
                                           batch_size=self.batch_size, 
                                           shuffle=False, pin_memory=True, drop_last=False)