import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision.transforms.functional as VF
from PIL import Image
import torchvision
import cv2
import os
def get_data(dataset, exp_num):
    if dataset == 'complex_ge':
        return simulate_complex_ge()
    elif dataset == 'mnist':
        return get_mnist()
    elif dataset == 'dshapes':
        return get_dshapes()
    elif dataset == 'omniglot':
        train_dset = CounterfactualOmniglot(split='train', mode='train')
        eval_train_dset = CounterfactualOmniglot(split='train', mode='test')
        test_dset = CounterfactualOmniglot(split='test', mode='test')
        return train_dset, train_dset, eval_train_dset, test_dset
    elif dataset == 'cisim':
        return get_cisim()
    elif dataset == 'dissim':
        return get_dissim()
    elif dataset == 'ihdp':
        return get_ihdp100(exp_num=exp_num)
    elif dataset == 'confounder':
        return get_confounder(case=exp_num)
    elif dataset == 'toy':
        return simulate_toy(case=exp_num)
    elif dataset == 'nonmono':
        return simulate_nonmono(case=exp_num)

def simulate(num_samples=10000):
    x = np.random.normal(size=(num_samples, 1))
    t = np.random.normal(size=(num_samples, 1))
    e = np.random.normal(size=(num_samples, 1))
    #y = 2 * x + 3 * t + e
    y = np.log(x**2 + t**2 + 1e-3)+e
    return x, t, y



def get_confounder(num_samples=81000, case=1):
    np.random.seed(42)
    if case==1:
        print('>>>>> Using confounder case 1 <<<<')
        # X<-C->T
        #c = np.random.uniform(size=(num_samples, 1))-0.5
        x = np.random.uniform(size=(num_samples, 1)) + c
        t = np.random.uniform(size=(num_samples, 1)) + c
        e = np.random.normal(size=(num_samples, 1))
        cf_t = np.ones_like(t)*0.5
        y = x + t + e
        cf_y = x + cf_t + e
    elif case == 2:
        print('>>>>> Using confounder case 2 <<<<')
        # X<-C->Y
        c = np.random.uniform(size=(num_samples, 1)) - 0.5
        x = np.random.uniform(size=(num_samples, 1)) + c
        t = np.random.uniform(size=(num_samples, 1))
        e = np.random.normal(size=(num_samples, 1))
        cf_t = np.ones_like(t)*0.5
        y = x + t + e + c
        cf_y = x + cf_t + e + c
    elif case == 3:
        print('>>>>> Using confounder case 3 <<<<')
        # T<-C>-Y
        c = np.random.uniform(size=(num_samples, 1)) - 0.5
        x = np.random.uniform(size=(num_samples, 1))
        t = np.random.uniform(size=(num_samples, 1)) + c
        e = np.random.normal(size=(num_samples, 1))
        y = x + t + e + c
        cf_t = np.ones_like(t)*0.5
        cf_y = x + cf_t + e + c

    ntrain = num_samples - 1000
    train_x = torch.from_numpy(x).float()[0:ntrain]
    train_t = torch.from_numpy(t).float()[0:ntrain]
    train_y = torch.from_numpy(y).float()[0:ntrain]
    t_cf_train = torch.from_numpy(cf_t).float()[0:ntrain]
    y_cf_train = torch.from_numpy(cf_y).float()[0:ntrain]

    test_x = torch.from_numpy(x).float()[ntrain:]
    test_t = torch.from_numpy(t).float()[ntrain:]
    test_y = torch.from_numpy(y).float()[ntrain:]
    t_cf_test = torch.from_numpy(cf_t).float()[ntrain:]
    y_cf_test = torch.from_numpy(cf_y).float()[ntrain:]
    train_dset = DictDataset({'covariate': train_x, 'treatment': train_t, 'outcome': train_y})
    test_dset = DictDataset({'covariate': test_x, 'treatment': test_t, 'outcome': test_y,
                             'cf_treatment': t_cf_test, 'cf_outcome': y_cf_test})
    print(y_cf_test[:10], ' >>>> y_cf_test ', t_cf_test[:10], ' >>>> t_cf_test <<<<')
    return train_dset, train_dset, test_dset, test_dset



def monte_carlo(x,t,y, target_y):
    ind = np.logical_and(np.abs(x-0.5)<=0.01, np.abs(t-0.5)<0.01)
    yy = y[ind]
    print(yy[:10], ' >>>> yy ', target_y)
    print(len(y), len(yy), ' >>>len yy')
    if len(yy)>0:
        cdf = np.sum(yy<=target_y)/len(yy)
    else:
        cdf = 0
    return cdf

def simulate_toy(num_samples=100000, case=1):
    assert case in [1, 2, 3, 4, 5]
    x = np.random.uniform(size=(num_samples, 1))
    t = np.random.uniform(size=(num_samples, 1))
    e = np.random.normal(size=(num_samples, 1))
    test_x = np.array([0.5]*1000).reshape(-1,1)
    test_t = np.array([0.5]*1000).reshape(-1,1)
    test_e = np.array([0.5]*1000).reshape(-1,1)
    cf_t = np.linspace(0, 1, 1000).reshape(-1,1)

    def construct(x,t,e):
        print('>>>>>>>>>>>>>> CASEE ', case)
        if case == 1:
            y = x + t + e
        elif case == 2:
            y = np.sin(2*np.pi*t+x) + e
        elif case == 3:
            y = np.exp(t-x+0.5) * e
        elif case == 4:
            y = np.exp(np.sin(np.pi*t+x)+e)
        elif case == 5:
            y = np.exp(-5*t+x) + np.exp(t+x-0.5)*e
        return y

    y = construct(x, t, e)
    test_y = construct(test_x, test_t, test_e)
    y_cf_test = construct(test_x, cf_t, test_e)

    mc_cdf = monte_carlo(x,t,y, target_y=test_y[0])
    print('>>>>> Monte Carlo CDF ', mc_cdf)

    train_x = torch.from_numpy(x).float()
    train_t = torch.from_numpy(t).float()
    train_y = torch.from_numpy(y).float()
    test_x = torch.from_numpy(test_x).float()
    test_t = torch.from_numpy(test_t).float()
    test_y = torch.from_numpy(test_y).float()
    cf_t = torch.from_numpy(cf_t).float()
    y_cf_test = torch.from_numpy(y_cf_test).float()

    train_dset = DictDataset({'covariate': train_x, 'treatment': train_t, 'outcome': train_y})
    test_dset = DictDataset({'covariate': test_x, 'treatment': test_t, 'outcome': test_y,
                             'cf_treatment': cf_t, 'cf_outcome': y_cf_test})

    df = pd.DataFrame({'x': train_x.reshape(-1), 't': train_t.reshape(-1), 'y': train_y.reshape(-1)})
    df.to_csv('data/toy_%s_train.csv' %case, index=False)
    df = pd.DataFrame({'x': test_x.reshape(-1), 't': test_t.reshape(-1), 'y': test_y.reshape(-1),
                       'cf_t': cf_t.reshape(-1), 'cf_y': y_cf_test.reshape(-1)
                       })
    df.to_csv('data/toy1_%s_test.csv' %case, index=False)
    return train_dset, test_dset, test_dset, test_dset


