import tensorflow as tf
from tensorflow.contrib import slim
import numpy as np
import sys
from tensorflow.contrib.layers.python.layers import initializers
from edward.models import Bernoulli, Normal
import tensorflow.contrib.graph_editor as ge
from functools import partial

from collections import namedtuple
NNBuilderConfig = namedtuple('ArtConfig',
                             ['initializer', 'activation', 'l2_lambda', 'dropout'])

NNBuilderConfig.__new__.__defaults__ = NNBuilderConfig(
    initializers.variance_scaling_initializer,
    tf.nn.elu,
    1e-4,
    0,
)


# class Normal_patch(Normal):
#     def __init__(self, *args, **kwargs):
#         super(Normal_patch, self).__init__(*args, **kwargs)
#
#     def _z(self, x):
#         with tf.name_scope("standardize", values=[x]):
#             return (x - self.loc) / tf.maximum(self.scale, .1)


class NNBuilder(NNBuilderConfig):
    def fc_net(self, inp, layers, out_layers, scope=None, reuse=None, seed=None):
        with slim.arg_scope([slim.fully_connected],
                            activation_fn=self.activation,
                            normalizer_fn=None,
                            weights_initializer=self.initializer(seed=seed),
                            reuse=reuse,
                            weights_regularizer=slim.l2_regularizer(self.l2_lambda)):

            if layers:
                # h = slim.stack(inp, slim.fully_connected, layers, scope=scope)
                h = inp
                for i, n in enumerate(layers):
                    h = slim.fully_connected(h, n, scope=scope and scope + str(i))
                    h = tf.nn.dropout(h,  1. - self.dropout)
                if not out_layers:
                    return h
            else:
                h = inp
            # wrong, should add after each layer
            # if self.dropout > 0:
            #     h = tf.nn.dropout(h, 1-self.dropout)

            outputs = []
            for i, (outdim, activation) in enumerate(out_layers):
                o1 = slim.fully_connected(h, outdim, activation_fn=activation, scope=scope and scope + '_{}'.format(i + 1))
                outputs.append(o1)
            return outputs if len(outputs) > 1 else outputs[0]

    def mu_sigma(self, input, layers, nout, var_sigma, switch, name=None, seed=None):
        def _mu_sigma(name, seed):
            if var_sigma:
                return self.fc_net(input, layers, [[nout, None], [nout, tf.nn.softplus]], name, seed=seed)
            else:
                mu = self.fc_net(input, layers, [[nout, None]], name, seed=seed)
                return mu, tf.ones_like(mu)

        if switch:
            return _mu_sigma(name and name+'_t0', seed), _mu_sigma(name and name+'_t1', seed+1 if seed is not None else None)
        else:
            return _mu_sigma(name, seed)

    def normal(self, input, layers, nout, var_sigma, switch, t, name=None, reweight=1., linear_out=False):

        if switch:
            (mu0, sigma0), (mu1, sigma1) = self.mu_sigma(input, layers, nout, var_sigma, True, name and name+'_mu_sigma')
            if linear_out:
                (mu0, _), (mu1, _) = self.mu_sigma(input, [], nout, var_sigma, True, name and name+'_mu_lin')
            return Normal(loc=t*mu1+(1.-t)*mu0, scale=reweight*(t*sigma1+(1.-t)*sigma0), name=name or 'Normal')
        else:
            mu, sigma = self.mu_sigma(tf.concat([input, t], 1), layers, nout, var_sigma, False, name and name+'_mu_sigma')
            if linear_out:
                mu, _ = self.mu_sigma(tf.concat([input, t], 1), [], nout, var_sigma, False, name and name+'_mu_lin')
            return Normal(loc=mu, scale=reweight*sigma, name=name or 'Normal')


# IMPORTANT: Should call this before using module level fc_net, mu_sigma, normal
def config_nnbuilder(**kwargs):
    _nnbuilder = NNBuilder()._replace(**kwargs)

    import inspect
    for name, member in inspect.getmembers(_nnbuilder, inspect.ismethod):
        if not name.startswith('_'):
            # inject to module level
            globals()[name] = member


# for name resolution
fc_net, mu_sigma, normal = None, None, None


def get_y0_y1(sess, y, f0, f1, shape=(), L=1, verbose=True):
    y0, y1 = np.zeros(shape, dtype=np.float32), np.zeros(shape, dtype=np.float32)
    ymean = y.mean()
    for l in range(L):
        if L > 1 and verbose:
            sys.stdout.write('\r Sample {}/{}'.format(l + 1, L))
            sys.stdout.flush()
        y0 += sess.run(ymean, feed_dict=f0) / L
        y1 += sess.run(ymean, feed_dict=f1) / L

    if L > 1 and verbose:
        print()
    return y0, y1


