import tensorflow as tf
import numpy as np

# import tensorflow_probability as tfp
# tfd = tfp.distributions

from util import *
from misc import FLAGS

class cfr_net(object):
    """
    This code implements the tensorflow graph for P-VAE-CI.
    It is based on the code-base of CFR-Net, maintained by
    F. Johansson at https://github.com/clinicalml/cfrnet
    """

    def __init__(self, x, t, y_, dims, do_in, do_out, p_t, FLAGS, pi_0=None):
        self.variables = {}
        self.wd_loss = 0

        if FLAGS.nonlin.lower() == 'elu':
            self.nonlin = tf.nn.elu
        else:
            self.nonlin = tf.nn.relu

        self._build_graph(x, t, y_, dims, do_in, do_out, p_t, FLAGS, pi_0)

    def _add_variable(self, var, name):
        ''' Adds variables to the internal track-keeper '''
        basename = name
        i = 0
        while name in self.variables:
            name = '%s_%d' % (basename, i) #@TODO: not consistent with TF internally if changed
            i += 1

        self.variables[name] = var

    def _create_variable(self, var, name):
        ''' Create and adds variables to the internal track-keeper '''

        var = tf.Variable(var, name=name)
        self._add_variable(var, name)
        return var

    def _create_variable_with_weight_decay(self, initializer, name, wd):
        ''' Create and adds variables to the internal track-keeper
            and adds it to the list of weight decayed variables '''
        var = self._create_variable(initializer, name)
        self.wd_loss += wd*tf.nn.l2_loss(var)
        return var

    def _build_graph(self, x, t, y_, dims, do_in, do_out, p_t, FLAGS, pi_0):
        """
        Constructs a TensorFlow subgraph for counterfactual regression.
        Sets the following member variables (to TF nodes):

        self.output         The output prediction "y"
        self.tot_loss       The total objective to minimize
        self.imb_loss       The imbalance term of the objective
        self.pred_loss      The prediction term of the objective
        self.weights_in     The input/representation layer weights
        self.weights_out    The output/post-representation layer weights
        self.weights_pred   The (linear) prediction layer weights
        self.h_rep          The layer of the penalized representation
        """

        r_alpha = FLAGS.p_alpha
        r_lambda = FLAGS.p_lambda

        self.x = x
        self.t = t
        self.y_ = y_
        self.do_in = do_in
        self.do_out = do_out
        self.p_t = p_t
        self.r_alpha = r_alpha
        self.r_lambda = r_lambda

        dim_input = dims[0]
        dim_in = dims[1]
        dim_out = dims[2]

        weights_in = []; biases_in = []

        if FLAGS.n_in == 0 or (FLAGS.n_in == 1 and FLAGS.varsel):
            dim_in = dim_input
        if FLAGS.n_out == 0:
            if FLAGS.split_output == False:
                dim_out = dim_in+1
            else:
                dim_out = dim_in

        if FLAGS.batch_norm:
            bn_biases = []
            bn_scales = []

        k_eps = 0.#1e-3
        ''' Construct encoders and decoders '''
        e_z1_mu_r, e_z1_mu_rn, e_z1_mu_w, e_z1_mu_b = self._build_latent_graph([tf.concat([x, t], axis=1)], dim_input+1, dim_in)
        e_z1_lv_r, e_z1_lv_rn, e_z1_lv_w, e_z1_lv_b = self._build_latent_graph([tf.concat([x, t], axis=1)], dim_input+1, dim_in)

        eps = tf.random.normal(shape=(tf.shape(x)[0], dim_in))
        z1_sample = e_z1_mu_rn + k_eps*eps*tf.exp(e_z1_lv_rn*.5)

        y, e_y_wo, e_y_wp, e_y_bo, e_y_bp = self._build_output_graph(z1_sample, t, dim_in, dim_out, do_out, FLAGS)

        
        e_z3_mu_r, e_z3_mu_rn, e_z3_mu_w, e_z3_mu_b = self._build_latent_graph([tf.concat([x, y_], axis=1)], dim_input+1, dim_in)
        e_z3_lv_r, e_z3_lv_rn, e_z3_lv_w, e_z3_lv_b = self._build_latent_graph([tf.concat([x, y_], axis=1)], dim_input+1, dim_in)

        eps = tf.random.normal(shape=(tf.shape(x)[0], dim_in))
        z3_sample = e_z3_mu_rn + k_eps*eps*tf.exp(e_z3_lv_rn*.5)

        e_t_mu_w, e_t_mu_b, e_t_cost, e_t_mu = self._build_treatment_graph(z3_sample, dim_in)

        
        e_z2_mu_r, e_z2_mu_rn, e_z2_mu_w, e_z2_mu_b = self._build_latent_graph([tf.concat([z1_sample, y], axis=1)], dim_in+1, dim_in)
        e_z2_lv_r, e_z2_lv_rn, e_z2_lv_w, e_z2_lv_b = self._build_latent_graph([tf.concat([z1_sample, y], axis=1)], dim_in+1, dim_in)

        eps = tf.random.normal(shape=(tf.shape(x)[0], dim_in))
        z2_sample = e_z2_mu_rn + k_eps*eps*tf.exp(e_z2_lv_rn*.5)

        
        e_z4_mu_r, e_z4_mu_rn, e_z4_mu_w, e_z4_mu_b = self._build_latent_graph([tf.concat([z3_sample, e_t_mu], axis=1)], dim_in+1, dim_in)
        e_z4_lv_r, e_z4_lv_rn, e_z4_lv_w, e_z4_lv_b = self._build_latent_graph([tf.concat([z3_sample, e_t_mu], axis=1)], dim_in+1, dim_in)

        eps = tf.random.normal(shape=(tf.shape(x)[0], dim_in))
        z4_sample = e_z4_mu_rn + k_eps*eps*tf.exp(e_z4_lv_rn*.5)

        
        d_z1_mu_r, d_z1_mu_rn, d_z1_mu_w, d_z1_mu_b = self._build_latent_graph([tf.concat([z2_sample, y_], axis=1)], dim_in+1, dim_in)

        d_z3_mu_r, d_z3_mu_rn, d_z3_mu_w, d_z3_mu_b = self._build_latent_graph([tf.concat([z4_sample, t], axis=1)], dim_in+1, dim_in)

        d_x_mu_r, d_x_mu_rn, d_x_mu_w, d_x_mu_b = self._build_latent_graph([tf.concat([z1_sample, z3_sample], axis=1)], 2*dim_in, dim_input)


        self.weights_in = e_z1_mu_w + e_z1_lv_w + e_z2_mu_w + e_z2_lv_w + d_z1_mu_w + d_x_mu_w \
                        + e_z3_mu_w + e_z3_lv_w + e_z4_mu_w + e_z4_lv_w + d_z3_mu_w + [e_t_mu_w]
        self.weights_out = e_y_wo
        self.weights_pred = e_y_wp
        self.biases_in  = e_z1_mu_b + e_z1_lv_b + e_z2_mu_b + e_z2_lv_b + d_z1_mu_b + d_x_mu_b \
                        + e_z3_mu_b + e_z3_lv_b + e_z4_mu_b + e_z4_lv_b + d_z3_mu_b + [e_t_mu_b]
        self.biases_out = e_y_bo
        self.bias_pred  = e_y_bp


        RECL = 0.
        RECL = logp(x, d_x_mu_rn, 1.)

        # mA = 8; mB = 8; mC = 8; mD = 1
        # self.reclA = logp(x[:,0:mA], d_x_mu_rn[:,0:mA], 1.) / (mA)
        # self.reclBCD = logp(x[:,mA:], d_x_mu_rn[:,mA:], 1.) / (mB+mC+mD)
        
        KLD = 0.
        KLD += kldivergence(e_z1_mu_rn, e_z1_lv_rn, d_z1_mu_rn, tf.ones(tf.shape(e_z1_lv_rn)))
        KLD += kldivergence(e_z3_mu_rn, e_z3_lv_rn, d_z3_mu_rn, tf.ones(tf.shape(e_z3_lv_rn)))
        KLD += kldivergence(e_z2_mu_rn, e_z2_lv_rn, tf.zeros(tf.shape(e_z2_mu_rn)), tf.ones(tf.shape(e_z2_lv_rn)))
        KLD += kldivergence(e_z4_mu_rn, e_z4_lv_rn, tf.zeros(tf.shape(e_z4_mu_rn)), tf.ones(tf.shape(e_z4_lv_rn)))
        ### KLD += kldivergence(y, tf.ones(tf.shape(y)), tf.zeros(tf.shape(y)), tf.ones(tf.shape(y)))
        
        ELBO = FLAGS.p_gamma*(RECL - FLAGS.p_beta*KLD)
        

        ''' Weights '''
        if FLAGS.reweight_sample:
            w_t = t/(2.*p_t)
            w_c = (1.-t)/(2.*(1.-p_t))
            sample_weight = (w_t+w_c)   # Default to Population-Based (PB) weighting scheme
        else:
            sample_weight = 1.0
        self.sample_weight = sample_weight


        ''' Construct factual loss function '''
        if FLAGS.loss == 'l2':
            risk = tf.reduce_mean(sample_weight*tf.square(y_ - y))
            pred_error = tf.sqrt(tf.reduce_mean(tf.square(y_ - y)))


        ''' Regularization '''
        if FLAGS.p_lambda>0 and FLAGS.rep_weight_decay:
            for i in range(0, FLAGS.n_in):
                if not (FLAGS.varsel and i==0): # No penalty on W in variable selection
                    self.wd_loss += tf.nn.l2_loss(self.weights_in[i]) 


        ''' Imbalance error '''
        imb_error, imb_dist = self._calculate_disc(z1_sample, r_alpha, FLAGS)

        
        ''' Total error '''
        tot_error = risk + logp(t, e_t_mu, 1.) - ELBO

        if FLAGS.p_alpha > 0:
            tot_error = tot_error + imb_error

        if FLAGS.p_lambda > 0:
            tot_error = tot_error + r_lambda*self.wd_loss

        self.output = y
        self.tot_loss = tot_error
        self.imb_loss = imb_error
        self.imb_dist = imb_dist
        self.pred_loss = pred_error
        # self.RECL = RECL
        # self.KLD = KLD
        
        self.h_rep = e_z1_mu_r
        self.h_rep_norm = e_z1_mu_rn

    def _build_latent_graph(self, h_in, dim_input, dim_in):
        weights_in = []; biases_in = []

        # mA = 4
        # h_in = [self.x[:, mA:-1]]
        # dim_input = int(h_in[0].shape[1])

        for i in range(0, FLAGS.n_in):
            if i==0:
                ''' If using variable selection, first layer is just rescaling'''
                if FLAGS.varsel:
                    weights_in.append(tf.Variable(1.0/dim_input*tf.ones([dim_input])))
                else:
                    weights_in.append(tf.Variable(tf.random_normal([dim_input, dim_in], stddev=FLAGS.weight_init/np.sqrt(dim_input))))
            else:
                weights_in.append(tf.Variable(tf.random_normal([dim_in,dim_in], stddev=FLAGS.weight_init/np.sqrt(dim_in))))

            ''' If using variable selection, first layer is just rescaling'''
            if FLAGS.varsel and i==0:
                biases_in.append([])
                h_in.append(tf.mul(h_in[i],weights_in[i]))
            else:
                biases_in.append(tf.Variable(tf.zeros([1,dim_in])))
                z = tf.matmul(h_in[i], weights_in[i]) + biases_in[i]

                if FLAGS.batch_norm:
                    batch_mean, batch_var = tf.nn.moments(z, [0])

                    if FLAGS.normalization == 'bn_fixed':
                        z = tf.nn.batch_normalization(z, batch_mean, batch_var, 0, 1, 1e-3)
                    else:
                        bn_biases.append(tf.Variable(tf.zeros([dim_in])))
                        bn_scales.append(tf.Variable(tf.ones([dim_in])))
                        z = tf.nn.batch_normalization(z, batch_mean, batch_var, bn_biases[-1], bn_scales[-1], 1e-3)

                h_in.append(self.nonlin(z))
                h_in[i+1] = tf.nn.dropout(h_in[i+1], self.do_in)

        h_rep = h_in[len(h_in)-1]

        if FLAGS.normalization == 'divide':
            h_rep_norm = h_rep / safe_sqrt(tf.reduce_sum(tf.square(h_rep), axis=1, keep_dims=True))
        else:
            h_rep_norm = 1.0*h_rep

        return h_rep, h_rep_norm, weights_in, biases_in

    def _build_output(self, h_input, dim_in, dim_out, do_out, FLAGS):
        h_out = [h_input]
        dims = [dim_in] + ([dim_out]*FLAGS.n_out)

        weights_out = []; biases_out = []

        for i in range(0, FLAGS.n_out):
            wo = self._create_variable_with_weight_decay(
                    tf.random_normal([dims[i], dims[i+1]],
                        stddev=FLAGS.weight_init/np.sqrt(dims[i])),
                    'w_out_%d' % i, 1.0)
            weights_out.append(wo)

            biases_out.append(tf.Variable(tf.zeros([1,dim_out])))
            z = tf.matmul(h_out[i], weights_out[i]) + biases_out[i]
            # No batch norm on output because p_cf != p_f

            h_out.append(self.nonlin(z))
            h_out[i+1] = tf.nn.dropout(h_out[i+1], do_out)

        weights_pred = self._create_variable(tf.random_normal([dim_out,1],
            stddev=FLAGS.weight_init/np.sqrt(dim_out)), 'w_pred')
        bias_pred = self._create_variable(tf.zeros([1]), 'b_pred')

        if FLAGS.varsel or FLAGS.n_out == 0:
            self.wd_loss += tf.nn.l2_loss(tf.slice(weights_pred,[0,0],[dim_out-1,1])) #don't penalize treatment coefficient
        else:
            self.wd_loss += tf.nn.l2_loss(weights_pred)

        ''' Construct linear classifier '''
        h_pred = h_out[-1]
        y = tf.matmul(h_pred, weights_pred)+bias_pred

        return y, weights_out, weights_pred, biases_out, bias_pred

    def _build_output_graph(self, rep, t, dim_in, dim_out, do_out, FLAGS):
        ''' Construct output/regression layers '''
        if FLAGS.split_output:
            i0 = tf.cast(tf.where(t < 1)[:,0], dtype=tf.int32)
            i1 = tf.cast(tf.where(t > 0)[:,0], dtype=tf.int32)

            rep0 = tf.gather(rep, i0)
            rep1 = tf.gather(rep, i1)

            y0, weights_out0, weights_pred0, biases_out0, bias_pred0 = self._build_output(rep0, dim_in, dim_out, do_out, FLAGS)
            y1, weights_out1, weights_pred1, biases_out1, bias_pred1 = self._build_output(rep1, dim_in, dim_out, do_out, FLAGS)

            y = tf.dynamic_stitch([i0, i1], [y0, y1])
            weights_out = weights_out0 + weights_out1
            weights_pred = weights_pred0 + weights_pred1
            biases_out = biases_out0 + biases_out1
            bias_pred = bias_pred0 + bias_pred1
        else:
            h_input = tf.concat(1,[rep, t])
            y, weights_out, weights_pred, biases_out, bias_pred = self._build_output(h_input, dim_in+1, dim_out, do_out, FLAGS)

        return y, weights_out, weights_pred, biases_out, bias_pred

    def _calculate_disc(self, h_rep_norm, coef, FLAGS):
        t = self.t

        if FLAGS.use_p_correction:
            p_ipm = self.p_t
        else:
            p_ipm = 0.5

        if FLAGS.imb_fun == 'mmd2_rbf':
            imb_dist = mmd2_rbf(h_rep_norm,t,p_ipm,FLAGS.rbf_sigma)
            imb_error = coef*imb_dist
        elif FLAGS.imb_fun == 'mmd2_lin':
            imb_dist = mmd2_lin(h_rep_norm,t,p_ipm)
            imb_error = coef*mmd2_lin(h_rep_norm,t,p_ipm)
        elif FLAGS.imb_fun == 'mmd_rbf':
            imb_dist = tf.abs(mmd2_rbf(h_rep_norm,t,p_ipm,FLAGS.rbf_sigma))
            imb_error = safe_sqrt(tf.square(coef)*imb_dist)
        elif FLAGS.imb_fun == 'mmd_lin':
            imb_dist = mmd2_lin(h_rep_norm,t,p_ipm)
            imb_error = safe_sqrt(tf.square(coef)*imb_dist)
        elif FLAGS.imb_fun == 'wass':
            imb_dist, imb_mat = wasserstein(h_rep_norm,t,p_ipm,lam=FLAGS.wass_lambda,its=FLAGS.wass_iterations,sq=False,backpropT=FLAGS.wass_bpt)
            imb_error = coef * imb_dist
            self.imb_mat = imb_mat # FOR DEBUG
        elif FLAGS.imb_fun == 'wass2':
            imb_dist, imb_mat = wasserstein(h_rep_norm,t,p_ipm,lam=FLAGS.wass_lambda,its=FLAGS.wass_iterations,sq=True,backpropT=FLAGS.wass_bpt)
            imb_error = coef * imb_dist
            self.imb_mat = imb_mat # FOR DEBUG
        else:
            imb_dist = lindisc(h_rep_norm,t,p_ipm)
            imb_error = coef * imb_dist

        return imb_error, imb_dist

    def _build_treatment_graph(self, h_rep_norm, dim_in):
        t = self.t

        W = tf.Variable(tf.random.normal([dim_in, 1], stddev=FLAGS.weight_init/np.sqrt(dim_in)), name='W')
        # W = tf.Variable(tf.zeros([dim_in, 1]), name='W')
        b = tf.Variable(tf.zeros([1]), name='b')    #-tf.math.log(1-1./p_t)
        sigma = tf.nn.sigmoid(tf.matmul(h_rep_norm, W) + b)
        
        pi_0 = tf.multiply(t, sigma) + tf.multiply(1.0-t, 1.0-sigma)
        # _, var = tf.nn.moments(1./pi_0, axes=[0])

        cost = tf.reduce_mean(
                - tf.multiply(t, tf.math.log(sigma))
                - tf.multiply(1.0-t, tf.math.log(1.0-sigma))
            ) \
            + FLAGS.p_lambda*tf.nn.l2_loss(W) #\
            # + FLAGS.p_gamma*tf.math.sqrt(var)
        return W, b, cost, pi_0