import torch.nn.init as init
def weights_init(init_type='gaussian'):
    def init_fun(m):
        classname = m.__class__.__name__
        if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'):
            # print m.__class__.__name__
            if init_type == 'gaussian':
                init.normal_(m.weight.data, 0.0, 0.5)
            else:
                assert 0, "Unsupported initialization: {}".format(init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)

    return init_fun

def simulate_nonmono(num_samples=100000, case=1):
    assert case in [1, 2, 3, 4, 5]
    x = np.random.uniform(size=(num_samples, 1))
    t = np.random.uniform(-3,3,size=(num_samples, 1))
    e = np.random.normal(size=(num_samples, 1))
    test_x = np.array([0.5]*1000).reshape(-1,1)
    test_t = np.array([0.5]*1000).reshape(-1,1)
    test_e = np.array([0.5]*1000).reshape(-1,1)
    cf_t = np.linspace(-3, 3, 1000).reshape(-1,1)
    mlp_e = nn.Sequential(nn.Linear(1, 100), nn.ReLU(), nn.Linear(100, 100),
                          nn.ReLU(), nn.Linear(100, 1),
                        nn.Tanh())
    mlp_xt = nn.Sequential(nn.Linear(2, 100), nn.ReLU(), nn.Linear(100, 100), nn.ReLU(),
                           nn.Linear(100, 1),
                        nn.Tanh())
    mlp_xte = nn.Sequential(nn.Linear(3, 100), nn.ReLU(), nn.Linear(100, 100), nn.ReLU(),
                            nn.Linear(100, 1),
                        nn.Tanh())

    mlp_e.apply(weights_init('gaussian'))
    mlp_xt.apply(weights_init('gaussian'))
    mlp_xte.apply(weights_init('gaussian'))
    def construct(x,t,e):
        print('>>>>>>>>>>>>>> Nonmoton CASEE ', case)
        if case == 1:
            y = np.exp(np.cos(np.pi*t+3*x)+np.cos(e))
        elif case == 2:
            y = np.exp(np.cos(np.pi * t + 3 * x) + e**2)
        elif case == 3:
            y = np.exp(np.cos(np.pi*t+3*x) + mlp_e(torch.from_numpy(e).float()).detach().numpy())
        elif case == 4:
            y = np.exp(mlp_xt(torch.from_numpy(np.concatenate([x,t], 1)).float()).detach().numpy()
                       + mlp_e(torch.from_numpy(e).float()).detach().numpy())
        elif case == 5:
            y = np.exp(mlp_xte(torch.from_numpy(np.concatenate([x,t,e], 1)).float()).detach().numpy()
                       +mlp_e(torch.from_numpy(e).float()).detach().numpy())
        return y

    y = construct(x, t, e)
    test_y = construct(test_x, test_t, test_e)
    y_cf_test = construct(test_x, cf_t, test_e)

    mc_cdf = monte_carlo(x,t,y, target_y=test_y[0])
    print('>>>>> Monte Carlo CDF ', mc_cdf)

    train_x = torch.from_numpy(x).float()
    train_t = torch.from_numpy(t).float()
    train_y = torch.from_numpy(y).float()
    test_x = torch.from_numpy(test_x).float()
    test_t = torch.from_numpy(test_t).float()
    test_y = torch.from_numpy(test_y).float()
    cf_t = torch.from_numpy(cf_t).float()
    y_cf_test = torch.from_numpy(y_cf_test).float()

    train_dset = DictDataset({'covariate': train_x, 'treatment': train_t, 'outcome': train_y})
    test_dset = DictDataset({'covariate': test_x, 'treatment': test_t, 'outcome': test_y,
                             'cf_treatment': cf_t, 'cf_outcome': y_cf_test})

    df = pd.DataFrame({'x': train_x.reshape(-1), 't': train_t.reshape(-1), 'y': train_y.reshape(-1)})
    df.to_csv('data/nonmono_%s_train.csv' %case, index=False)
    df = pd.DataFrame({'x': test_x.reshape(-1), 't': test_t.reshape(-1), 'y': test_y.reshape(-1),
                       'cf_t': cf_t.reshape(-1), 'cf_y': y_cf_test.reshape(-1)
                       })
    df.to_csv('data/nonmono_%s_test.csv' %case, index=False)
    return train_dset, test_dset, test_dset, test_dset