def safe_sqrt(x, lbound=1e-10):
    ''' Numerically safe version of TensorFlow sqrt '''
    return tf.sqrt(tf.clip_by_value(x, lbound, np.inf))


def wasserstein(X,t,p,lam=10,its=10,sq=False,backpropT=False):
    """ Returns the Wasserstein distance between treatment groups """

    def pdist2sq(X, Y):
        """ Computes the squared Euclidean distance between all pairs x in X, y in Y """
        C = -2 * tf.matmul(X, tf.transpose(Y))
        nx = tf.reduce_sum(tf.square(X), 1, keep_dims=True)
        ny = tf.reduce_sum(tf.square(Y), 1, keep_dims=True)
        D = (C + tf.transpose(ny)) + nx
        return D

    it = tf.where(t>0)[:,0]
    ic = tf.where(t<1)[:,0]
    Xc = tf.gather(X,ic)
    Xt = tf.gather(X,it)
    nc = tf.to_float(tf.shape(Xc)[0])
    nt = tf.to_float(tf.shape(Xt)[0])

    ''' Compute distance matrix'''
    if sq:
        M = pdist2sq(Xt, Xc)
    else:
        M = safe_sqrt(pdist2sq(Xt, Xc))


    ''' Estimate lambda and delta '''
    M_mean = tf.reduce_mean(M)
    M_drop = tf.nn.dropout(M,10/(nc*nt))
    delta = tf.stop_gradient(tf.reduce_max(M))
    eff_lam = tf.stop_gradient(lam/M_mean)

    ''' Compute new distance matrix '''
    Mt = M
    row = delta*tf.ones(tf.shape(M[0:1,:]))

    def concat(axis, values):
        return tf.concat(values, axis)

    col = concat(0,[delta*tf.ones(tf.shape(M[:,0:1])),tf.zeros((1,1))])
    Mt = concat(0,[M,row])
    Mt = concat(1,[Mt,col])

    ''' Compute marginal vectors '''
    a = concat(0,[p*tf.ones(tf.shape(tf.where(t>0)[:,0:1]))/nt, (1-p)*tf.ones((1,1))])
    b = concat(0,[(1-p)*tf.ones(tf.shape(tf.where(t<1)[:,0:1]))/nc, p*tf.ones((1,1))])

    ''' Compute kernel matrix'''
    Mlam = eff_lam*Mt
    K = tf.exp(-Mlam) + 1e-6 # added constant to avoid nan
    U = K*Mt
    ainvK = K/a

    u = a
    for i in range(0,its):
        u = 1.0/(tf.matmul(ainvK,(b/tf.transpose(tf.matmul(tf.transpose(u),K)))))
    v = b/(tf.transpose(tf.matmul(tf.transpose(u),K)))

    T = u*(tf.transpose(v)*K)

    if not backpropT:
        T = tf.stop_gradient(T)

    E = T*Mt
    D = 2*tf.reduce_sum(E)

    return D, Mlam


def pinv(a, rcond=1e-15):
    s, u, v = tf.svd(a)
    # Ignore singular values close to zero to prevent numerical overflow
    limit = rcond * tf.reduce_max(s)
    non_zero = tf.greater(s, limit)

    reciprocal = tf.where(non_zero, tf.reciprocal(s), tf.zeros(s.shape))
    lhs = tf.matmul(v, tf.matrix_diag(reciprocal))
    return tf.matmul(lhs, u, transpose_b=True)

import imageio
from PIL import Image
import numpy as np


def center_crop(x, crop_h, crop_w=None, resize_w=64):
    # crop the images to [crop_h,crop_w,3] then resize to [resize_h,resize_w,3]
    if crop_w is None:
        crop_w = crop_h # the width and height after cropped
    h, w = x.shape[:2]
    j = int(round((h - crop_h)/2.))
    i = int(round((w - crop_w)/2.))
    return np.array(Image.fromarray(x[j:j+crop_h, i:i+crop_w].astype(np.uint8)).resize([resize_w, resize_w]))


def merge(images, size):
    # merge all output images(of sample size:8*8 output images of size 64*64) into one big image
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 3))
    for idx, image in enumerate(images): # idx=0,1,2,...,63
        i = idx % size[1] # column number
        j = idx // size[1] # row number
        img[j*h:j*h+h, i*w:i*w+w, :] = image
    return img


def transform(image, npx=64, is_crop=True, resize_w=64):
    if is_crop:
        cropped_image = center_crop(image, npx, resize_w=resize_w)
    else:
        cropped_image = image
    return np.array(cropped_image)/127.5 - 1.  # change pixel value range from [0,255] to [-1,1] to feed into CNN


