import numpy as np
import os
from scipy import stats, signal
import tensorflow as tf
import hashlib
import inspect
import random
from tensorflow.contrib import layers
import time
import matplotlib.pyplot as plt
from keras import backend as K
import time

def smoothness_regularizer_2d(W, weight=1.0):
    with tf.compat.v1.variable_scope('smoothness'):
        lap = tf.constant([[0.25, 0.5, 0.25], [0.5, -3.0, 0.5], [0.25, 0.5, 0.25]])
        lap = tf.expand_dims(tf.expand_dims(lap, 2), 3)
        out_channels = W.get_shape().as_list()[2]
        W_lap = tf.nn.depthwise_conv2d(tf.transpose(W, perm=[3, 0, 1, 2]),
                                       tf.tile(lap, [1, 1, out_channels, 1]),
                                       strides=[1, 1, 1, 1], padding='SAME')
        penalty = tf.reduce_sum(tf.reduce_sum(tf.square(W_lap), [1, 2]) / tf.transpose(tf.reduce_sum(tf.square(W), [0, 1])))
        penalty = tf.identity(weight * penalty, name='penalty')
        tf.compat.v1.add_to_collection('smoothness_regularizer_2d', penalty)
        return penalty


def group_sparsity_regularizer_2d(W, weight=1.0):
    with tf.variable_scope('group_sparsity'):
        penalty = tf.reduce_sum(tf.sqrt(tf.reduce_sum(tf.square(W), [0, 1])))
        penalty = tf.identity(weight * penalty, name='penalty')
        tf.add_to_collection('group_sparsity_regularizer_2d', penalty)
        return penalty
     
def kl(z_mu,z_log_sigma):

    """
    d_kl(q(z|x)||p(z)) returns the KL-divergence between the prior
    p and the variational posterior q.
    :return: KL divergence between q and p  
    Formula: 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    """

    return - 0.5*tf.reduce_sum( 1.0 + 2.0 * z_log_sigma - tf.square(z_mu) -
                      tf.exp(2.0 * z_log_sigma), 1)

def elu(x):
    return tf.math.log(tf.exp(x) + 1, name='elu')
    

def inv_elu(x):
    return np.log(np.exp(x) - 1)


def poisson(prediction, response):
    return tf.reduce_mean(tf.reduce_sum(prediction - response * tf.log(prediction + 1e-5), 1), name='poisson')


def envelope(w, k=25):
    t = np.linspace(-2.5, 2.5, k, endpoint=True)
    u, v = np.meshgrid(t, t)
    win = np.exp(-(u ** 2 + v ** 2) / 2)
    sub = lambda x: x - np.median(x)
    return np.array([signal.convolve2d(sub(wi) ** 2, win, 'same') for wi in w])


def est_rf_location(x, y, k=25):
    zscore = lambda x: (x - x.mean()) / x.std()
    x = zscore(x[:,:,:,0])
    w = np.tensordot(y, x, axes=[[0], [0]])
    e = envelope(w, k)
    s = e.shape
    e = np.reshape(e, [s[0], -1])
    max_idx = np.argmax(e, axis=1)
    x = max_idx % s[2]
    y = max_idx // s[2]
    return x, y