def simulate_complex_ge(num_samples=100000):
    x = np.random.normal(size=(num_samples, 1))
    t = np.random.normal(size=(num_samples, 1))
    e = np.random.normal(size=(num_samples, 1))
    x = torch.from_numpy(x).float()
    t = torch.from_numpy(t).float()
    e = torch.from_numpy(e).float()

    test_x = torch.tensor(0.5).view(1, 1).repeat(1000, 1).repeat(1,1)
    test_t = torch.tensor(0.5).view(1, 1).repeat(1000, 1).repeat(1,1)
    test_e = torch.repeat_interleave(torch.tensor([0.5]), repeats=1000).view(1000, 1)

    cf_actions = torch.linspace(-1, 1, 1000).view(1000, 1).repeat(1,1)

    def construct(xx, tt, ee):
        new_ee = []
        for i in range(len(ee)):
            #new_ee.append(np.abs(np.power(ee[i], max(1,(ee[i]-np.sin(ee[i]))**2))))
            #new_ee.append((np.power(np.abs(ee[i]), np.sin(ee[i]))))
            #new_ee.append(np.exp(ee[i])/2)
            #new_ee.append(np.where(ee[i]>0, ee[i], 0))
            #new_ee.append(max(ee[i], -0.1))
            #new_ee.append(min(max(ee[i], -0.5), 1))
            #new_ee.append(ee[i])
            new_ee.append((1/4*ee[i]+1/8*ee[i]**2+1/32*ee[i]**3))
            #new_ee.append(-2*torch.sqrt(torch.abs(ee[i]-torch.sin(ee[i]))))

        ee = torch.tensor(new_ee).view(-1, 1)
        out = (xx+ np.sin(2*np.pi*tt) * ee + 2*ee)/4

        #out = xx+tt+ee
        print(ee.min(), ee.max(), ' >>>> ee min max ', out.min(), out.max(), ' >>>> out min max ')
        #out = torch.from_numpy(out).float()
        return out

    y = construct(x, t, e)
    test_y = construct(test_x, test_t, test_e)
    y_cf_test = construct(test_x, cf_actions, test_e)
    x = x.to('cuda')
    t = t.to('cuda')
    y = y.to('cuda')
    e = e.to('cuda')
    test_x = test_x.to('cuda')
    test_t = test_t.to('cuda')
    test_y = test_y.to('cuda')
    test_e = test_e.to('cuda')
    y_cf_test = y_cf_test.to('cuda')
    cf_actions = cf_actions.to('cuda')

    train_dset = DictDataset({'covariate': x, 'treatment': t, 'outcome': y})
    test_dset = DictDataset({'covariate': test_x, 'treatment': test_t, 'outcome': test_y,
                             'cf_treatment': cf_actions, 'cf_outcome': y_cf_test})

    return train_dset, train_dset, test_dset, test_dset


class DictDataset(torch.utils.data.Dataset):
    def __init__(self, data_dict):
        super().__init__()
        self.data_dict = data_dict

    def __len__(self):
        k = list(self.data_dict.keys())[0]
        return len(self.data_dict[k])

    def __getitem__(self, index):
        return {k: v[index] for k, v in self.data_dict.items()}



class RotationMNIST:
    def __init__(self, split='train', device='cuda'):
        self.mnist = MNIST(root='data', download=True, train=(split=='train'))
        self.train_data = self.mnist.data
        self.train_labels = self.mnist.targets
        self.device = device
        self.split = split

    def __len__(self):
        return len(self.train_data)
    def __getitem__(self, index):
        img, target = self.train_data[index], int(self.train_labels[index])
        img = Image.fromarray(img.numpy(), mode='L')
        img = VF.resize(img, (32, 32))
        img = transforms.ToTensor()(img)
        img = img.float().view(1, 32, 32).repeat(3,1,1)
        if self.split == 'train':
            brightness = np.random.uniform(0.0, 1.)
            rotate = torch.clone(img)
            rotate_angle = np.random.randint(-45, 45)
            rotate = VF.rotate(rotate, rotate_angle)
        else:
            rotate_angle = 0
            brightness = np.array([0, 0.25, 0.5, 0.75, 1])[index%5]
            rotate = torch.clone(img)
            rotate[0,:,:] += 0.01*brightness
            rotate[1,:,:] = brightness
            rotate[2,:,:] += 0.01*brightness

            cf_t = 30 if index%2 ==0 else -30
            cf_out = VF.rotate(torch.clone(img), cf_t)
            cf_out[1,:,:] = brightness

        example = { 'covariate': img.view(3,32,32).to(self.device),
                    'treatment': torch.tensor(rotate_angle).to(self.device).view(-1)/45,
                    'outcome':rotate.view(3,32,32).to(self.device),
                    }
        if self.split == 'test':
            example['cf_treatment'] = torch.tensor(cf_t).to(self.device).view(-1)/45
            example['cf_outcome'] = cf_out.view(3,32,32).to(self.device)
        return example


class CounterfactualMNIST:
    def __init__(self, split='train', device='cuda', mode='train'):
        self.mnist = MNIST(root='data', download=True, train=(split=='train'))
        self.train_data = self.mnist.data
        self.train_labels = self.mnist.targets
        self.device = device
        self.split = split
        self.mode = mode
        self.pallette = [[31, 119, 180],
                         [255, 127, 14],
                         [44, 160, 44],
                         [214, 39, 40],
                         [148, 103, 189],
                         [140, 86, 75],
                         [227, 119, 194],
                         [127, 127, 127],
                         [188, 189, 34],
                         [23, 190, 207]]

    def __len__(self):
        return len(self.train_data)
    def __getitem__(self, index):
        img, target = self.train_data[index], int(self.train_labels[index])
        img = Image.fromarray(img.numpy(), mode='L')
        img = VF.resize(img, (32, 32))
        img = transforms.ToTensor()(img)
        img = img.float().view(1, 32, 32).repeat(3,1,1)
        if self.mode == 'train':
            color = np.random.randint(0, 10)
        else:
            color = index%10
        color = 4
        img[0,:,:] *= (self.pallette[color][0]/255)
        img[1,:,:] *= (self.pallette[color][1]/255)
        img[2,:,:] *= (self.pallette[color][2]/255)

        if self.mode == 'train':
            brightness = np.random.uniform(0.0, 1.)
            rotate = torch.clone(img)
            rotate_angle = np.random.randint(-45, 45)
            rotate = VF.rotate(rotate, rotate_angle)
            rotate[0,:,:] = 0.99*rotate[0,:,:] + 0.01*brightness
            rotate[1,:,:] = brightness
            rotate[2,:,:] = 0.99*rotate[2,:,:] + 0.01*brightness

        else:
            rotate_angle = -30 if index%2==0 else 30
            brightness = np.array([0, 0.25, 0.5, 0.75, 1])[index%5]
            brightness = 0.5
            rotate = VF.rotate(torch.clone(img), rotate_angle)
            rotate[0,:,:] = 0.99*rotate[0,:,:] + 0.01*brightness
            rotate[1,:,:] = brightness
            rotate[2,:,:] = 0.99*rotate[2,:,:] + 0.01*brightness

            cf_t = 30 if index%2 ==0 else -30
            cf_out = VF.rotate(torch.clone(img), cf_t)
            cf_out[0,:,:] = 0.99*cf_out[0,:,:] + 0.01*brightness
            cf_out[1,:,:] = brightness
            cf_out[2,:,:] = 0.99*cf_out[2,:,:] + 0.01*brightness

        example = { 'covariate': img.view(3,32,32).to(self.device),
                    'treatment': torch.tensor(rotate_angle).to(self.device).view(-1)/45,
                    'outcome':rotate.view(3,32,32).to(self.device),
                    }
        if self.mode != 'train':
            example['cf_treatment'] = torch.tensor(cf_t).to(self.device).view(-1)/45
            example['cf_outcome'] = cf_out.view(3,32,32).to(self.device)

        return example


