import numpy as np
from numpy import savez_compressed
import tensorflow as tf
import edward as ed
from sklearn.model_selection import train_test_split
from edward.models import Bernoulli, Normal
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import pandas as pd
from glob import glob
import os

from utils import config_nnbuilder

# config_nnbuilder(activation=tf.nn.relu)
# config_nnbuilder()
from utils import *

import matplotlib.pyplot as plt

data_path = './data/'


class IHDP(object):
    def __init__(self, path_data="datasets/IHDP/csv", replications=10):
        # self.path_data = path_data
        self.train_data = np.load('../ihdp_npci_1-1000.train.npz')
        self.test_data = np.load('../ihdp_npci_1-1000.test.npz')
        self.replications = replications
        # which features are binary
        self.binfeats = [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
        # which features are continuous
        self.contfeats = [i for i in range(25) if i not in self.binfeats]

    # def __iter__(self):
    #     for i in range(self.replications):
    #         data = np.loadtxt(self.path_data + '/ihdp_npci_' + str(i + 1) + '.csv', delimiter=',')
    #         t, y, y_cf = data[:, 0], data[:, 1][:, np.newaxis], data[:, 2][:, np.newaxis]
    #         mu_0, mu_1, x = data[:, 3][:, np.newaxis], data[:, 4][:, np.newaxis], data[:, 5:]
    #         yield (x, t, y), (y_cf, mu_0, mu_1)

    def get_train_valid_test(self, index=None):
        for i in range(self.replications):
            if index is not None:
                if i != index:
                    continue

            #     data = np.loadtxt(self.path_data + '/ihdp_npci_' + str(i + 1) + '.csv', delimiter=',')
            #     t, y, y_cf = data[:, 0][:, np.newaxis], data[:, 1][:, np.newaxis], data[:, 2][:, np.newaxis]
            #     mu_0, mu_1, x = data[:, 3][:, np.newaxis], data[:, 4][:, np.newaxis], data[:, 5:]

            def get(data):
                (x, t, y), (y_cf, mu_0, mu_1) = \
                    (data['x'][:, :, i], data['t'][:, i][:, np.newaxis], data['yf'][:, i][:, np.newaxis]), \
                    (data['ycf'][:, i][:, np.newaxis], data['mu0'][:, i][:, np.newaxis],
                     data['mu1'][:, i][:, np.newaxis])
                # this binary feature is in {1, 2}
                x[:, 13] -= 1
                return (x, t, y), (y_cf, mu_0, mu_1)

            (x, t, y), (y_cf, mu_0, mu_1) = get(self.train_data)

            itr, iva = train_test_split(np.arange(x.shape[0]), test_size=0.3, random_state=1)
            train = (x[itr], t[itr], y[itr]), (y_cf[itr], mu_0[itr], mu_1[itr])
            valid = (x[iva], t[iva], y[iva]), (y_cf[iva], mu_0[iva], mu_1[iva])
            yield train, valid, get(self.test_data), self.contfeats, self.binfeats


class CelebA(object):
    def __init__(self):
        attr_file = data_path + 'CelebA/list_attr_celeba.txt'
        # assume first line of the original file was deleted
        self.attributes = pd.read_csv(attr_file, delim_whitespace=True)
        self.points = len(self.attributes)
        pass

    # last attr as treatment
    def get_train_valid_test(self, attr_names=('Bald', 'Young', 'No_Beard', 'Male'), split=0.8):
        attrs = self.attributes[list(attr_names)].to_numpy()
        from utils import get_image
        img_files = glob(os.path.join(data_path, 'CelebA/img_align_celeba/', "*.jpg"))
        img_data = np.array(
            [get_image(img_file, 128, is_crop=True, resize_w=64, is_grayscale=0) for img_file in img_files]
        )

        # TODO: remove this, give real data
        img_data = np.array([img_data[0]]*self.points)

        data = (attrs[:, :-1], attrs[:, -1:], img_data)
        train = tuple(d[:int(self.points*split)] for d in data)
        valid = tuple(d[int(self.points*split):] for d in data)
        return train, valid


class Pokec(object):
    def __init__(self, confounder='region_cat', replications=10, is_networked=True, beta=10.0):
        self.replications = replications
        self.confounder = confounder

        from data.Pokec.simulate_treatment_outcome import simulate_from_pokec_covariate

        self.data, self.conf_data = simulate_from_pokec_covariate(
            data_path + 'Pokec/regional_subset',
            covariate=self.confounder, beta0=1.0,
            beta1=beta, gamma=1.0
        )

        from copy import deepcopy

        # self.conf_data = deepcopy(self.data[-1][confounder])
        # self.conf_data = self.conf_data[:, np.newaxis]

        np.random.seed(0)
        self.points = len(self.data[0])
        indexes = list(range(self.points))
        np.random.shuffle(indexes)
        self.folds = np.array_split(indexes, self.replications)
        self.is_networked = is_networked

    def network_data(self):
        return self.data[-2]

    def confounder_data(self, mask=None):
        if mask is not None:
            return self.conf_data[self.mask==mask]
        else:
            return self.conf_data

    def get_train_valid_test(self, index=None, split=0.66):
        t, y, ycf, y0, y1, network_data, features = self.data

        features = np.vstack([v for v in features.values()]).T

        all_data = (t, y, ycf, y0, y1, features)
        all_data = tuple(d[:, np.newaxis] if len(d.shape) == 1 else d for d in all_data)

        # NOTE: there is actually no need to distinguish type of covariates, unless we want to generate them in decoder
        # binfeats = [
        #     'public', 'gender', 'completed_level_of_education', 'sign_in_zodiac', 'relation_to_casual_sex',
        #     'I_like_books', 'recent_login', 'old_school'
        # ]

        # cat region

        for i, test_ids in enumerate(self.folds):
            if index is not None:
                if isinstance(index, int):
                    # run this replica only
                    if i != index:
                        continue
                elif isinstance(index, tuple):
                    if index[-1] is None:
                        # start from this replica
                        if i < index[0]:
                            continue

            test = tuple(d[test_ids, :] for d in all_data)

            if self.replications == 1:
                train = test
                valid = test
            else:
                rest_folds = [f for j, f in enumerate(self.folds) if not j == i]
                rest_ids = np.concatenate(rest_folds)
                np.random.seed(i)
                np.random.shuffle(rest_ids)
                train_ids = rest_ids[:int(len(rest_ids)*split)]
                valid_ids = rest_ids[int(len(rest_ids)*split):]
                self.mask = np.empty((self.points, ), dtype=np.int)
                self.mask[train_ids] = 0
                self.mask[valid_ids] = 1
                self.mask[test_ids] = 2
                train = tuple(d[train_ids, :] for d in all_data)
                valid = tuple(d[valid_ids, :] for d in all_data)

            def format_data(data):
                t, y, ycf, y0, y1, features = data
                if self.is_networked:
                    return (all_data[-1], t, y), (ycf, y0, y1)
                else:
                    return (features, t, y), (ycf, y0, y1)

            yield format_data(train), format_data(valid), format_data(test), list(range(features.shape[1])), []


from collections import namedtuple

ArtConfig = namedtuple('ArtConfig', ['z_dim', 'z_sigma',
                                     'conf_dim', 'conf_mix',
                                     'proxied_conf_dim',
                                     'proxy_t_dim', 'proxy_y_dim', 'proxy_no_dim', 'proxy_mix',
                                     'backdoor_t', 'backdoor_y',
                                     'y_noise_level', 'proxy_value_range', 'proxy_noise_level',
                                     'satisify_icvae', 'standalone_z',
                                     'logit_scale', 'last_z_scale', 'dep_noise'])

# All conf (z) observed directly
ArtConfig.__new__.__defaults__ = ArtConfig(
    10, 1,
    10, False,
    10,
    0, 0, 0, False,
    True, True,
    0.1, 0, 0.2,
    False, False,
    10, 1, False
)

unconfounded = ArtConfig()

# Backdoor criteria satisfied
no_backdoor = ArtConfig(
    proxy_mix='nonlinear',
    proxy_t_dim=20, proxy_y_dim=4,
    # proxy_no_dim=5,
    backdoor_t=False, backdoor_y=False,
    proxy_value_range=0.5,
    # z_dim=1,
    # z_sigma=0.5,
    # test icvae
)

no_backdoor = no_backdoor._replace(
    satisify_icvae=True,
    z_dim=1, conf_dim=1, proxied_conf_dim=1,
    proxy_t_dim=30, proxy_no_dim=0, proxy_value_range=.1,
    y_noise_level=.1, proxy_noise_level=.2
)

# pure proxy
# no_backdoor = no_backdoor._replace(
#     satisify_icvae=True,
#     proxy_mix='linear',
#     z_dim=1, conf_dim=1, proxied_conf_dim=1,
#     proxy_t_dim=0, proxy_no_dim=3, proxy_y_dim=0, proxy_value_range=.1,
#     y_noise_level=.1, proxy_noise_level=.2
# )

# y proxy
# no_backdoor = no_backdoor._replace(
#     # satisify_icvae=True,
#     proxy_mix='linear',
#     z_dim=1, conf_dim=1, proxied_conf_dim=1,
#     proxy_t_dim=0, proxy_no_dim=0, proxy_value_range=.1,
#     y_noise_level=.1, proxy_noise_level=.2
# )

y_backdoor = no_backdoor._replace(backdoor_y=True)
t_backdoor = no_backdoor._replace(backdoor_t=True)
yes_backdoor = no_backdoor._replace(backdoor_y=True, backdoor_t=True)


# class Artificial(object):
#     def __init__(self, replications=1, samples=500, config=unconfounded):
#         self.replications = replications
#         self.points = samples
#         self.config = config
#
#     def model(self, seed):
#         (z_dim, z_sigma,
#          conf_dim, conf_mix,
#          proxied_conf_dim,
#          proxy_t_dim, proxy_y_dim, proxy_no_dim, proxy_mix,
#          backdoor_t, backdoor_y,
#          y_noise_level, proxy_value_range,
#          satisify_icvae) = self.config
#
#         proxy_var_sigma = not satisify_icvae
#         y_var_sigma = not satisify_icvae
#         if satisify_icvae:
#             proxy_mix = 'linear'  # TODO: nonlinear
#             proxy_y_dim = 0
#
#         proxy_dim = proxy_t_dim + proxy_y_dim + proxy_no_dim
#
#         z_shape = [self.points, z_dim]
#         z = Normal(loc=tf.zeros(z_shape), scale=z_sigma * tf.ones(z_shape))
#         # z_p = Normal(loc=mu_multiplier*tf.ones([self.points, z_dim/2]), scale=tf.ones([self.points, z_dim/2]))
#         # z_n = Normal(loc=-mu_multiplier*tf.ones([self.points, z_dim/2]), scale=tf.ones([self.points, z_dim/2]))
#         # z = tf.concat([z_p, z_n], 1)
#
#         lamba = 1e-4
#         nn_depth = 3
#         nn_width = 10
#         layers = nn_depth * [nn_width]
#
#         if conf_mix:
#             confounder = fc_net(z, layers, [conf_dim, tf.nn.elu], 'confounder', seed=seed)
#         else:
#             confounder = z
#
#         proxied_conf = confounder[:, :proxied_conf_dim]
#         if proxy_mix:
#             if proxy_mix == 'nonlinear':
#                 mux, sigmax = mu_sigma(proxied_conf, layers, proxy_dim, var_sigma=True, switch=False, seed=seed,
#                                        name='x')
#             elif proxy_mix == 'linear':
#                 # mux is just a linear trans. of conf, and sigmax is id. matrix
#                 mux, sigmax = mu_sigma(proxied_conf, [], proxy_dim, var_sigma=proxy_var_sigma, switch=False, seed=seed,
#                                        name='x')
#         else:
#             mux = proxied_conf[:, :proxy_dim]
#             sigmax = tf.ones([self.points, proxy_dim])
#
#         if proxy_value_range:
#             proxy = Normal(loc=mux, scale=sigmax * proxy_value_range, name='proxy')
#         else:
#             proxy = mux
#
#         if satisify_icvae:
#             with tf.variable_scope('x_1', reuse=True):
#                 w = tf.get_variable('weights')
#
#             pinv_w = pinv(w)
#
#             mux_x = Normal(loc=proxy, scale=sigmax * proxy_value_range, name='mux_x')
#             z_x = Normal(loc=tf.matmul(mux_x, pinv_w), scale=tf.matmul(sigmax, pinv_w) * proxy_value_range, name='z_x')
#
#         proxy_t = proxy[:, :proxy_t_dim]
#         proxy_y = proxy[:, proxy_t_dim: proxy_t_dim + proxy_y_dim]
#
#         def bd_dim(attr):
#             return conf_dim if attr is True else int(attr)
#
#         logits = fc_net(tf.concat([proxy_t, confounder[:, :bd_dim(backdoor_t)]], 1),
#                         layers, [[1, None]], 'pt_z', seed=seed + 1)
#         t = Bernoulli(logits=logits, dtype=tf.float32)
#
#         ((mu0, sigma0), (mu1, sigma1)) = mu_sigma(tf.concat([proxy_y, confounder[:, :bd_dim(backdoor_y)]], 1),
#                                                   layers, 1, var_sigma=y_var_sigma, switch=True,
#                                                   seed=seed + 1,
#                                                   name='y')
#         y_standard = Normal(loc=tf.zeros([self.points, 1]), scale=tf.ones([self.points, 1]), name='y_standard')
#
#         y = y_standard * (t * sigma1 + (1. - t) * sigma0) * y_noise_level + t * mu1 + (1. - t) * mu0
#         ycf = y_standard * (t * sigma0 + (1. - t) * sigma1) * y_noise_level + t * mu0 + (1. - t) * mu1
#
#         # self.z_dim = z_dim
#         # self.proxy_dim = proxy_dim
#
#         return (proxy, t, y, ycf), (mu0, mu1), (z, z_x, mux_x)
#
#     def data(self, model_seed, z_seed=None, sample_seed=None):
#         with tf.Graph().as_default():
#             ed.set_seed(sample_seed)  # eq to set np and tf
#             # tf.set_random_seed(sample_seed)
#
#             sess = tf.InteractiveSession()
#
#             observed_and_outcome, mu, (z, z_x, mux_x) = self.model(model_seed)
#
#             tf.global_variables_initializer().run()
#
#             observed_and_outcome_data = sess.run(
#                 observed_and_outcome,
#                 feed_dict={z: sess.run(z.sample(seed=z_seed))}
#             )
#
#             # compute mu by sampling
#             print('Compute mu by sampling...')
#             x, x_data = observed_and_outcome[0], observed_and_outcome_data[0]
#
#             delta = 0.01
#             alpha = 0.1
#
#             sum0_data, sum1_data = \
#                 np.zeros(mu[0].shape, dtype=np.float32), np.zeros(mu[1].shape, dtype=np.float32)
#             i = 0  # i should be positive
#             while True:
#                 upper = n_sampling = 2 ** i
#                 print('n_sampling = ', n_sampling)
#                 lower = 2 ** (i - 1) if i else 0
#                 # sum0_data = sum0_data_[i % 2]
#                 # sum1_data = sum1_data_[i % 2]
#
#                 if i % 2:
#                     sum0_data_, sum1_data_ = np.copy(sum0_data), np.copy(sum1_data)
#
#                 # BUG: not converge, maybe I misused graph_replace
#                 # mu_x_ops = [ge.graph_replace(mu, {z.value(): z_x.sample(seed=p),
#                 #                                   x.value(): x_data})
#                 #             for p in range(lower, upper)]
#                 # mu_x_list = sess.run(mu_x_ops)
#                 #
#                 # for mu0_x_p, mu1_x_p in mu_x_list:
#                 #     sum0_data += mu0_x_p
#                 #     sum1_data += mu1_x_p
#
#                 for p in range(lower, upper):
#                     # sampling mux_x
#                     mux_x_p = sess.run(mux_x.sample(seed=p), feed_dict={x: x_data})
#                     # sampling z_x
#                     z_x_p = sess.run(z_x.sample(seed=p), feed_dict={mux_x: mux_x_p})
#                     # substitute into mu
#                     mu0_x_p, mu1_x_p = sess.run(mu, feed_dict={x: x_data, z: z_x_p})
#                     sum0_data += mu0_x_p
#                     sum1_data += mu1_x_p
#
#                 if i % 2:
#                     mu0_data_ = sum0_data_ / n_sampling * 2
#                     mu1_data_ = sum1_data_ / n_sampling * 2
#                     mu0_data__ = (sum0_data - sum0_data_) / n_sampling * 2
#                     mu1_data__ = (sum1_data - sum1_data_) / n_sampling * 2
#
#                     if np.mean(np.abs(mu0_data_ - mu0_data__) > delta) < alpha \
#                             and np.mean(np.abs(mu1_data_ - mu1_data__) > delta) < alpha:
#                         mu0_data = sum0_data / n_sampling
#                         mu1_data = sum1_data / n_sampling
#                         break
#
#                 i += 1
#             print('Done\n')
#         return observed_and_outcome_data[:3], (observed_and_outcome_data[3], mu0_data, mu1_data)
#
#     def get_train_valid_test(self, index=None, save='', saved=''):
#         if not saved:
#             replicas = []
#             for i in range(self.replications):
#                 if index is not None:
#                     if i != index:
#                         continue
#                 replica = self.data(i, i, i), self.data(i, i + 1, i + 1), self.data(i, i + 2, i + 2)
#                 if save:
#                     replicas.append([np.hstack(t0 + t1) for t0, t1 in replica])
#                     if i + 1 == self.replications:
#                         savez_compressed(save, replicas)
#                 yield replica + (list(range(replica[0][0][0].shape[1])), [])
#
#         else:
#             for replica in np.load(saved, allow_pickle=True)['arr_0']:
#                 yield tuple(((d[:, :-5], d[:, [-5]], d[:, [-4]]),
#                              (d[:, [-3]], d[:, [-2]], d[:, [-1]])) for d in replica) + \
#                       (list(range(replica[0].shape[1] - 5)), [])


class Artificial(object):
    def __init__(self, replications=1, samples=500, config=unconfounded, save='', saved=''):
        self.replications = replications
        self.points = 3*samples
        self.config = config

        self._gen(save, saved)

    def model(self, seed):
        raise NotImplementedError

    def _gen(self, save, saved):
        if not saved:
            ATEs = []
            self.replicas = []
            for i in range(self.replications):
                replica = self.model(i)
                print(replica)
                continue
                (train, valid, test) = replica
                (xtr, ttr, ytr), (y_cftr, mu0tr, mu1tr), ztr = train
                ATEs.append(np.mean(mu1tr - mu0tr))
                self.replicas.append(replica)
            self.ate_scale = np.std(ATEs)
            if save:
                savez_compressed(save, (self.replicas, self.ate_scale))
        else:
            loaded = np.load(saved, allow_pickle=True)['arr_0']
            self.ate_scale = loaded[1]
            self.replicas = loaded[0]

    def get_train_valid_test(self, index=None):
        for i in range(self.replications):
            if index is not None:
                if i != index:
                    continue
            replica = self.replicas[i]

            replica_scale = []
            for d in replica:
                (xtr, ttr, ytr), (y_cftr, mu0tr, mu1tr), ztr = d
                ytr, y_cftr, mu0tr, mu1tr = tuple(o/self.ate_scale for o in (ytr, y_cftr, mu0tr, mu1tr))
                d_scale = (xtr, ttr, ytr), (y_cftr, mu0tr, mu1tr), ztr
                replica_scale.append(d_scale)

            replica = tuple(replica_scale)

            yield replica + (list(range(replica[0][0][0].shape[1])), [])


def sample_loc(proxy_value_range, x_shape):
    return np.tile(np.random.uniform(-2*proxy_value_range, 2*proxy_value_range, x_shape[1]), (x_shape[0], 1))


def sample_scale(proxy_value_range, x_shape):
    return np.tile(np.random.uniform(0, 2*proxy_value_range, x_shape[1]), (x_shape[0], 1))


class LinearArtificial(Artificial):
    def __init__(self, **kwargs):
        super(LinearArtificial, self).__init__(**kwargs)

    def model(self, seed):
        np.random.seed(seed)
        Normal = np.random.normal

        (z_dim, z_sigma,
         conf_dim, conf_mix,
         proxied_conf_dim,
         proxy_t_dim, proxy_y_dim, proxy_no_dim, proxy_mix,
         backdoor_t, backdoor_y,
         y_noise_level, proxy_value_range, proxy_noise_level,
         satisify_icvae, standalone_z) = self.config

        proxy_var_sigma = not satisify_icvae
        y_var_sigma = not satisify_icvae
        if satisify_icvae:
            proxy_mix = 'linear'  # TODO: nonlinear
            proxy_y_dim = 0

        proxy_dim = proxy_t_dim + proxy_y_dim + proxy_no_dim

        lamba = 1e-4
        nn_depth = 3
        nn_width = 10
        layers = nn_depth * [nn_width]
        linear_sigma = 5

        # z_shape = [self.points, z_dim]
        # z = Normal(loc=np.zeros(z_shape), scale=z_sigma * np.ones(z_shape))
        # # z_p = Normal(loc=mu_multiplier*tf.ones([self.points, z_dim/2]), scale=tf.ones([self.points, z_dim/2]))
        # # z_n = Normal(loc=-mu_multiplier*tf.ones([self.points, z_dim/2]), scale=tf.ones([self.points, z_dim/2]))
        # # z = tf.concat([z_p, z_n], 1)
        #
        # if conf_mix:
        #     # confounder = fc_net(z, layers, [conf_dim, tf.nn.elu], 'confounder', seed=seed)
        #     pass
        # else:
        #     confounder = z
        #
        # proxied_conf = confounder[:, :proxied_conf_dim]
        # if proxy_mix:
        #     if proxy_mix == 'nonlinear':
        #         # mux, sigmax = mu_sigma(proxied_conf, layers, proxy_dim, var_sigma=True, switch=False, seed=seed, name='x')
        #         pass
        #     elif proxy_mix == 'linear':
        #         # mux is just a linear trans. of conf, and sigmax is id. matrix
        #         # mux, sigmax = mu_sigma(proxied_conf, [], proxy_dim, var_sigma=proxy_var_sigma, switch=False, seed=seed, name='x')
        #         w = Normal(loc=np.zeros([z_dim, proxy_dim]), scale=np.ones([z_dim, proxy_dim]))
        #         mux = np.matmul(confounder, w)
        #         sigmax = np.ones(mux.shape)
        # else:
        #     mux = proxied_conf[:, :proxy_dim]
        #     sigmax = np.ones([self.points, proxy_dim])
        #
        # proxy = Normal(loc=mux, scale=sigmax * proxy_value_range)
        #
        # if satisify_icvae:
        #     pinv_w = np.linalg.pinv(w)
        #
        #     z_x = Normal(loc=np.matmul(proxy, pinv_w), scale=np.abs(np.matmul(sigmax, pinv_w)) * proxy_value_range)

        ## Generate model parameters first
        # for pz_x
        w_muz = Normal(loc=np.zeros([proxy_dim, z_dim]), scale=np.ones([proxy_dim, z_dim]))
        w_sigmaz = Normal(loc=np.zeros([proxy_dim, z_dim]), scale=np.ones([proxy_dim, z_dim]))
        # for t
        w_logits = Normal(loc=np.zeros([proxy_t_dim+(z_dim if backdoor_t else 0), 1]),
                          scale=np.ones([proxy_t_dim+(z_dim if backdoor_t else 0), 1]))
        # for y
        # small cond number
        # ws = [Normal(loc=np.zeros([inp_y.shape[1], 1]), scale=np.ones([inp_y.shape[1], 1])) for _ in range(10)]
        # ws = [(w, np.linalg.cond(w)) for w in ws]
        # ws.sort(key=lambda x: x[1])
        # i = np.random.randint(2)
        # w_y_0, w_y_1 = ws[i][0], ws[1-i][0]

        w_y_0 = Normal(0, 1, [z_dim + proxy_y_dim, 1])
        w_y_1 = Normal(0, 1, [z_dim + proxy_y_dim, 1])

        b_y_0 = np.random.uniform(-1, 1, [1, z_dim + proxy_y_dim])
        b_y_1 = np.random.uniform(-1, 1, [1, z_dim + proxy_y_dim])


        ## Sampling
        # px
        if standalone_z:
            x_shape = [self.points, proxy_dim+z_dim]
            mur = sample_loc(proxy_value_range, x_shape)
            random_src = Normal(loc=mur, scale=sample_scale(proxy_value_range, x_shape))
            proxy = random_src[:, :proxy_dim]
        else:
            x_shape = [self.points, proxy_dim]

            # proxy = Normal(loc=np.zeros(x_shape), scale=z_sigma * np.ones(x_shape)*proxy_value_range)
            proxy = Normal(loc=sample_loc(proxy_value_range, x_shape),
                           scale=sample_scale(proxy_value_range, x_shape))

        # pz_x
        if standalone_z:
            confounder = random_src[:, -z_dim:]
            muz = mur[:, -z_dim:]
        else:
            muz = np.matmul(proxy, w_muz)
            sigmaz = np.abs(np.matmul(proxy, w_sigmaz))
            z_x = Normal(loc=muz, scale=proxy_noise_level*sigmaz * z_sigma)
            confounder = z_x

        # pt_zx
        proxy_t = proxy[:, :proxy_t_dim]
        proxy_y = proxy[:, proxy_t_dim: proxy_t_dim + proxy_y_dim]

        def bd_dim(attr):
            return conf_dim if attr is True else int(attr)

        # logits = fc_net(np.concatenate([proxy_t, confounder[:, :bd_dim(backdoor_t)]], 1),
        #                 layers, [[1, None]], 'pt_z', seed=seed + 1)

        inp_logits = np.concatenate([proxy_t, confounder[:, :bd_dim(backdoor_t)]], 1)
        logits = np.matmul(inp_logits, w_logits)
        t = np.random.binomial(n=1, p=1./(1+np.exp(-10*logits)))

        # ((mu0, sigma0), (mu1, sigma1)) = mu_sigma(np.concatenate([proxy_y, confounder[:, :bd_dim(backdoor_y)]], 1),
        #                                           layers, 1, var_sigma=y_var_sigma, switch=True,
        #                                           seed=seed+1,
        #                                           name='y')

        # py_zt
        inp_y = np.concatenate([proxy_y, confounder[:, :bd_dim(backdoor_y)]], 1)
        mu0 = np.matmul(inp_y, w_y_0) + np.matmul(b_y_0, w_y_0)
        mu1 = np.matmul(inp_y, w_y_1) + np.matmul(b_y_1, w_y_1)


        # import matplotlib.pyplot as plt
        #
        # plt.figure()
        # plt.plot(inp_y, mu0, '.')
        # plt.plot(inp_y, mu1, '.')
        # plt.savefig('plot%d.png' % seed)
        # plt.clf()

        y_standard = Normal(loc=np.zeros([self.points, 1]), scale=np.ones([self.points, 1]))

        mu0_std = np.std(mu0)
        mu1_std = np.std(mu1)
        # y = y_standard * (t * sigma1 + (1. - t) * sigma0) * y_noise_level + t * mu1 + (1. - t) * mu0
        # ycf = y_standard * (t * sigma0 + (1. - t) * sigma1) * y_noise_level + t * mu0 + (1. - t) * mu1
        y0 = y_standard*np.ones_like(mu0)*y_noise_level*mu0_std + mu0
        y1 = y_standard*np.ones_like(mu0)*y_noise_level*mu1_std + mu1

        y = t * y1 + (1. - t) * y0
        ycf = t * y0 + (1. - t) * y1

        import matplotlib.pyplot as plt

        plt.figure()
        plt.plot(inp_y, y0, '.')
        plt.plot(inp_y, y1, '.')
        plt.savefig('plot%d_.png' % seed)
        plt.clf()

        # self.z_dim = z_dim
        # self.proxy_dim = proxy_dim

        inp_y = np.concatenate([proxy_y, muz[:, :bd_dim(backdoor_y)]], 1)
        mu0 = np.matmul(inp_y, w_y_0) + b_y_0
        mu1 = np.matmul(inp_y, w_y_1) + b_y_1

        # y, ycf, mu0, mu1 = tuple(StandardScaler().fit_transform(data) for data in (y, ycf, mu0, mu1))

        # plt.figure()
        # plt.plot(inp_y, mu0, '.')
        # plt.plot(inp_y, mu1, '.')
        # plt.savefig('plot%d.png' % seed)
        # plt.clf()

        # set ATE to 1
        # ate = np.mean(mu1 - mu0)
        # y, ycf, mu0, mu1 = tuple(data/ate for data in (y, ycf, mu0, mu1))

        # release plt memory
        plt.close('all')

        import gc
        gc.collect()

        data = [np.split(d, 3) for d in (proxy, t, y, ycf, mu0, mu1, confounder)]
        return tuple((
                         tuple(d[i] for d in data[:3]),
                         tuple(d[i] for d in data[3:-1]),
                         data[-1][i]
                     )
                     for i in range(3))

    # def model(self, seed):
    #     np.random.seed(seed)
    #     Normal = np.random.normal
    #
    #     (z_dim, z_sigma,
    #      conf_dim, conf_mix,
    #      proxied_conf_dim,
    #      proxy_t_dim, proxy_y_dim, proxy_no_dim, proxy_mix,
    #      backdoor_t, backdoor_y,
    #      y_noise_level, proxy_value_range,
    #      satisify_icvae) = self.config
    #
    #     proxy_var_sigma = not satisify_icvae
    #     y_var_sigma = not satisify_icvae
    #     if satisify_icvae:
    #         proxy_mix = 'linear'  # TODO: nonlinear
    #         proxy_y_dim = 0
    #
    #     proxy_dim = proxy_t_dim + proxy_y_dim + proxy_no_dim
    #
    #     lamba = 1e-4
    #     nn_depth = 3
    #     nn_width = 10
    #     layers = nn_depth * [nn_width]
    #     linear_sigma = 5
    #
    #     z_shape = [self.points, z_dim]
    #     z = Normal(loc=np.zeros(z_shape), scale=z_sigma * np.ones(z_shape))
    #     # z_p = Normal(loc=mu_multiplier*tf.ones([self.points, z_dim/2]), scale=tf.ones([self.points, z_dim/2]))
    #     # z_n = Normal(loc=-mu_multiplier*tf.ones([self.points, z_dim/2]), scale=tf.ones([self.points, z_dim/2]))
    #     # z = tf.concat([z_p, z_n], 1)
    #
    #     if conf_mix:
    #         # confounder = fc_net(z, layers, [conf_dim, tf.nn.elu], 'confounder', seed=seed)
    #         pass
    #     else:
    #         confounder = z
    #
    #     proxied_conf = confounder[:, :proxied_conf_dim]
    #     if proxy_mix:
    #         if proxy_mix == 'nonlinear':
    #             # mux, sigmax = mu_sigma(proxied_conf, layers, proxy_dim, var_sigma=True, switch=False, seed=seed, name='x')
    #             pass
    #         elif proxy_mix == 'linear':
    #             # mux is just a linear trans. of conf, and sigmax is id. matrix
    #             # mux, sigmax = mu_sigma(proxied_conf, [], proxy_dim, var_sigma=proxy_var_sigma, switch=False, seed=seed, name='x')
    #             w = Normal(loc=np.zeros([z_dim, proxy_dim]), scale=np.ones([z_dim, proxy_dim]))
    #             mux = np.matmul(confounder, w)
    #             sigmax = np.ones(mux.shape)
    #     else:
    #         mux = proxied_conf[:, :proxy_dim]
    #         sigmax = np.ones([self.points, proxy_dim])
    #
    #     proxy = Normal(loc=mux, scale=sigmax * proxy_value_range)
    #
    #     if satisify_icvae:
    #         pinv_w = np.linalg.pinv(w)
    #
    #         z_x = Normal(loc=np.matmul(proxy, pinv_w), scale=np.abs(np.matmul(sigmax, pinv_w)) * proxy_value_range)
    #
    #     proxy_t = proxy[:, :proxy_t_dim]
    #     proxy_y = proxy[:, proxy_t_dim: proxy_t_dim + proxy_y_dim]
    #
    #     def bd_dim(attr):
    #         return conf_dim if attr is True else int(attr)
    #
    #     # logits = fc_net(np.concatenate([proxy_t, confounder[:, :bd_dim(backdoor_t)]], 1),
    #     #                 layers, [[1, None]], 'pt_z', seed=seed + 1)
    #
    #     inp_logits = np.concatenate([proxy_t, confounder[:, :bd_dim(backdoor_t)]], 1)
    #     w_logits = Normal(loc=np.zeros([inp_logits.shape[1], 1]), scale=np.ones([inp_logits.shape[1], 1]))
    #     logits = np.matmul(inp_logits, w_logits)
    #     t = np.random.binomial(n=1, p=1./(1+np.exp(-logits)))
    #
    #     # ((mu0, sigma0), (mu1, sigma1)) = mu_sigma(np.concatenate([proxy_y, confounder[:, :bd_dim(backdoor_y)]], 1),
    #     #                                           layers, 1, var_sigma=y_var_sigma, switch=True,
    #     #                                           seed=seed+1,
    #     #                                           name='y')
    #
    #     inp_y = np.concatenate([proxy_y, confounder[:, :bd_dim(backdoor_y)]], 1)
    #     w_y_0 = Normal(loc=np.zeros([inp_y.shape[1], 1]), scale=np.ones([inp_y.shape[1], 1]))
    #     w_y_1 = Normal(loc=np.zeros([inp_y.shape[1], 1]), scale=np.ones([inp_y.shape[1], 1]))
    #     mu0 = np.matmul(inp_y, w_y_0)
    #     mu1 = np.matmul(inp_y, w_y_1)
    #
    #     y_standard = Normal(loc=np.zeros([self.points, 1]), scale=np.ones([self.points, 1]))
    #
    #     # y = y_standard * (t * sigma1 + (1. - t) * sigma0) * y_noise_level + t * mu1 + (1. - t) * mu0
    #     # ycf = y_standard * (t * sigma0 + (1. - t) * sigma1) * y_noise_level + t * mu0 + (1. - t) * mu1
    #     y = y_standard * np.ones_like(mu0) * y_noise_level + t * mu1 + (1. - t) * mu0
    #     ycf = y_standard * np.ones_like(mu0) * y_noise_level + t * mu0 + (1. - t) * mu1
    #
    #     # self.z_dim = z_dim
    #     # self.proxy_dim = proxy_dim
    #
    #     inp_y = np.concatenate([proxy_y, np.matmul(proxy, pinv_w)[:, :bd_dim(backdoor_y)]], 1)
    #     mu0 = np.matmul(inp_y, w_y_0)
    #     mu1 = np.matmul(inp_y, w_y_1)
    #
    #     data = [np.split(d, 3) for d in (proxy, t, y, ycf, mu0, mu1)]
    #     return tuple((
    #                      tuple(d[i] for d in data[:3]),
    #                      tuple(d[i] for d in data[3:])
    #                  )
    #                  for i in range(3))

    # def data(self, model_seed, z_seed=None, sample_seed=None):
    #     with tf.Graph().as_default():
    #         ed.set_seed(sample_seed)  # eq to set np and tf
    #         # tf.set_random_seed(sample_seed)
    #
    #         sess = tf.InteractiveSession()
    #
    #         observed_and_outcome, mu, (z, z_x, mux_x) = self.model(model_seed)
    #
    #         tf.global_variables_initializer().run()
    #
    #         observed_and_outcome_data = sess.run(
    #             observed_and_outcome,
    #             feed_dict={z: sess.run(z.sample(seed=z_seed))}
    #         )
    #
    #         # compute mu by sampling
    #         print('Compute mu by sampling...')
    #         x, x_data = observed_and_outcome[0], observed_and_outcome_data[0]
    #
    #         delta = 0.01
    #         alpha = 0.1
    #
    #         sum0_data, sum1_data = \
    #             np.zeros(mu[0].shape, dtype=np.float32), np.zeros(mu[1].shape, dtype=np.float32)
    #         i = 0  # i should be positive
    #         while True:
    #             upper = n_sampling = 2 ** i
    #             print('n_sampling = ', n_sampling)
    #             lower = 2 ** (i - 1) if i else 0
    #             # sum0_data = sum0_data_[i % 2]
    #             # sum1_data = sum1_data_[i % 2]
    #
    #             if i % 2:
    #                 sum0_data_, sum1_data_ = np.copy(sum0_data), np.copy(sum1_data)
    #
    #             # BUG: not converge, maybe I misused graph_replace
    #             # mu_x_ops = [ge.graph_replace(mu, {z.value(): z_x.sample(seed=p),
    #             #                                   x.value(): x_data})
    #             #             for p in range(lower, upper)]
    #             # mu_x_list = sess.run(mu_x_ops)
    #             #
    #             # for mu0_x_p, mu1_x_p in mu_x_list:
    #             #     sum0_data += mu0_x_p
    #             #     sum1_data += mu1_x_p
    #
    #             for p in range(lower, upper):
    #                 # sampling mux_x
    #                 mux_x_p = sess.run(mux_x.sample(seed=p), feed_dict={x: x_data})
    #                 # sampling z_x
    #                 z_x_p = sess.run(z_x.sample(seed=p), feed_dict={mux_x: mux_x_p})
    #                 # substitute into mu
    #                 mu0_x_p, mu1_x_p = sess.run(mu, feed_dict={x: x_data, z: z_x_p})
    #                 sum0_data += mu0_x_p
    #                 sum1_data += mu1_x_p
    #
    #             if i % 2:
    #                 mu0_data_ = sum0_data_ / n_sampling * 2
    #                 mu1_data_ = sum1_data_ / n_sampling * 2
    #                 mu0_data__ = (sum0_data - sum0_data_) / n_sampling * 2
    #                 mu1_data__ = (sum1_data - sum1_data_) / n_sampling * 2
    #
    #                 if np.mean(np.abs(mu0_data_ - mu0_data__) > delta) < alpha \
    #                         and np.mean(np.abs(mu1_data_ - mu1_data__) > delta) < alpha:
    #                     mu0_data = sum0_data / n_sampling
    #                     mu1_data = sum1_data / n_sampling
    #                     break
    #
    #             i += 1
    #         print('Done\n')
    #     return observed_and_outcome_data[:3], (observed_and_outcome_data[3], mu0_data, mu1_data)



class PolyArtificial(object):
    def __init__(self, replications=1, samples=500, config=unconfounded):
        self.replications = replications
        self.points = 3*samples
        self.config = config

    def model(self, seed):
        np.random.seed(seed)
        Normal = np.random.normal

        (z_dim, z_sigma,
         conf_dim, conf_mix,
         proxied_conf_dim,
         proxy_t_dim, proxy_y_dim, proxy_no_dim, proxy_mix,
         backdoor_t, backdoor_y,
         y_noise_level, proxy_value_range,
         satisify_icvae) = self.config

        proxy_var_sigma = not satisify_icvae
        y_var_sigma = not satisify_icvae
        if satisify_icvae:
            proxy_mix = 'linear'  # TODO: nonlinear
            proxy_y_dim = 0

        proxy_dim = proxy_t_dim + proxy_y_dim + proxy_no_dim

        # px
        x_shape = [self.points, proxy_dim]
        # proxy = Normal(loc=np.zeros(x_shape), scale=z_sigma * np.ones(x_shape)*proxy_value_range)
        proxy = Normal(loc=np.random.uniform(-2*proxy_value_range, 2*proxy_value_range, x_shape),
                       scale=np.random.uniform(0, 2*proxy_value_range, x_shape))

        # pz_x
        w_muz = Normal(loc=np.zeros([proxy_dim, z_dim]), scale=np.ones([proxy_dim, z_dim]))
        w_sigmaz = Normal(loc=np.zeros([proxy_dim, z_dim]), scale=np.ones([proxy_dim, z_dim]))

        muz = np.matmul(proxy, w_muz)*2
        sigmaz = np.abs(np.matmul(proxy, w_sigmaz))
        z_x = Normal(loc=muz, scale=sigmaz * z_sigma)
        confounder = z_x

        proxy_t = proxy[:, :proxy_t_dim]
        proxy_y = proxy[:, proxy_t_dim: proxy_t_dim + proxy_y_dim]

        def bd_dim(attr):
            return conf_dim if attr is True else int(attr)

        # logits = fc_net(np.concatenate([proxy_t, confounder[:, :bd_dim(backdoor_t)]], 1),
        #                 layers, [[1, None]], 'pt_z', seed=seed + 1)

        inp_logits = np.concatenate([proxy_t, confounder[:, :bd_dim(backdoor_t)]], 1)
        w_logits = Normal(loc=np.zeros([inp_logits.shape[1], 1]), scale=np.ones([inp_logits.shape[1], 1]))
        logits = np.matmul(inp_logits, w_logits)
        t = np.random.binomial(n=1, p=1./(1+np.exp(-logits)))

        # ((mu0, sigma0), (mu1, sigma1)) = mu_sigma(np.concatenate([proxy_y, confounder[:, :bd_dim(backdoor_y)]], 1),
        #                                           layers, 1, var_sigma=y_var_sigma, switch=True,
        #                                           seed=seed+1,
        #                                           name='y')

        inp_y = np.concatenate([proxy_y, confounder[:, :bd_dim(backdoor_y)]], 1)

        # w_y_0 = Normal(0, 10, [inp_y.shape[1], 1])
        # w_y_1 = Normal(0, 10, [inp_y.shape[1], 1])
        # mu0 = np.matmul(inp_y, w_y_0)
        # mu1 = np.matmul(inp_y, w_y_1)

        poly_coef0 = np.zeros(20)
        # odd power coef
        poly_coef0[2 * np.random.choice(range(10), 5, replace=False)] = np.random.uniform(0, 10, 5)
        # constant
        poly_coef0[-1] = Normal(0, 10)

        poly_coef1 = np.zeros(20)
        poly_coef1[2 * np.random.choice(range(10), 5, replace=False)] = np.random.uniform(0, 10, 5)
        poly_coef1[-1] = Normal(0, 10)

        poly0 = np.poly1d(poly_coef0)
        poly1 = np.poly1d(poly_coef1)

        mu0 = poly0(inp_y)
        mu1 = poly1(inp_y)

        import matplotlib.pyplot as plt

        plt.figure()
        plt.plot(inp_y, mu0, '.')
        plt.plot(inp_y, mu0, '.')
        plt.savefig('plot%d.png' % seed)
        plt.clf()

        y_standard = Normal(loc=np.zeros([self.points, 1]), scale=np.ones([self.points, 1]))

        # y = y_standard * (t * sigma1 + (1. - t) * sigma0) * y_noise_level + t * mu1 + (1. - t) * mu0
        # ycf = y_standard * (t * sigma0 + (1. - t) * sigma1) * y_noise_level + t * mu0 + (1. - t) * mu1
        y = y_standard * np.ones_like(mu0) * y_noise_level + t * mu1 + (1. - t) * mu0
        ycf = y_standard * np.ones_like(mu0) * y_noise_level + t * mu0 + (1. - t) * mu1

        n_sample_z = 1000
        z_x = Normal(loc=[muz]*n_sample_z, scale=[sigmaz * z_sigma]*n_sample_z)

        inp_y = np.concatenate([np.repeat(proxy_y, (n_sample_z, 1)) if proxy_y_dim else np.empty((n_sample_z, self.points, 0)),
                                z_x[:, :, :bd_dim(backdoor_y)]], 2)
        mu0 = np.mean(poly0(inp_y), axis=0)
        mu1 = np.mean(poly1(inp_y), axis=0)

        from sklearn.preprocessing import StandardScaler
        y, ycf, mu0, mu1 = tuple(StandardScaler().fit_transform(data) for data in (y, ycf, mu0, mu1))

        data = [np.split(d, 3) for d in (proxy, t, y, ycf, mu0, mu1, muz)]
        return tuple((
                         tuple(d[i] for d in data[:3]),
                         tuple(d[i] for d in data[3:-1]),
                         data[-1][i]
                     )
                     for i in range(3))

    def get_train_valid_test(self, index=None, save='', saved=''):
        if not saved:
            replicas = []
            for i in range(self.replications):
                if index is not None:
                    if i != index:
                        continue
                replica = self.model(i)
                if save:
                    replicas.append([np.hstack(t0 + t1) for t0, t1 in replica])
                    if i + 1 == self.replications:
                        savez_compressed(save, replicas)
                yield replica + (list(range(replica[0][0][0].shape[1])), [])

        else:
            for replica in np.load(saved, allow_pickle=True)['arr_0']:
                yield tuple(((d[:, :-5], d[:, [-5]], d[:, [-4]]),
                             (d[:, [-3]], d[:, [-2]], d[:, [-1]])) for d in replica) + \
                      (list(range(replica[0].shape[1] - 5)), [])


class NNnpArtificial(Artificial):
    def __init__(self, **kwargs):
        super(NNnpArtificial, self).__init__(**kwargs)

    def model(self, seed):
        np.random.seed(seed)
        Normal = np.random.normal

        (z_dim, z_sigma,
         conf_dim, conf_mix,
         proxied_conf_dim,
         proxy_t_dim, proxy_y_dim, proxy_no_dim, proxy_mix,
         backdoor_t, backdoor_y,
         y_noise_level, proxy_value_range, proxy_noise_level,
         satisify_icvae, standalone_z,
         logit_scale, last_z_scale, dep_noise) = self.config

        proxy_var_sigma = not satisify_icvae
        y_var_sigma = not satisify_icvae
        if satisify_icvae:
            proxy_mix = 'linear'  # TODO: nonlinear
            proxy_y_dim = 0

        proxy_dim = proxy_t_dim + proxy_y_dim + proxy_no_dim

        ## Generate model parameters first
        # for pz_x
        w_muz = Normal(loc=np.zeros([proxy_dim, z_dim]), scale=np.ones([proxy_dim, z_dim]))
        w_sigmaz = Normal(loc=np.zeros([proxy_dim, z_dim]), scale=np.ones([proxy_dim, z_dim]))
        # for t
        w_logits = Normal(loc=np.zeros([proxy_t_dim + (z_dim if backdoor_t else 0), 1]),
                          scale=np.ones([proxy_t_dim + (z_dim if backdoor_t else 0), 1]))
        # for y
        import neuralnets as nn

        class Net(nn.Module):
            def __init__(self):
                super().__init__()

                self.dim = proxy_y_dim + (z_dim if backdoor_y else 0)
                dim = self.dim
                # bound = np.sqrt(3/dim)
                # w_init = nn.init.RandomUniform(0.9, 1.1)
                b_init = nn.init.RandomNormal(-1, 1)

                layers = [
                    # nn.layers.Dense(dim, dim, nn.act.tanh, w_init, b_init),
                    # nn.layers.Dense(dim, dim, nn.act.LeakyReLU(alpha=.5), nn.init.RandomUniform(0.9, 1.1), b_init),
                    nn.layers.Dense(dim, dim, nn.act.LeakyReLU(alpha=.5), nn.init.RandomUniform(-0.9, -1.1), b_init),
                    # nn.layers.Dense(dim, dim, nn.act.sigmoid, w_init, b_init),
                    # nn.layers.Dense(dim, dim, nn.act.ELU(alpha=.3), w_init, b_init),
                ]

                self.layers = layers * np.random.randint(3, 9)

                if dim > 1:
                    self.layers.append(nn.layers.Dense(dim, 1, nn.act.LeakyReLU(alpha=.5), nn.init.RandomUniform(-0.9, -1.1), b_init))

                if dep_noise:
                    self.soft_plus_layer = nn.layers.Dense(dim, 1, nn.act.SoftPlus(), nn.init.RandomUniform(-0.9, -1.1),
                                                           b_init)

            def forward(self, x, soft_plus=False):
                if soft_plus:
                    for layer in self.layers[:-self.dim]:  # np.random.choice(range(len(self.layers)), 3, replace=False):
                        x = layer(x)
                    x = self.soft_plus_layer(x)
                else:
                    for layer in self.layers:  # np.random.choice(range(len(self.layers)), 3, replace=False):
                        x = layer(x)

                return x

        net0 = Net()
        net1 = Net()

        ## Sampling
        # px
        if standalone_z:
            x_shape = [self.points, proxy_dim+z_dim]
            mur = sample_loc(proxy_value_range, x_shape)
            sigmar = sample_scale(proxy_value_range, x_shape)
            random_src = Normal(loc=mur, scale=sigmar)
            proxy = random_src[:, :proxy_dim]
        else:
            x_shape = [self.points, proxy_dim]
            # proxy = Normal(loc=np.zeros(x_shape), scale=z_sigma * np.ones(x_shape)*proxy_value_range)
            proxy = Normal(loc=sample_loc(proxy_value_range, x_shape),
                           scale=sample_scale(proxy_value_range, x_shape))

        # plt.figure()
        # plt.hist(proxy[:,0], 50)
        # plt.savefig('x0fix_%d.png' % seed)
        # pz_x
        if standalone_z:
            confounder = random_src[:, -z_dim:]
            muz = mur[:, -z_dim:]
            sigmaz = sigmar[:, -z_dim:]
        else:
            muz = np.matmul(proxy, w_muz)
            sigmaz = np.abs(np.matmul(proxy, w_sigmaz))
            z_x = Normal(loc=muz, scale=proxy_noise_level*sigmaz * z_sigma)
            z_x[:, -1] = z_x[:, -1]*last_z_scale
            confounder = z_x

        # pt_zx
        proxy_t = proxy[:, :proxy_t_dim]
        proxy_y = proxy[:, proxy_t_dim: proxy_t_dim + proxy_y_dim]

        def bd_dim(attr):
            return conf_dim if attr is True else int(attr)

        # logits = fc_net(np.concatenate([proxy_t, confounder[:, :bd_dim(backdoor_t)]], 1),
        #                 layers, [[1, None]], 'pt_z', seed=seed + 1)

        inp_logits = np.concatenate([proxy_t, confounder[:, :bd_dim(backdoor_t)]], 1)
        logits = np.matmul(inp_logits, w_logits)
        p = 1. / (1 + np.exp(-logit_scale * logits))

        diff = np.abs(2*p-1)
        metrics = (np.sum(diff > 0.998) / len(diff), np.mean(diff), np.median(diff))
        return metrics

        t = np.random.binomial(n=1, p=p)

        # ((mu0, sigma0), (mu1, sigma1)) = mu_sigma(np.concatenate([proxy_y, confounder[:, :bd_dim(backdoor_y)]], 1),
        #                                           layers, 1, var_sigma=y_var_sigma, switch=True,
        #                                           seed=seed+1,
        #                                           name='y')

        # py_zt
        inp_y = np.concatenate([proxy_y, confounder[:, :bd_dim(backdoor_y)]], 1)
        # inp_y = StandardScaler().fit_transform(inp_y)
        mu1 = net1.forward(inp_y).data
        mu0 = net0.forward(inp_y).data

        # return

        if dep_noise:
            sig1 = net1.forward(inp_y, soft_plus=True).data
            sig0 = net0.forward(inp_y, soft_plus=True).data
            sig0, sig1 = tuple(MinMaxScaler(feature_range=(0, 2)).fit_transform(data) for data in (sig0, sig1))
            y0_standard = Normal(loc=np.zeros([self.points, 1]), scale=sig0)
            y1_standard = Normal(loc=np.zeros([self.points, 1]), scale=sig1)
        else:
            y0_standard = Normal(loc=np.zeros([self.points, 1]), scale=np.ones([self.points, 1]))
            y1_standard = y0_standard

        mu0_std = np.std(mu0)
        mu1_std = np.std(mu1)
        # y = y_standard * (t * sigma1 + (1. - t) * sigma0) * y_noise_level + t * mu1 + (1. - t) * mu0
        # ycf = y_standard * (t * sigma0 + (1. - t) * sigma1) * y_noise_level + t * mu0 + (1. - t) * mu1
        y0 = y0_standard*np.ones_like(mu0)*y_noise_level*mu0_std + mu0
        y1 = y1_standard*np.ones_like(mu0)*y_noise_level*mu1_std + mu1

        y = t * y1 + (1. - t) * y0
        ycf = t * y0 + (1. - t) * y1


        plt.figure()
        plt.plot(inp_y, y0, '.')
        plt.plot(inp_y, y1, '.')
        plt.savefig('plot%d.png' % seed)
        plt.clf()

        n_sample_z = 1000
        z_x_ = Normal(loc=[muz]*n_sample_z, scale=[sigmaz * z_sigma]*n_sample_z)

        inp_y = np.concatenate([np.repeat(proxy_y, (n_sample_z, 1)) if proxy_y_dim else np.empty((n_sample_z, self.points, 0)),
                                z_x_[:, :, :bd_dim(backdoor_y)]], 2)
        mu0 = np.mean(net0.forward(inp_y).data, axis=0)
        mu1 = np.mean(net1.forward(inp_y).data, axis=0)

        # y, ycf, mu0, mu1 = tuple(StandardScaler().fit_transform(data) for data in (y, ycf, mu0, mu1))

        data = [np.split(d, 3) for d in (proxy, t, y, ycf, mu0, mu1, confounder)]

        # release plt memory
        plt.close('all')

        import gc
        gc.collect()

        return tuple((
                         tuple(d[i] for d in data[:3]),
                         tuple(d[i] for d in data[3:-1]),
                         data[-1][i]
                     )
                     for i in range(3))