class Net:

    def __init__(self, data=None, log_dir=None, log_hash=None, global_step=None):
        self.data = data
        self.best_loss = 1e100
        log_dir_ = os.path.dirname(inspect.stack()[0][1])
        log_dir = os.path.join(log_dir_, 'train_logs', 'cnn_tmp' if log_dir is None else log_dir)
        if log_hash == None: log_hash = '%010x' % random.getrandbits(40)
        self.log_dir = os.path.join(log_dir, log_hash)
        self.log_hash = log_hash
        
        self.seed = int.from_bytes(log_hash[:4].encode('utf8'), 'big')
        self.global_step = 0 if global_step == None else global_step
        self.session = None
        self.best_loss = 1e100

        # placeholders
        if data is None: return
        with tf.Graph().as_default() as self.graph:
            self.is_training = tf.compat.v1.placeholder (tf.bool)
            self.learning_rate = tf.compat.v1.placeholder (tf.float32)
            self.images = tf.compat.v1.placeholder (tf.float32, shape=[None, data.px_y, data.px_x, 1])
            self.responses = tf.compat.v1.placeholder (tf.float32, shape=[None, data.num_neurons])


    def initialize(self):
        self.session = tf.Session(graph=self.graph)
        self.session.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(max_to_keep=1)


    def __del__(self):
        try:
            if not self.session == None:
                self.session.close()
                self.writer.close()
        except:
            pass


    def close(self):
        self.session.close()


    def save(self):
        
        self.saver.save(self.session, os.path.join(self.log_dir, 'model.ckpt'))


    def load(self):

        self.saver.restore(self.session, os.path.join(self.log_dir, 'model.ckpt'))

    def train(self,
              max_iter,
              learning_rate,
              batch_size,
              val_steps,
              early_stopping_steps):
        with self.graph.as_default():
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            imgs_val, res_val = self.data.val()
            not_improved = 0
            for i in range(self.global_step + 1, self.global_step + max_iter + 1):

                # training step
                imgs_batch, res_batch = self.data.minibatch(batch_size)
                self.global_step = i
                feed_dict = {self.images: imgs_batch,
                             self.responses: res_batch,
                             self.is_training: True,
                             self.learning_rate: learning_rate}
                t = time.time()
                self.session.run([self.train_step, update_ops], feed_dict)
                # validate/save periodically
                if not i % val_steps:
                    result = self.eval(images=imgs_val, responses=res_val)
                    if result[0] < self.best_loss:
                        self.best_loss = result[0]
                        self.save()
                        not_improved = 0
                    else:
                        not_improved += 1
                    if not_improved == early_stopping_steps:
                        self.global_step -= early_stopping_steps * val_steps
                        self.load()
                        not_improved = 0
                        break
                    yield (i, time.time()-t,result)

    def eval_test(self, images=None, responses=None):
        if images is None or responses is None:
            images, responses = self.data.test()
        ops = self.get_test_ops()
        feed_dict = {self.images: images,
                     self.responses: responses,
                     self.is_training: False}
        #if self.is_training==False:
        self.load()
        result = self.session.run(ops, feed_dict)
        return result

    def eval(self, images=None, responses=None):
        if images is None or responses is None:
            images, responses = self.data.test()
        ops = self.get_test_ops()
        feed_dict = {self.images: images,
                     self.responses: responses,
                     self.is_training: False}
        #if self.is_training==False:

        result = self.session.run(ops, feed_dict)
        return result



    def get_test_ops(self):
        return [self.total_loss,
                self.recon_loss,
                self.poisson_loss,
                self.imgs_r,
                self.images,
                self.responses,
                self.prediction]