def get_mnist():
    train_dset = CounterfactualMNIST(split='train', mode='train')
    eval_train_dset = CounterfactualMNIST(split='train', mode='test')
    test_dset = CounterfactualMNIST(split='test', mode='test')
    return train_dset, train_dset, eval_train_dset, test_dset

class SlowCounterfactualOmniglot:
    def __init__(self, split='train', mode='train', device='cpu'):
        super().__init__()
        self.omniglot = torchvision.datasets.Omniglot(root='data', download=True)
        self.split = split
        self.device = device
        self.mode = mode


    def __len__(self):
        train_id = int(0.8*len(self.omniglot))
        if self.mode == 'train':
            return train_id
        else:
            return len(self.omniglot)-train_id

    def adjust_thick(self, img, kernel_size):
        img = np.array(img)
        threshold = 0.5
        if img.max()<2:
            img *= 255
        _, binary_image = cv2.threshold(img, 128, 255, cv2.THRESH_BINARY_INV)
        kernel = np.ones((kernel_size, kernel_size), np.uint8)
        processed_image = cv2.dilate(binary_image, kernel, iterations=1)>(threshold)
        processed_pil_image = Image.fromarray(processed_image)
        return 1-transforms.ToTensor()(processed_pil_image)

    def __getitem__(self, index):
        if self.mode == 'test':
            index += int(0.8*len(self.omniglot))
        img, target = self.omniglot.__getitem__(index)
        img = img.resize((32, 32))
        options = [1, 5]
        img = transforms.ToTensor()(img)[0]
        if self.mode == 'train':
            brightness = np.random.uniform(0.0, 1.)
            opt = np.random.choice(range(len(options)))
            out = self.adjust_thick(img, options[opt])
            out *= brightness
        else:
            brightness = np.array([0.25, 0.5, 0.75, 1])[index%4]
            opt = index%len(options)
            out = self.adjust_thick(img, options[opt])
            out *= brightness

            cf_t = len(options)-1-opt
            cf_out = self.adjust_thick(img, options[cf_t])
            cf_out *= brightness

            cf_out = cf_out.float().view(1,32,32).repeat(3,1,1)

        img = img.float().view(1, 32, 32).repeat(3,1,1)
        out = out.float().view(1, 32, 32).repeat(3,1,1)
        example = { 'covariate': img.view(3,32,32).to(self.device),
                    'treatment': torch.tensor(opt).to(self.device).view(-1).float(),
                    'outcome':out.view(3,32,32).to(self.device),
                    }
        if self.mode != 'train':
            example['cf_treatment'] = torch.tensor(cf_t).to(self.device).view(-1).float()
            example['cf_outcome'] = cf_out.view(3,32,32).to(self.device)
            example['brightness'] = brightness
        return example


class CounterfactualOmniglot:
    def __init__(self, split='train', mode='train', device='cpu', root='data/omni.npy'):
        super().__init__()
        self.omniglot = np.load(root)
        print(self.omniglot.shape, '  >>>> omni shape')
        if split == 'train':
            self.omniglot = self.omniglot[:int(0.8*len(self.omniglot))]
        else:
            self.omniglot = self.omniglot[int(0.8*len(self.omniglot)):]
        self.split = split
        self.device = device
        self.mode = mode


    def __len__(self):
        return len(self.omniglot)

    def adjust_thick(self, img, kernel_size):
        img = np.array(img)
        threshold = 0.5
        if img.max()<2:
            img *= 255
        _, binary_image = cv2.threshold(img, 128, 255, cv2.THRESH_BINARY_INV)
        kernel = np.ones((kernel_size, kernel_size), np.uint8)
        processed_image = cv2.dilate(binary_image, kernel, iterations=1)>(threshold)
        return 1-torch.from_numpy(processed_image).float()

    def __getitem__(self, index):
        img = self.omniglot[index]
        options = [1, 5]
        img = torch.from_numpy(img).float()
        if self.mode == 'train':
            brightness = np.random.uniform(0.0, 1.)
            opt = np.random.choice(range(len(options)))
            out = self.adjust_thick(img, options[opt])
            out *= brightness
        else:
            #brightness = np.array([0.25, 0.5, 0.75, 1])[index%4]
            #brightness = np.array([0.25, 0.5, 0.75])[index%3]
            brightness = 0.9
            opt = index%len(options)
            out = self.adjust_thick(img, options[opt])
            out *= brightness

            cf_t = len(options)-1-opt
            cf_out = self.adjust_thick(img, options[cf_t])
            cf_out *= brightness

            cf_out = cf_out.float().view(1,32,32).repeat(3,1,1)

        img = img.float().view(1, 32, 32).repeat(3,1,1)
        out = out.float().view(1, 32, 32).repeat(3,1,1)
        example = { 'covariate': img.view(3,32,32).to(self.device),
                    'treatment': torch.tensor(opt).to(self.device).view(-1).float(),
                    'outcome':out.view(3,32,32).to(self.device),
                    }
        if self.mode != 'train':
            example['cf_treatment'] = torch.tensor(cf_t).to(self.device).view(-1).float()
            example['cf_outcome'] = cf_out.view(3,32,32).to(self.device)
            example['brightness'] = brightness
        return example