def inverse_transform(images):
    return (images+1.)/2. # change image pixel value(outputs from tanh in range [-1,1]) back to [0,1]


def imread(path, is_grayscale = False):
    if (is_grayscale):
        return imageio.imread(path, flatten = True).astype(np.float) # [width,height] flatten RGB image to grayscale image
    else:
        return imageio.imread(path).astype(np.float) # [width,height,color_dim]


def imsave(images, size, path):
    return imageio.imsave(path, merge(images, size))


def get_image(image_path, image_size, is_crop=True, resize_w=64, is_grayscale = False):
    return transform(imread(image_path, is_grayscale), image_size, is_crop, resize_w)


def save_images(images, size, image_path):
    # size indicates how to arrange the images to form a big summary image
    # images: [batchsize,height,width,color]
    # example: save_images(img, [8, 8],'./{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir, epoch, idx))
    return imsave(inverse_transform(images), size, image_path)


def save_images_256(images, size, image_path):
    images = inverse_transform(images)
    h, w = 64, 64 # 256,256
    img = np.zeros((h * size[0], w * size[1], 3))
    for idx, image in enumerate(images): # idx=0,1,2,...,63
        image = np.array(Image.fromarray(image).resize([h, w]))
        i = idx % size[1] # column number
        j = idx // size[1] # row number
        img[j*h:j*h+h, i*w:i*w+w, :] = image
    return imageio.imsave(image_path, img)


def flatten_fully_connected(inputs,
                            num_outputs,
                            activation_fn=tf.nn.relu,
                            normalizer_fn=None,
                            normalizer_params=None,
                            weights_initializer=slim.xavier_initializer(),
                            weights_regularizer=None,
                            biases_initializer=tf.zeros_initializer(),
                            biases_regularizer=None,
                            reuse=None,
                            variables_collections=None,
                            outputs_collections=None,
                            trainable=True,
                            scope=None):
    with tf.variable_scope(scope, 'flatten_fully_connected', [inputs]):
        if inputs.shape.ndims > 2:
            inputs = slim.flatten(inputs)
        return slim.fully_connected(inputs,
                                    num_outputs,
                                    activation_fn,
                                    normalizer_fn,
                                    normalizer_params,
                                    weights_initializer,
                                    weights_regularizer,
                                    biases_initializer,
                                    biases_regularizer,
                                    reuse,
                                    variables_collections,
                                    outputs_collections,
                                    trainable,
                                    scope)


def lrelu(z, name=None):
    return tf.maximum(0.01 * z, z, name=name)


def conv_endec():

    conv = partial(slim.conv2d, activation_fn=None)
    dconv = partial(slim.conv2d_transpose, activation_fn=None)
    fc = partial(flatten_fully_connected, activation_fn=None)
    relu = tf.nn.relu
    batch_norm = partial(slim.batch_norm, scale=True)  # , updates_collections=None

    def Enc(img, z_dim, dim=64, is_training=True, sigma=False, scope='Enc'):
        bn = partial(batch_norm, is_training=is_training)
        conv_bn_lrelu = partial(conv, normalizer_fn=bn, activation_fn=lrelu)

        with tf.variable_scope(scope):
            y = conv_bn_lrelu(img, dim, 5, 2)
            y = conv_bn_lrelu(y, dim * 2, 5, 2)
            y = conv_bn_lrelu(y, dim * 4, 5, 2)
            y = conv_bn_lrelu(y, dim * 8, 5, 2)
            z_mu = fc(y, z_dim)
            if sigma:
                z_sigma = tf.nn.softplus(fc(y, z_dim))
                return z_mu, z_sigma
            else:
                z_log_sigma_sq = fc(y, z_dim)
                return z_mu, z_log_sigma_sq

    def Dec(z, dim=64, channels=3, is_training=True, sigma=True, scope='Dec'):
        bn = partial(batch_norm, is_training=is_training)
        dconv_bn_relu = partial(dconv, normalizer_fn=bn, activation_fn=relu)

        with tf.variable_scope(scope):
            y = relu(fc(z, 4 * 4 * dim * 8))
            y = tf.reshape(y, [-1, 4, 4, dim * 8])
            y = dconv_bn_relu(y, dim * 4, 5, 2)
            y = dconv_bn_relu(y, dim * 2, 5, 2)
            y = dconv_bn_relu(y, dim * 1, 5, 2)
            img = tf.tanh(dconv(y, channels, 5, 2))
            if sigma:
                sigmai = tf.nn.softplus(dconv(y, channels, 5, 2))
                return img, sigmai
            else:
                return img

    return Enc, Dec