class ConvNet(Net):

    def __init__(self, *args, **kwargs):
        super(ConvNet, self).__init__(*args, **kwargs)
        self.conv = []
        self.W = []
        self.conv_tr = []
        self.W_tr = []
        self.W_read_conv=[]
        self.conv_read_conv=[]
        self.readout_sparseness_regularizer = 0.0

    def encoder(self,
              filter_sizes=[3, 3,3,3],
              out_channels=[32, 32, 32,32],
              strides=[2,1,2,1],
              paddings=['SAME', 'SAME','SAME','SAME'],
              smooth_weights=[0.00, 0, 0,0],
              sparse_weights=[0.0, 0.00, 0.00,0.00]):
        for i,(filter_size,
              out_chans,
              stride,
              padding,
              smooth_weight,sparse_weight) in enumerate(zip(filter_sizes,
                                                             out_channels,
                                                             strides,
                                                             paddings,
                                                             smooth_weights,
                                                             sparse_weights)):
            x = self.images if not i else self.conv[i-1]
            imgs = self.images
            bn_params = {'decay': 0.9, 'is_training': self.is_training}
            scope = 'conv{}'.format(i)

            reg = lambda w: smoothness_regularizer_2d(w, smooth_weight) + \
                group_sparsity_regularizer_2d(w, sparse_weight)
            c = layers.convolution2d(inputs=x,
                    num_outputs=out_chans,
                    kernel_size=int(filter_size),
                    stride=int(stride),
                    padding=padding,
                    activation_fn= elu if i < len(filter_sizes) - 1 else None,
                    normalizer_fn=layers.batch_norm,
                    normalizer_params=bn_params,
                    weights_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.01),
                    weights_regularizer=reg,
                    scope=scope,
            )
            with tf.variable_scope(scope, reuse=True):
                W = tf.get_variable('weights')
            self.W.append(W)
            self.conv.append(c)
            print('encoder_____shape')
            print(c.shape)
        return c,imgs

    def decoder(self,x,
              filter_sizes=[3, 3, 3,3],
              out_channels=[32, 32,32,1],
              strides=[1, 2, 1, 2 ],
              paddings=[ 'SAME','SAME','SAME','SAME'],
              smooth_weights=[0.00, 0.0, 0.0, 0.0],
              sparse_weights=[0, 0.0, 0.0, 0.0],
              act=[elu,elu,elu,tf.nn.tanh] ):
        for i,(filter_size,
              out_chans,
              stride,
              padding,
              smooth_weight,
              sparse_weight,
              act) in enumerate(zip(filter_sizes,
                                    out_channels,
                                    strides,
                                    paddings,
                                    smooth_weights,
                                    sparse_weights,
                                    act)):
            z = x if not i else self.conv_tr[i-1]
            #self.imgs = self.images
            bn_params = {'decay': 0.9, 'is_training': self.is_training}
            scope = 'conv_tr{}'.format(i)
            #act = tf.nn.tanh if i==len(filter_sizes) else tf.nn.relu

            reg = lambda w: smoothness_regularizer_2d(w, smooth_weight) + \
                group_sparsity_regularizer_2d(w, sparse_weight)
            c = layers.convolution2d_transpose(inputs=z,
                    num_outputs=out_chans,
                    kernel_size=int(filter_size),
                    stride=int(stride),
                    padding=padding,
                    activation_fn=act,
                    normalizer_fn=layers.batch_norm,
                    normalizer_params=bn_params,
                    weights_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.01),
                    weights_regularizer=reg,
                    scope=scope,
            )
            with tf.variable_scope(scope, reuse=True):
                W1 = tf.get_variable('weights')
            self.W_tr.append(W1)
            self.conv_tr.append(c)
            print(c.shape)
        return c

    def latent_deep(self,c):
        l = layers.flatten(c)
        l_shape = l.shape
        print(l_shape)
        z = tf.layers.dense(l,100)
        print(z.shape)
        c1 = tf.layers.dense(z,int(l_shape[1]))
        c1 = tf.reshape(c1,tf.shape(c))
        print(c1.shape)

        return z, c1

    def latent_variational(self,c):
        l = layers.flatten(c)
        l_shape = l.shape
        print(l_shape)
        size = tf.shape(l)[0]
        n_latent =50
        z_mus = tf.layers.dense(l,n_latent)
        z_log_sigmas = 0.5*tf.layers.dense(l,n_latent)
        z = z_mus + tf.exp(z_log_sigmas) * tf.random_normal(shape=[size,n_latent])
        c1 = tf.layers.dense(z,int(l_shape[1]))
        c1 = tf.reshape(c1,tf.shape(c))
        print(c1.shape)
        self.kl_loss = kl(z_mus,z_log_sigmas)

        return z, c1

    def latent_vq(self,z_e):

        #ventor quantised
        z_e_shape = z_e.get_shape()
        self._embedding_dim = z_e_shape[-1]
        self._num_embeddings = 256
        
        self._decay = 0.99
        self._epsilon = 1e-5
        inputs =  z_e


        initializer = tf.random_normal_initializer()
        # w is a matrix with an embedding in each column. When training, the
        # embedding is assigned to be the average of all inputs assigned to that
        # embedding.

        self._w = tf.get_variable(
            'embedding', [self._embedding_dim, self._num_embeddings],
            initializer=initializer, use_resource=True)
        self._ema_cluster_size = tf.get_variable(
            'ema_cluster_size', [self._num_embeddings],
          initializer=tf.constant_initializer(0), use_resource=True)
        self._ema_w = tf.get_variable(
            'ema_dw', initializer=self._w.initialized_value(), use_resource=True)
        print('999999999999999999999999999999999999999999999')
        print(self._embedding_dim)


        with tf.control_dependencies([inputs]):
            w = self._w.read_value()
        input_shape = tf.shape(inputs)

        with tf.control_dependencies([
            tf.Assert(tf.equal(input_shape[-1], self._embedding_dim),
                  [input_shape])]):
            flat_inputs = tf.reshape(inputs, [-1, self._embedding_dim])

        distances = (tf.reduce_sum(flat_inputs**2, 1, keepdims=True)
                 - 2 * tf.matmul(flat_inputs, w)
                 + tf.reduce_sum(w ** 2, 0, keepdims=True))

        encoding_indices = tf.argmax(- distances, 1)
        self.encoding_indices = encoding_indices
        encodings = tf.one_hot(encoding_indices, self._num_embeddings)
        encoding_indices = tf.reshape(encoding_indices, tf.shape(inputs)[:-1])
        quantized = self.quantize(encoding_indices)
        e_latent_loss = tf.reduce_mean((tf.stop_gradient(quantized) - inputs) ** 2)

        is_trainin = True

        if is_trainin == self.is_training:
            updated_ema_cluster_size = moving_averages.assign_moving_average(
                self._ema_cluster_size, tf.reduce_sum(encodings, 0), self._decay)

            dw = tf.matmul(flat_inputs, encodings, transpose_a=True)
            updated_ema_w = moving_averages.assign_moving_average(self._ema_w, dw,
                                                            self._decay)
            n = tf.reduce_sum(updated_ema_cluster_size)
            updated_ema_cluster_size = ((updated_ema_cluster_size + self._epsilon)
                / (n + self._num_embeddings * self._epsilon) * n)

            normalised_updated_ema_w = (
                updated_ema_w / tf.reshape(updated_ema_cluster_size, [1, -1]))

            with tf.control_dependencies([e_latent_loss]):
                update_w = tf.assign(self._w, normalised_updated_ema_w)
                with tf.control_dependencies([update_w]):
                    self.l_loss = e_latent_loss

        else:
            self.l_loss = e_latent_loss
        
        #quantized = inputs + tf.stop_gradient(quantized - inputs)
        #avg_probs = tf.reduce_mean(encodings, 0)
        #perplexity = tf.exp(- tf.reduce_sum(avg_probs * tf.log(avg_probs + 1e-10)))
        
        stop_grd = lambda z: z[0] + tf.stop_gradient(z[1] - z[0])
        ze = stop_grd([z_e, quantized])

        self.z_e = z_e
        self.z_q = quantized
        
        
        return encoding_indices, ze



    def read_out(self,c,
        fully_connected_readout,
        fixed_rfs):

        # initialize biases
        images, responses = self.data.train()
        b = inv_elu(responses.mean(axis=0))

        if fully_connected_readout: ## fully connected readout as in McIntosh et al. 2017
            c = layers.flatten(c)
            #c = layers.dropout(c, 0.5, is_training=self.is_training)
            prediction = layers.fully_connected(
                            c,self.data.num_neurons,activation_fn=elu,
                            biases_initializer=tf.constant_initializer(b)
                            )
            print('--------------------- FC')
            return prediction

        else: 
            sz = c.get_shape()
            out_channels =sz[-1]
            px_x_conv = int(sz[2])
            px_y_conv = int(sz[1])
            px_conv = px_x_conv * px_y_conv
            conv_flat = tf.reshape(c, [-1, px_conv, out_channels, 1])
            #print('---------------------')
            if fixed_rfs:
                # instead of learning the spatial mask, here we extract RF 
                # locations from STA (similar to Batty et al. 2017)
                print('--------------------- FF')
                crop = (images.shape[1] - px_x_conv) // 2
                rf_x, rf_y = est_rf_location(images, responses, k=25)
                rf_x = np.maximum(np.minimum(rf_x - crop, px_x_conv - 1), 0)
                rf_y = np.maximum(np.minimum(rf_y - crop, px_y_conv - 1), 0)
                w_init = np.zeros([self.data.num_neurons, px_x_conv, px_y_conv])
                for i, (x, y) in enumerate(zip(rf_x, rf_y)):
                    w_init[i,y,x] = 1
                w_init = np.reshape(w_init, [w_init.shape[0], -1]).T
                self.W_spatial = tf.get_variable('W_spatial',
                                                     shape=[px_conv, self.data.num_neurons],
                                                     initializer=tf.constant_initializer(w_init),
                                                     trainable=False)
            else:
                self.W_spatial = tf.get_variable('W_spatial',
                                                 shape=[px_conv, self.data.num_neurons],
                                                 initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.01))
                print('--------------------- FR')
            W_spatial_flat = tf.reshape(self.W_spatial, [px_conv, 1, 1, self.data.num_neurons])
            W_spatial_flat = tf.abs(W_spatial_flat)
            h_spatial = tf.nn.conv2d(conv_flat, W_spatial_flat, strides=[1, 1, 1, 1], padding='VALID')
            self.W_features = tf.get_variable('W_features',
                                                  shape=[out_channels, self.data.num_neurons],
                                                  initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.01))
            self.W_features = tf.abs(self.W_features)
            self.h_out = tf.reduce_sum(tf.multiply(h_spatial, self.W_features), [1, 2])

            # output nonlinearity
            self.b_out = tf.get_variable('b_out',
                                             shape=[self.data.num_neurons],
                                             dtype=tf.float32,
                                             initializer=tf.constant_initializer(b))
            prediction = elu(self.h_out + self.b_out)
            return prediction

    def build_vqvae(self,
              read_ae_conv,
              fully_connected_readout,
              fixed_rfs):

        with self.graph.as_default():

            # Build Autoencoder model
            self.ae_alpha = read_ae_conv[0]
            # encoder 
            en_conv,imgs_o = self.encoder()
            # latent
            #z,de_conv = self.latent_deep(en_conv)
            z,de_conv = self.latent_vq(en_conv)
            # decoder
            imgs_r = self.decoder(de_conv)
            self.imgs_r = imgs_r
            # ae loss 
            self.recon_loss = tf.reduce_mean((imgs_r - imgs_o)**2,axis=[0,1,2,3]) 
            beta = 0.25
            #self.vq = tf.reduce_mean(tf.norm(tf.stop_gradient(self.z_e) - self.z_q,axis=-1)**2,axis=[0,1,2])
            #self.commit = tf.reduce_mean(tf.norm(self.z_e - tf.stop_gradient(self.z_q),axis=-1)**2,axis=[0,1,2])
            self.ae_loss = self.recon_loss +  beta * self.l_loss

    
            # Build readout model
            self.readout_beta = read_ae_conv[1]
            read_e =self.conv[read_ae_conv[2]]
            #self.prediction = self.read_new(read_e)
            self.prediction = self.read_out(read_e,fully_connected_readout,fixed_rfs)
            # readout loss
            self.poisson_loss = poisson(self.prediction, self.responses)
            # self.poisson = tf.reduce_mean((self.prediction - self.responses)**2,axis=[0,1])



            # Build optimizing readout and AE model
            # all loss
            self.total_loss = self.readout_beta*self.poisson_loss+self.ae_alpha*self.ae_loss
            # optimizer
            self.train_step = tf.compat.v1.train.AdamOptimizer(self.learning_rate).minimize(self.total_loss)
            # initialize TF session
            self.initialize()

    def build_dae(self,
              read_ae_conv,
              fully_connected_readout,
              fixed_rfs):

        with self.graph.as_default():

            # Build Autoencoder model
            self.ae_alpha = read_ae_conv[0]
            # encoder 
            en_conv,imgs_o = self.encoder()
            # latent
            z,de_conv = self.latent_deep(en_conv)
            #z,de_conv = self.latent_vq(en_conv)
            # decoder
            imgs_r = self.decoder(de_conv)
            self.imgs_r = imgs_r
            # ae loss 
            self.recon_loss = tf.reduce_mean((imgs_r - imgs_o)**2,axis=[0,1,2,3]) 
            #self.vq = tf.reduce_mean(tf.norm(tf.stop_gradient(self.z_e) - self.z_q,axis=-1)**2,axis=[0,1,2])
            #self.commit = tf.reduce_mean(tf.norm(self.z_e - tf.stop_gradient(self.z_q),axis=-1)**2,axis=[0,1,2])
            self.ae_loss = self.recon_loss 

    
            # Build readout model
            self.readout_beta = read_ae_conv[1]
            read_e =self.conv[read_ae_conv[2]]
            #self.prediction = self.read_new(read_e)
            self.prediction = self.read_out(read_e,fully_connected_readout,fixed_rfs)
            # readout loss
            self.poisson_loss = poisson(self.prediction, self.responses)
            # self.poisson = tf.reduce_mean((self.prediction - self.responses)**2,axis=[0,1])



            # Build optimizing readout and AE model
            # all loss
            self.total_loss = self.readout_beta*self.poisson_loss+self.ae_alpha*self.ae_loss
            # optimizer
            self.train_step = tf.compat.v1.train.AdamOptimizer(self.learning_rate).minimize(self.total_loss)
            # initialize TF session
            self.initialize()

    def build_vae(self,
              read_ae_conv,
              fully_connected_readout,
              fixed_rfs):

        with self.graph.as_default():

            # Build Autoencoder model
            self.ae_alpha = read_ae_conv[0]
            # encoder 
            en_conv,imgs_o = self.encoder()
            # latent
            z,de_conv = self.latent_variational(en_conv)
            #z,de_conv = self.latent_vq(en_conv)
            # decoder
            imgs_r = self.decoder(de_conv)
            self.imgs_r = imgs_r
            # ae loss 
            self.recon_loss = tf.reduce_mean((imgs_r - imgs_o)**2,axis=[1,2,3]) 
            #self.vq = tf.reduce_mean(tf.norm(tf.stop_gradient(self.z_e) - self.z_q,axis=-1)**2,axis=[0,1,2])
            #self.commit = tf.reduce_mean(tf.norm(self.z_e - tf.stop_gradient(self.z_q),axis=-1)**2,axis=[0,1,2])
            print(self.kl_loss.shape)
            self.ae_loss =  tf.reduce_mean(1000*self.recon_loss+self.kl_loss,axis=[0])

    
            # Build readout model
            self.readout_beta = read_ae_conv[1]
            read_e =self.conv[read_ae_conv[2]]
            #self.prediction = self.read_new(read_e)
            self.prediction = self.read_out(read_e,fully_connected_readout,fixed_rfs)
            # readout loss
            self.poisson_loss = poisson(self.prediction, self.responses)
            # self.poisson = tf.reduce_mean((self.prediction - self.responses)**2,axis=[0,1])


            # Build optimizing readout and AE model
            # all loss
            self.total_loss = self.readout_beta*self.poisson_loss+self.ae_alpha*self.ae_loss
            # optimizer
            self.train_step = tf.compat.v1.train.AdamOptimizer(self.learning_rate).minimize(self.total_loss)
            # initialize TF session
            self.initialize()

    @property

    def embeddings(self):
        return self._w

    def quantize(self, encoding_indices):
        with tf.control_dependencies([encoding_indices]):
            w = tf.transpose(self.embeddings.read_value(), [1, 0])
        return tf.nn.embedding_lookup(w, encoding_indices, validate_indices=False)