import h5py
class SlowGammaShapes:
    def __init__(self, root='data/3dshapes/3dshapes.h5', split='train', mode='train', device='cuda'):
        super().__init__()
        with h5py.File(root, 'r') as f:
            print(f.keys())
            labels = f['labels'][:]
            self.labels = np.frombuffer(labels, dtype=np.float64).reshape(-1, 6)
            train_id = int(0.8 * len(self.labels))
        self.file = h5py.File(root, 'r')
        self.data_api = self.file['images']
        if split == 'train':
            self.labels = self.labels[:train_id]
            self.data_api = self.data_api[:train_id]
        else:
            self.labels = self.labels[train_id:]
            self.data_api = self.data_api[train_id:]
        self.relevant = np.concatenate([self.labels[:, :3], self.labels[:, 4:]], axis=1)
        assert len(self.labels) == len(self.relevant)
        self.mode = mode
        self.split = split
        self.device = device
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):

        x = np.frombuffer(self.data_api[index], dtype=np.uint8).reshape(64, 64, 3)
        x = np.array((Image.fromarray(x).resize((32, 32)))).transpose(2, 0, 1)/255.
        x_relavent = self.relevant[index]
        mask = np.all(self.relevant==x_relavent, axis=1)
        mask &= (~np.all(self.labels==self.labels[index], axis=1))

        matching_indices = np.nonzero(mask)[0]
        matching_labels = self.labels[matching_indices]

        if self.mode == 'train':
            pick_index = np.random.randint(len(matching_indices))
            pick_whole_index = matching_indices[pick_index]
            noise = np.random.uniform(0, 4)
        else:
            pick_index = 0
            pick_whole_index = matching_indices[pick_index]
            noise = np.array([0.1, 1, 3])[index%3]
        outcome = np.frombuffer(self.data_api[pick_whole_index], dtype=np.uint8).reshape(64, 64, 3)
        outcome = np.array(Image.fromarray(outcome).resize((32, 32))).transpose(2, 0, 1)/255.
        outcome = torch.from_numpy(outcome).float()
        outcome = outcome**noise
        treat = matching_labels[pick_index,3]/self.labels[index, 3]
        example = {}
        example['covariate'] = torch.from_numpy(x).float()
        example['treatment'] = torch.tensor(treat).float().view(1)
        example['outcome'] = outcome
        if self.mode != 'train':
            cf_index = index%len(matching_indices)
            cf_whole_index = matching_indices[cf_index]
            cf_outcome = np.frombuffer(self.data_api[cf_whole_index], dtype=np.uint8).reshape(64, 64, 3)
            cf_outcome = np.array(Image.fromarray(cf_outcome).resize((32, 32))).transpose(2, 0, 1)/255.
            cf_outcome = torch.from_numpy(cf_outcome).float()
            cf_outcome = cf_outcome**noise
            cf_label = matching_labels[cf_index,3]
            cf_treat = cf_label/self.labels[index, 3]
            example['cf_treatment'] = torch.tensor(cf_treat).float().view(1)
            example['cf_outcome'] = cf_outcome
        return example

class GammaShapes:
    def __init__(self, root='data/3dshapes', split='train', mode='train', device='cuda'):
        super().__init__()
        if split == 'train':
            root = os.path.join(root, '3dshapes_train.npz')
        else:
            root = os.path.join(root, '3dshapes_test.npz')

        with np.load(root) as f:
            load_dict = {key: f[key] for key in f}
            self.covariates = load_dict['covariate']
            self.data = load_dict['data']
            self.object_sizes = load_dict['size']
        self.mode = mode
        self.split = split
        self.device = device
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):

        x = self.data[index]
        mask = np.all(self.covariates==self.covariates[index], axis=1)
        matching_indices = np.nonzero(mask)[0]
        matching_indices = matching_indices[matching_indices!=index]

        if self.mode == 'train':
            pick_index = np.random.randint(len(matching_indices))
            pick_whole_index = matching_indices[pick_index]
            noise = np.random.uniform(0, 4)
        else:
            pick_index = 0
            pick_whole_index = matching_indices[pick_index]
            matching_indices = matching_indices[matching_indices!=pick_whole_index]
            noise = np.array([0.1, 1, 3])[index%3]
        outcome = self.data[pick_whole_index]
        outcome = torch.from_numpy(outcome).float()
        outcome = outcome**noise
        treat = self.object_sizes[pick_whole_index]/self.object_sizes[index]
        example = {}
        example['covariate'] = torch.from_numpy(x).float()
        example['treatment'] = torch.tensor(treat).float().view(1)
        example['outcome'] = outcome
        if self.mode != 'train':
            cf_index = index%len(matching_indices)
            cf_whole_index = matching_indices[cf_index]
            cf_outcome = self.data[cf_whole_index]
            cf_outcome = torch.from_numpy(cf_outcome).float()
            cf_outcome = cf_outcome**noise
            cf_treat = self.object_sizes[cf_whole_index]/self.object_sizes[index]
            example['cf_treatment'] = torch.tensor(cf_treat).float().view(1)
            example['cf_outcome'] = cf_outcome
        return example

def resize_tensor(x):
    try:
        x = torch.from_numpy(np.array(Image.fromarray((x.cpu().numpy()*255).astype(np.uint8).transpose(1,2,0)).resize((32, 32))).transpose(2, 0, 1) / 255.).float()
    except:
        x = torch.from_numpy(np.array(Image.fromarray((x*255).astype(np.uint8).transpose(1,2,0)).resize((32, 32))).transpose(2, 0, 1) / 255.)
    return x

class LoadGammaShapes:
    def __init__(self, root='data/3dshapes/3dshapes.h5', split='train', mode='train', device='cuda'):
        super().__init__()
        with h5py.File(root, 'r') as f:
            print(f.keys())
            labels = f['labels'][:]
            self.labels = np.frombuffer(labels, dtype=np.float64).reshape(-1, 6)
            self.data = np.frombuffer(f['images'][:], dtype=np.uint8).reshape(-1, 64, 64, 3).transpose(0, 3, 1, 2)/255
            train_id = int(0.8 * len(self.labels))
        if split == 'train':
            self.labels = self.labels[:train_id]
            self.data = self.data[:train_id]
        else:
            self.labels = self.labels[train_id:]
            self.data = self.data[train_id:]
        self.relevant = np.concatenate([self.labels[:, :3], self.labels[:, 4:]], axis=1)
        assert len(self.labels) == len(self.relevant)
        self.mode = mode
        self.split = split
        self.device = device
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):

        x = self.data[index]
        x = resize_tensor(x)
        x_relavent = self.relevant[index]
        mask = np.all(self.relevant==x_relavent, axis=1)
        mask &= (~np.all(self.labels==self.labels[index], axis=1))

        matching_indices = np.nonzero(mask)[0]
        matching_labels = self.labels[matching_indices]

        if self.mode == 'train':
            pick_index = np.random.randint(len(matching_indices))
            pick_whole_index = matching_indices[pick_index]
            noise = np.random.uniform(0, 4)
        else:
            pick_index = 0
            pick_whole_index = matching_indices[pick_index]
            noise = np.array([0.1, 1, 3])[index%3]
        outcome = self.data[pick_whole_index]
        outcome = torch.from_numpy(outcome).float()
        outcome = outcome**noise
        treat = matching_labels[pick_index,3]/self.labels[index, 3]
        example = {}
        example['covariate'] = resize_tensor(x).float()
        example['treatment'] = torch.tensor(treat).float().view(1)
        example['outcome'] = resize_tensor(outcome)
        if self.mode != 'train':
            cf_index = index%len(matching_indices)
            cf_whole_index = matching_indices[cf_index]
            cf_outcome = self.data[cf_whole_index]
            cf_outcome = torch.from_numpy(cf_outcome).float()
            cf_outcome = cf_outcome**noise
            cf_label = matching_labels[cf_index,3]
            cf_treat = cf_label/self.labels[index, 3]
            example['cf_treatment'] = torch.tensor(cf_treat).float().view(1)
            example['cf_outcome'] = resize_tensor(cf_outcome)
        return example


def get_dshapes():
    train_dset = GammaShapes(split='train', mode='train')
    eval_train_dset = GammaShapes(split='train', mode='test')
    test_dset = GammaShapes(split='test', mode='test')
    return train_dset, train_dset, eval_train_dset, test_dset
    #train_dset = LoadGammaShapes(split='train', mode='train')
    #return train_dset, train_dset, train_dset, train_dset


def get_dissim(device='cpu'):
    df = pd.read_excel('data/simulated_CI_discrete.xlsx')
    df = df.to_numpy()

    cov = []
    treat = []
    out = []
    cf_treat = []
    cf_out = []
    for id in range(1, 1001):
        tmp_df = df[df[:, 0] == id]
        pick_index = id%2
        cf_index = 1-pick_index
        cov.append(tmp_df[pick_index, 1])
        treat.append(tmp_df[pick_index, 2])
        out.append(tmp_df[pick_index, 3])
        cf_treat.append(tmp_df[cf_index, 2])
        cf_out.append(tmp_df[cf_index, 3])
    cov = np.array(cov).reshape(-1, 1)/100.
    treat = np.array(treat).reshape(-1, 1)
    out = np.array(out).reshape(-1, 1)
    cf_treat = np.array(cf_treat).reshape(-1, 1)
    cf_out = np.array(cf_out).reshape(-1, 1)

    train_cov, val_covariate = cov[:int(0.8*len(cov))], cov[int(0.8*len(cov)):]
    train_treat, val_treatment = treat[:int(0.8*len(treat))], treat[int(0.8*len(treat)):]
    train_out, val_effect = out[:int(0.8*len(out))], out[int(0.8*len(out)):]
    train_cf_treat, val_cf_treatment = cf_treat[:int(0.8*len(cf_treat))], cf_treat[int(0.8*len(cf_treat)):]
    train_cf_out, val_cf_effect = cf_out[:int(0.8*len(cf_out))], cf_out[int(0.8*len(cf_out)):]

    from sklearn.ensemble import RandomForestRegressor
    model = RandomForestRegressor()
    model.fit(np.concatenate([train_cov, train_treat], axis=1), train_out)
    train_cf_pred = model.predict(np.concatenate([train_cov, train_cf_treat], axis=1))
    val_cf_pred = model.predict(np.concatenate([val_covariate, val_cf_treatment], axis=1))

    print('RF train mse: %.4f' % np.mean((train_cf_pred.reshape(-1)-train_cf_out.reshape(-1))**2), len(train_cf_out))
    print('RF val mse: %.4f' % np.mean((val_cf_pred.reshape(-1)-val_cf_effect.reshape(-1))**2), len(val_cf_pred))

    num = 50
    train_dict = {'covariate': torch.from_numpy(train_cov).float()[:num], 'treatment': torch.from_numpy(train_treat).float()[:num],
                  'outcome': torch.from_numpy(train_out).float()[:num],
                  }
    train_dset = DictDataset(train_dict)

    eval_train_dict = {'covariate': torch.from_numpy(train_cov).float(), 'treatment': torch.from_numpy(train_treat).float(),
                       'outcome': torch.from_numpy(train_out).float(),
                       'cf_treatment': torch.from_numpy(train_cf_treat).float(),
                       'cf_outcome': torch.from_numpy(train_cf_out).float()
                       }
    eval_train_dset = DictDataset(eval_train_dict)

    test_dict = {'covariate': torch.from_numpy(val_covariate).float(), 'treatment': torch.from_numpy(val_treatment).float(),
                 'outcome': torch.from_numpy(val_effect).float(),
                 'cf_treatment': torch.from_numpy(val_cf_treatment).float(),
                 'cf_outcome': torch.from_numpy(val_cf_effect).float()}
    test_dset = DictDataset(test_dict)
    return train_dset, train_dset, eval_train_dset, test_dset





def get_cisim(device='cpu'):
    df = pd.read_excel('data/simulated_CI2.xlsx')
    df = df.to_numpy()

    cov = []
    treat = []
    out = []
    cf_treat = []
    cf_out = []
    for id in range(1, 1001):
        tmp_df = df[df[:, 0] == id]
        pick_index = id%20
        cf_index = 20-1-pick_index
        cov.append(tmp_df[pick_index, 1])
        treat.append(tmp_df[pick_index, 2])
        out.append(tmp_df[pick_index, 3])
        cf_treat.append(tmp_df[cf_index, 2])
        cf_out.append(tmp_df[cf_index, 3])
    cov = np.array(cov).reshape(-1, 1)/100.
    treat = np.array(treat).reshape(-1, 1)
    out = np.array(out).reshape(-1, 1)
    cf_treat = np.array(cf_treat).reshape(-1, 1)
    cf_out = np.array(cf_out).reshape(-1, 1)

    train_cov, val_covariate = cov[:int(0.8*len(cov))], cov[int(0.8*len(cov)):]
    train_treat, val_treatment = treat[:int(0.8*len(treat))], treat[int(0.8*len(treat)):]
    train_out, val_effect = out[:int(0.8*len(out))], out[int(0.8*len(out)):]
    train_cf_treat, val_cf_treatment = cf_treat[:int(0.8*len(cf_treat))], cf_treat[int(0.8*len(cf_treat)):]
    train_cf_out, val_cf_effect = cf_out[:int(0.8*len(cf_out))], cf_out[int(0.8*len(cf_out)):]

    num_use = 100000
    train_cov = train_cov[:num_use]
    train_treat = train_treat[:num_use]
    train_out = train_out[:num_use]
    train_cf_treat = train_cf_treat[:num_use]
    train_cf_out = train_cf_out[:num_use]

    from sklearn.ensemble import RandomForestRegressor
    model = RandomForestRegressor(random_state=2)
    model.fit(np.concatenate([train_cov, train_treat], axis=1), train_out.ravel())
    train_cf_pred = model.predict(np.concatenate([train_cov, train_cf_treat], axis=1))
    val_cf_pred = model.predict(np.concatenate([val_covariate, val_cf_treatment], axis=1))
    print('RF train mse: %.4f' % np.mean((train_cf_pred.reshape(-1)-train_cf_out.reshape(-1))**2), len(train_cf_out))
    print('RF val mse: %.4f' % np.mean((val_cf_pred.reshape(-1)-val_cf_effect.reshape(-1))**2), len(val_cf_pred))

    train_dict = {'covariate': torch.from_numpy(train_cov).float(), 'treatment': torch.from_numpy(train_treat).float(),
                    'outcome': torch.from_numpy(train_out).float(),
                  }
    train_dset = DictDataset(train_dict)

    eval_train_dict = {'covariate': torch.from_numpy(train_cov).float(), 'treatment': torch.from_numpy(train_treat).float(),
                    'outcome': torch.from_numpy(train_out).float(),
                       'cf_treatment': torch.from_numpy(train_cf_treat).float(),
                       'cf_outcome': torch.from_numpy(train_cf_out).float()
                  }
    eval_train_dset = DictDataset(eval_train_dict)

    test_dict = {'covariate': torch.from_numpy(val_covariate).float(), 'treatment': torch.from_numpy(val_treatment).float(),
                 'outcome': torch.from_numpy(val_effect).float(),
                 'cf_treatment': torch.from_numpy(val_cf_treatment).float(),
                 'cf_outcome': torch.from_numpy(val_cf_effect).float()}
    test_dset = DictDataset(test_dict)
    """
    example = {'x_train': torch.from_numpy(train_cov).float(),
               't_train': torch.from_numpy(train_treat).float(),
               't_cf_train': torch.from_numpy(train_cf_treat).float(),
               'y_f_train': torch.from_numpy(train_out).float(),
               'y_cf_train': torch.from_numpy(train_cf_out).float(),
               'x_test': torch.from_numpy(val_covariate).float(),
               't_test': torch.from_numpy(val_treatment).float(),
               't_cf_test': torch.from_numpy(val_cf_treatment).float(),
               'y_f_test': torch.from_numpy(val_effect).float(),
               'y_cf_test': torch.from_numpy(val_cf_effect).float(),
               }
    return example
    """
    return train_dset, train_dset, eval_train_dset, test_dset


def get_ihdp100(exp_num, root='data', val_rate=0):
    train_data = os.path.join(root, 'ihdp_npci_1-100.train.npz')
    test_data = os.path.join(root, 'ihdp_npci_1-100.test.npz')
    data_in = dict(np.load(train_data))
    example = {}
    I = np.random.permutation(range(0, len(data_in['x'])))
    n_valid = int(len(data_in['x']) * val_rate)
    n_train = len(data_in['x']) - n_valid
    I_train = I[:n_train]
    I_valid = I[n_train:]
    example['x_train'] = torch.from_numpy(data_in['x']).float()[:, :, exp_num][I_train]
    example['t_train'] = torch.from_numpy(data_in['t']).float()[:, exp_num:exp_num + 1][I_train]
    example['y_f_train'] = torch.from_numpy(data_in['yf']).float()[:, exp_num:exp_num + 1][I_train]
    example['y_cf_train'] = torch.from_numpy(data_in['ycf']).float()[:, exp_num:exp_num + 1][I_train]
    data_in = dict(np.load(test_data))
    example['x_test'] = torch.from_numpy(data_in['x']).float()[:, :, exp_num]
    example['t_test'] = torch.from_numpy(data_in['t']).float()[:, exp_num:exp_num + 1]
    example['y_f_test'] = torch.from_numpy(data_in['yf']).float()[:, exp_num:exp_num + 1]
    example['y_cf_test'] = torch.from_numpy(data_in['ycf']).float()[:, exp_num:exp_num + 1]
    example['mu0_test'] = torch.from_numpy(data_in['mu0']).float()[:, exp_num:exp_num + 1]
    example['mu1_test'] = torch.from_numpy(data_in['mu1']).float()[:, exp_num:exp_num + 1]

    from sklearn.ensemble import RandomForestRegressor
    model = RandomForestRegressor(random_state=2)
    model.fit(torch.cat([example['x_train'], example['t_train']], 1), example['y_f_train'].ravel())
    train_cf_pred = model.predict(torch.cat([example['x_train'], 1-example['t_train']], 1))
    val_cf_pred = model.predict(torch.cat([example['x_test'], 1-example['t_test']], 1))
    print('RF train mse: %.4f' % np.mean((train_cf_pred.reshape(-1)-example['y_cf_train'].reshape(-1).numpy())**2), len(example['y_f_train']))
    print('RF val mse: %.4f' % np.mean((val_cf_pred.reshape(-1)-example['y_cf_test'].reshape(-1).numpy())**2), len(val_cf_pred))


    train_dset = DictDataset({'covariate': example['x_train'], 'treatment': example['t_train'],
                              'outcome': example['y_f_train']
                              })
    eval_train_dset = DictDataset({'covariate': example['x_train'], 'treatment': example['t_train'],
                                   'cf_treatment': 1-example['t_train'],
                                      'outcome': example['y_f_train'], 'cf_outcome': example['y_cf_train']
                                      })
    test_dset = DictDataset({'covariate': example['x_test'], 'treatment': example['t_test'],
                             'cf_treatment': 1-example['t_test'],
                                'outcome': example['y_f_test'], 'cf_outcome': example['y_cf_test'],
                                })

    return train_dset, train_dset, eval_train_dset, test_dset



if __name__ == '__main__':
    """
    dset = CounterfactualOmniglot(mode='test', split='test')
    print(len(dset))
    loader = DataLoader(dset, batch_size=128, shuffle=False, drop_last=True, num_workers=16)
    batch = next(iter(loader))
    save_image(fp='checkpoints/omni.png', tensor=batch['covariate'][:100], nrow=10)
    save_image(fp='checkpoints/omni_out.png', tensor=batch['outcome'][:100], nrow=10)
    save_image(fp='checkpoints/omni_cf_out.png', tensor=batch['cf_outcome'][:100], nrow=10)
    dset = LoadGammaShapes(mode='test')
    print(len(dset))
    loader = DataLoader(dset, batch_size=5, shuffle=True, drop_last=True, num_workers=4)
    batch = next(iter(loader))
    xy = []
    for xx,yy,zz in zip(batch['covariate'].to('cuda'),batch['outcome'].to('cuda'), batch['cf_outcome'].to('cuda')):
        xy.append(xx)
        xy.append(yy)
        xy.append(zz)
    xy = torch.stack(xy, dim=0)
    print(batch['treatment'].view(-1))
    print(batch['cf_treatment'].view(-1))
    save_image(fp='checkpoints/gamma.png', tensor=xy, nrow=3, normalize=False)
    """

    """
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
    ])
    dset = torchvision.datasets.Omniglot(root='data', download=True, transform=transform)
    loader = DataLoader(dset, batch_size=128, shuffle=False, drop_last=False, num_workers=8)
    all_array = []

    for batch in loader:
        all_array.append(batch[0].numpy())
    all_array = np.concatenate(all_array, axis=0)
    print(all_array.shape)
    np.save('data/omni.npy', all_array)

    """
    """
    dset = CounterfactualOmniglot(split='train', mode='test')
    loader = DataLoader(dset, batch_size=128, shuffle=True, drop_last=True, num_workers=8)
    batch = next(iter(loader))
    save_image(fp='checkpoints/omni.png', tensor=batch['covariate'][:100], nrow=10)
    save_image(fp='checkpoints/omni_out.png', tensor=batch['outcome'][:100], nrow=10)
    save_image(fp='checkpoints/omni_cf_out.png', tensor=batch['cf_outcome'][:100], nrow=10)
    """
    """
    root = 'data/3dshapes/3dshapes.h5'
    from tqdm import tqdm
    import pickle
    with h5py.File(root, 'r') as f:
        labels = f['labels'][:]
        labels = np.frombuffer(labels, dtype=np.float64).reshape(-1, 6)
        data = np.frombuffer(f['images'][:], dtype=np.uint8).reshape(-1, 64, 64,3)
    new_data = {}
    relevant = np.concatenate([labels[:, :3], labels[:, 4:]], axis=1)
    images = []
    treatments = []
    covariates = []
    for i in tqdm(range(4*60000)):
        img = data[i]
        x_relavent = relevant[i]
        if str(x_relavent) in new_data:
            continue

        cur_img = []
        img = np.array((Image.fromarray(img).resize((32, 32)))).transpose(2, 0, 1)/255.
        cur_img.append(img)
        x = data[i]

        new_data[str(x_relavent)] = img
        mask = np.all(relevant==x_relavent, axis=1)
        mask &= (~np.all(labels==labels[i], axis=1))

        matching_indices = np.nonzero(mask)[0]
        matching_labels = labels[matching_indices]
        for j in range(len(matching_indices)):
            pick_whole_index = matching_indices[j]
            neighbor = data[pick_whole_index]
            neighbor = np.array((Image.fromarray(neighbor).resize((32, 32)))).transpose(2, 0, 1)/255.
            cur_img.append(neighbor)
        cur_img = np.stack(cur_img, axis=0)
        cur_cov = np.repeat(x_relavent[None,:], len(cur_img), axis=0)
        cur_treat = np.array([labels[i,3]] + [matching_labels[j,3] for j in range(len(matching_indices))]).reshape(len(cur_img), 1)
        new_data[str(x_relavent)] = 1
        images.append(cur_img)
        treatments.append(cur_treat)
        covariates.append(cur_cov)

        if len(new_data)>=10000:
            break

    rnd_ids = np.random.permutation(len(images))
    train_ids = rnd_ids[:int(0.8*len(images))]
    test_ids = rnd_ids[int(0.8*len(images)):]
    train_images = []
    train_treatments = []
    train_covariates = []
    for i in range(len(train_ids)):
        sub_id = np.random.choice(range(len(images[train_ids[i]])), 4, replace=False)
        train_images.append(images[train_ids[i]][sub_id])
        train_treatments.append(treatments[train_ids[i]][sub_id])
        train_covariates.append(covariates[train_ids[i]][sub_id])
    train_images = np.concatenate(train_images, axis=0)
    train_treatments = np.concatenate(train_treatments, axis=0)
    train_covariates = np.concatenate(train_covariates, axis=0)

    print(train_images.shape, train_treatments.shape, train_covariates.shape)


    test_images = []
    test_treatments = []
    test_covariates = []
    for i in range(len(test_ids)):
        sub_id = np.random.choice(range(len(images[test_ids[i]])), 4, replace=False)
        test_images.append(images[test_ids[i]][sub_id])
        test_treatments.append(treatments[test_ids[i]][sub_id])
        test_covariates.append(covariates[test_ids[i]][sub_id])
    test_images = np.concatenate(test_images, axis=0)
    test_treatments = np.concatenate(test_treatments, axis=0)
    test_covariates = np.concatenate(test_covariates, axis=0)
    print(test_images.shape, test_treatments.shape, test_covariates.shape)

    np.savez('data/3dshapes/3dshapes_train.npz', data=train_images, size=train_treatments, covariate=train_covariates)
    np.savez('data/3dshapes/3dshapes_test.npz', data=test_images, size=test_treatments, covariate=test_covariates)
    """


    """
    with np.load('data/3dshapes/3dshapes.npz') as f:
        load_dict = {key: f[key] for key in f}
        for k in load_dict:
            print(k, load_dict[k].shape)
    """
    dset = GammaShapes(split='train', mode='test')
    loader = DataLoader(dset, batch_size=10, shuffle=True, drop_last=True, num_workers=8)
    batch = next(iter(loader))
    print(batch['treatment'].view(-1))
    print(batch['cf_treatment'].view(-1))
    tensors = torch.cat([batch['covariate'], batch['outcome'], batch['cf_outcome']], dim=0)
    save_image(fp='checkpoints/gamma.png', tensor=tensors, nrow=10)