#!/usr/bin/env python
# -*- coding: utf-8 -*-


import os
os.environ["CUDA_VISIBLE_DEVICES"]="3"
import sys
import random
import pdb
import numpy as np
import tensorflow as tf
import cPickle as pkl
import six.moves.cPickle as pickle
from collections import OrderedDict
#from peentree import *
import tensorflow.contrib.eager as tfe
tfe.enable_eager_execution()

def _p(pp, name):
    return '%s_%s' % (pp, name)




def load_imdb_data(n_words=5000, sort_by_len=True):
  
    train_set = pkl.load(open("data/imdb_train.pkl", 'rb'))
    valid_set = pkl.load(open("data/imdb_valid.pkl", 'rb'))      

    train_set_x, train_set_y = train_set
   

    def remove_unk(x):
        return [[1 if w >= n_words else w for w in sen] for sen in x]

    valid_set_x, valid_set_y = valid_set
    train_set_x, train_set_y = train_set

    train_set_x = remove_unk(train_set_x)
    valid_set_x = remove_unk(valid_set_x)

    def len_argsort(seq):
        return sorted(range(len(seq)), key=lambda x: len(seq[x]))

    if sort_by_len:
        sorted_index = len_argsort(valid_set_x)
        valid_set_x = [valid_set_x[i] for i in sorted_index]
        valid_set_y = [valid_set_y[i] for i in sorted_index]

        sorted_index = len_argsort(train_set_x)
        train_set_x = [train_set_x[i] for i in sorted_index]
        train_set_y = [train_set_y[i] for i in sorted_index]

    train = (train_set_x, train_set_y)
    valid = (valid_set_x, valid_set_y)
    
    return train, valid
    
    
def prepare_data(seqs, labels, maxlen=100):

    lengths = [len(s) for s in seqs]
    if maxlen is not None:
        new_seqs = []
        new_labels = []
        new_lengths = []
        for l, s, y in zip(lengths, seqs, labels):
            if l < maxlen:
                new_seqs.append(s + [0] * (maxlen-len(s)))
                new_labels.append(y)

            else:
                new_seqs.append(s[0:maxlen])
                new_labels.append(y)
            new_lengths.append(maxlen)    
        lengths = new_lengths
        labels = new_labels
        seqs = new_seqs
    
    maxlen = np.max(lengths)
    n_samples = len(seqs)
    x = np.zeros((maxlen, n_samples))
    x_mask = np.zeros((maxlen, n_samples))
    for idx, s in enumerate(seqs):
        x[:lengths[idx], idx] = s
        x_mask[:lengths[idx], idx] = 1.
        
    
    return x.astype('int32'), x_mask.astype('float32'), np.asarray(labels).astype('int32')

def get_minibatches_idx(n, minibatch_size, shuffle=False):
    """
    Used to shuffle the dataset at each iteration.
    """

    idx_list = np.arange(n, dtype="int32")

    if shuffle:
        np.random.shuffle(idx_list)

    minibatches = []
    minibatch_start = 0
    for i in range(n // minibatch_size):
        minibatches.append(idx_list[minibatch_start:
                                    minibatch_start + minibatch_size])
        minibatch_start += minibatch_size

    if (minibatch_start != n):
        # Make a minibatch out of what is left
        minibatches.append(idx_list[minibatch_start:])
    
    return minibatches

def pred_error(f_pred, prepare_data, data, iterator, verbose=False):
    """
    Just compute the error
    f_pred: Theano fct computing the prediction
    prepare_data: usual prepare_data for that dataset.
    """
    
    valid_err = 0
    for _, valid_batch in iterator:
        x, mask, _, y = prepare_data([data[0][t] for t in valid_batch],
                                  np.array(data[1])[valid_batch],
                                  maxlen=None)
        preds = f_pred(x, mask)
        targets = np.array(data[1])[valid_batch]
        valid_err += (preds == targets).sum()
    valid_err = 1. - numpy_floatX(valid_err) / len(data[0])

    return valid_err

class FillWeight(tfe.Network):
    def __init__(self):
        super(FillWeight, self).__init__()

    def uniform_weight(self, n_in, n_out=None, fix=False):
        if n_out is None:
            n_out = n_in
        W = np.random.uniform(low=-np.sqrt(6. / (n_in + n_out)),
        high=np.sqrt(6. / (n_in + n_out)), size=(n_in, n_out))   
               
        if fix:    
            W = np.random.uniform(low=-0.5, high=0.5, size=(n_in, n_out))
        return W.astype('float32')
        
    
    def uniform_vector(self, n_, fix=False):
        W = np.random.uniform(low=-np.sqrt(6. / (n_)),
        high=np.sqrt(6. / (n_)), size=(1, n_))
        
        if fix:    
            W = np.random.uniform(low=-0.5, high=0.5, size=(n_in, n_out))
            
        return W.astype('float32')

    def ortho_weight(self, ndim):
        W = np.random.randn(ndim, ndim)
        u, s, v = np.linalg.svd(W)
        return u.astype('float32')


class Softmax_Layer(tfe.Network):
    
    def __init__(self, config, weight_filler, prefix='Dense'):
        super(Softmax_Layer, self).__init__()
        
        self.config = config
        self.params = dict()
        self.dim_hidden = config['dim_hidden']
        self.params[_p(prefix,'W')] = tf.Variable(weight_filler.uniform_weight(config['dim_hidden'], config['y_dim']).astype('float32'),name=_p(prefix,"W"))
        self.params[_p(prefix,'b')] = tf.Variable(np.zeros((config['y_dim'],)).astype('float32'), name=_p(prefix,"b"))
    
    def __call__(self, x, prefix='Dense'):
        
        softmax_layer = tf.nn.softmax(tf.matmul(x, self.params[_p(prefix,'W')])+self.params[_p(prefix, 'b')])
        return softmax_layer
    

class Dense_Layer(tfe.Network):

    def __init__(self, config, weight_filler, prefix='Dense', trained_param=None):
        super(Dense_Layer, self).__init__()
        
        self.config = config
        self.params = dict()
        
        self.dim_hidden = config['dim_hidden']

        if trained_param is not None:
            print("Dense_Layer")
            self.params[_p(prefix,'W')] = trained_param[_p(prefix,'W')]
            self.params[_p(prefix,'b')] = trained_param[_p(prefix,'b')]
        else:
            self.params[_p(prefix,'W')] = tf.Variable(weight_filler.uniform_weight(2*config['dim_hidden'], config['y_dim']).astype('float32'), name=_p(prefix,"W"))
            self.params[_p(prefix,'b')] = tf.Variable(np.zeros((config['y_dim'],)).astype('float32'), name=_p(prefix,"b"))
        
        
    
    def __call__(self, x, prefix='Dense'):
        
        dense_layer = tf.matmul(x, self.params[_p(prefix,'W')])+self.params[_p(prefix, 'b')]
        
        return dense_layer
    

class Embedding_Layer(tfe.Network):

    def __init__(self, config, weight_filler, prefix='Embedding', trained_param=None):
        super(Embedding_Layer, self).__init__()
        
        self.config = config
        self.params = dict()
        
        self.dim_hidden = config['dim_hidden']
        print("Embedding_Layer")
        if trained_param is not None:
            self.params['Wemb'] = trained_param['Wemb']
        else:
            self.params['Wemb'] = tf.Variable(weight_filler.uniform_weight(config['n_words'], config['dim_hidden']).astype('float32'), name=_p(prefix,"Wemb"))
        

    
    def __call__(self, x, prefix='Embedding'):
        
        emb = tf.gather(self.params['Wemb'], x)
        
        return emb
        
          
class SRLSTM_Peephole_Layer(tfe.Network):

    def __init__(self, config, weight_filler, prefix='SRLSTM', trained_param=None):
        super(SRLSTM_Peephole_Layer, self).__init__()
        self.config =  config
        self.params = dict()
        
        self.dim_hidden = config['dim_hidden']
        
        if trained_param is not None:
            W = np.concatenate([weight_filler.uniform_weight(self.dim_hidden),
                                weight_filler.uniform_weight(self.dim_hidden),
                                weight_filler.uniform_weight(self.dim_hidden),
                                weight_filler.uniform_weight(self.dim_hidden)], axis=1)
            self.params[_p(prefix, 'W')] = tf.Variable(W, name=_p(prefix,"W"))
            
            U = np.concatenate([weight_filler.uniform_weight(self.dim_hidden),
                                weight_filler.uniform_weight(self.dim_hidden),
                                weight_filler.uniform_weight(self.dim_hidden),
                                weight_filler.uniform_weight(self.dim_hidden)], axis=1)
            self.params[_p(prefix, 'U')] = tf.Variable(U, name=_p(prefix,"U"))
            
            
            P = np.concatenate([weight_filler.uniform_vector(self.dim_hidden),
                                weight_filler.uniform_vector(self.dim_hidden),
                                weight_filler.uniform_vector(self.dim_hidden)], axis=0)
            
            self.params[_p(prefix, 'P')] =  tf.Variable(P, name=_p(prefix,"P"))
            

            b_i = np.zeros((self.dim_hidden,))
            b_f = np.ones((self.dim_hidden,))
            b_o = np.zeros((self.dim_hidden,))
            b_c = np.zeros((self.dim_hidden,))
            
            b = np.concatenate([b_i, b_f, b_o, b_c],axis=0)
            self.params[_p(prefix, 'b')] = tf.Variable(b.astype('float32'), name=_p(prefix,"b"))
            
            temperature = np.ones((1,))
            self.params[_p(prefix, 'temperature')] = tf.Variable(temperature.astype('float32'), name=_p(prefix,"temperature"))
            centroids = weight_filler.uniform_weight(self.dim_hidden, config['n_centroid'], fix=True).astype('float32')
            self.params[_p(prefix, 'centroid')] = tf.Variable(centroids, name=_p(prefix,"centroid"))
        else:
            self.params[_p(prefix, 'W')] = trained_param[_p(prefix,"W")]
            self.params[_p(prefix, 'U')] = trained_param[_p(prefix,"U")]
            self.params[_p(prefix, 'P')] =  trained_param[_p(prefix,"P")]

            self.params[_p(prefix, 'b')] = trained_param[_p(prefix,"b")]
            self.params[_p(prefix, 'centroid')] = trained_param[_p(prefix,"centroid")]
            self.params[_p(prefix, 'temperature')] = trained_param[_p(prefix,"temperature")]
    
    ''' gumbel softmax '''
    def sample_gumbel(self, shape, eps=1e-20): 
        """Sample from Gumbel(0, 1)"""
        U = tf.random_uniform(shape, minval=0, maxval=1)
        return -tf.log(-tf.log(U + eps) + eps)


    def gumbel_softmax_sample(self, logits, temperature): 
        """ Draw a sample from the Gumbel-Softmax distribution"""
        y = logits + self.sample_gumbel(tf.shape(logits))
        return tf.nn.softmax( y / temperature)

    def gumbel_softmax(self, logits, temperature, hard=False):
        """
        logits: [batch_size, n_class] unnormalized log-probs
        temperature: non-negative scalar
        hard: if True, take argmax, but differentiate w.r.t. soft sample y
        """
        y = self.gumbel_softmax_sample(logits, temperature)
        if hard: 
            y_hard = tf.cast(tf.equal(y,tf.reduce_max(y,1,keep_dims=True)),y.dtype)
            y = tf.stop_gradient(y_hard - y) + y
        return y
     
    def __call__(self, x, mask, prefix='SRLSTM', temperature=0.5):
        
        n_steps = x.shape[0]
        batch_size = x.shape[1]
        
        def _slice(x, n, dim):
            return x[:, n * dim:(n + 1) * dim]
            
        def one_step(prev, x_curr_mask):
            '''
            @param x_curr: x(t)
            @param h_prev: h(t-1) 
            @param c_prev: c(t-1)
            '''
            
            x_curr, x_mask = x_curr_mask
            h_prev, c_prev = tf.unstack(prev, axis=0)   
            # input modulation gate
            X_ = x_curr + tf.matmul(h_prev, self.params[_p(prefix, 'U')])
            
            peephone_i = c_prev * (self.params[_p(prefix, 'P')])[0] ## BxD * 1xD
            peephone_f = c_prev * (self.params[_p(prefix, 'P')])[1]
            
            #input, forget, output gates and cell state
            i = tf.sigmoid(_slice(X_, 0, self.dim_hidden) + peephone_i)
            f = tf.sigmoid(_slice(X_, 1, self.dim_hidden) + peephone_f)
            c = tf.tanh(_slice(X_, 2, self.dim_hidden))
            
            peephone_o = c * (self.params[_p(prefix, 'P')])[2]
            o = tf.sigmoid(_slice(X_, 3, self.dim_hidden) + peephone_o)
            
            
            c = f * c_prev + i * c
            c = x_mask[:, None] * c + (1. - x_mask)[:, None] * c_prev
            
            h = o * tf.tanh(c) 
            h = x_mask[:, None] * h + (1. - x_mask)[:, None] * h_prev
            
            trans_probs = tf.matmul(h, self.params[_p(prefix, 'centroid')])  #B xD  DxK--> BxK 
            trans_probs = self.gumbel_softmax(trans_probs, self.params[_p(prefix, 'temperature')][0]) # BxK
            
            hc = tf.matmul(trans_probs, tf.transpose(self.params[_p(prefix, 'centroid')])) # BxK KxD-> BxD 
            
            return tf.stack([hc, c], axis=0)
            
        h0  =  np.zeros([batch_size, self.dim_hidden]).astype('float32')
        c0  =  np.zeros([batch_size, self.dim_hidden]).astype('float32')
        init_state = tf.stack([h0, c0], axis=0)

        
        x_in  = tf.matmul(x, self.params[_p(prefix, 'W')])+self.params[_p(prefix, 'b')]
        states = tf.scan(one_step, [x_in, mask], 
                        initializer = init_state,
                        name = 'h_states')
                        
        
        h = tf.transpose(states, [1,0,2,3])[0]
        c = tf.transpose(states, [1,0,2,3])[1]
        return h, c


def create_config(dim_hidden=256,
           weight_decay=0.0005,
           lrate=0.01,
           n_words =5000,
           batch_size= 8,
           n_centroid=5,
           y_dim =2,
           model_path='models/ST_IMDB'
           ):
    return locals().copy()

def save_models(modelparams):
    
    new_modelparams = OrderedDict()
    for (layer, param) in modelparams.items():
        for key_value in param:
            key, value= key_value[0], key_value[1]
            new_modelparams[key]=value

    return new_modelparams
    
  
def load_models(path):
    print(path)
    trained_param = np.load(path, allow_pickle=True)
    return trained_param
        

class SRLSTM_Model(tfe.Network):
  def __init__(self, srlstm_config, trained_param=None):
    super(SRLSTM_Model, self).__init__()
    self.srlstm_params=OrderedDict()
    self.config = srlstm_config
    
    print("config", self.config)
    self.weight_filler = FillWeight()

    print('...building embedding layer...')
    self.emb_layer = self.track_layer(Embedding_Layer(self.config, self.weight_filler, trained_param=trained_param))
    self.srlstm_params['emb_layer']=self.emb_layer.params.items() 
    
    print('...building lstm layer...')
    self.srlstm_layer = self.track_layer(SRLSTM_Peephole_Layer(self.config, self.weight_filler, trained_param=trained_param))
    self.srlstm_params['srlstm_layer']=self.srlstm_layer.params.items()
    
    print('...building dense layer...')
    self.dense_layer = self.track_layer(Dense_Layer(self.config, self.weight_filler, trained_param=trained_param))
    self.srlstm_params['dense_layer']=self.dense_layer.params.items()
    
    
    
    


  def call(self, x, mask):
      """Run the model."""
      
      emb = self.emb_layer(x)
      h, c = self.srlstm_layer(emb, mask)
      hc = tf.concat([h, c], axis=2)
      hc = hc[-1]
      logits = self.dense_layer(hc)

      return logits



def set_seed(seed, n_gpu=1):
    """
    Sets the seed.

    :param seed: seed to set, set -1 to draw a random number
    :param n_gpu:
    :return: 0 on success
    """
    if seed == -1:
        seed = random.randrange(2**32 - 1)
    print("Seed: %s", seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.set_random_seed(seed)
    return 0
    
def accuracy(logits, Y):
    prediction = tf.nn.softmax(logits)
    
    correct_pred = tf.equal(tf.argmax(prediction, 1), Y)
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    return accuracy
       
def prepare_loss(logits, y):
    
    
    losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels = y, logits = logits)
    loss = tf.reduce_mean(losses)
    return loss

def model_loss(model, x, mask, y):
    logits = model(x, mask)
    loss_value = prepare_loss(logits, y)+model.config['weight_decay']*tf.nn.l2_loss(model.dense_layer.params['Dense_W'])
    return loss_value
    


def train(
     max_epochs=5000, 
     dim_hidden=256, 
     batch_size = 8,
     valid_batch_size = 8,
     save_iters = 1000,
     valid_iters = 1000,
     display_iters = 1,
):
    log_path = "logs/ST_IMDB.log"
    logger = open(log_path, "a")
    print('\nSetting seed')
    set_seed(12345)

    srlstm_config = create_config()  # returns dictionary
    srlstm_model = SRLSTM_Model(srlstm_config)
    optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
    grad_fn = tfe.implicit_value_and_gradients(model_loss)
    logger.write(str(srlstm_config))
    print(srlstm_config)
    
   
    
    print('...Loading Data...') 
    train, valid = load_imdb_data()
    
    valid_batches = get_minibatches_idx(len(valid[0]), valid_batch_size)
    
    print('%d traning samples'%(len(train[0])))
    print('%d valid samples'%(len(valid[0])))
    
    print('....Training....')
    
    n_iters = 0
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    session = tf.Session(config=config)
    train_loss = 0
    
    
    #print(srlstm_model.srlstm_params[1][0][1][0:10,:2])
    best_val_error=1
    try:
        for epoch in range(max_epochs):
            print('epoch: '+str(epoch))
            batches = get_minibatches_idx(len(train[0]), batch_size, shuffle=True)
            
            for train_batch in batches:
                n_iters += 1
                y_t = [train[1][i] for i in train_batch]
                x_t = [train[0][i] for i in train_batch]
                x_, mask_, y_ = prepare_data(x_t, y_t, maxlen=None)
                
                loss, grads = grad_fn(srlstm_model, x_, mask_, y_)
                
                optimizer.apply_gradients(grads)
                if np.mod(n_iters, display_iters)==0:
                    print('iter [{}] loss : {}'.format(n_iters, loss))
                    logger.write('iter [{}] loss : {}\n'.format(n_iters, loss))
                    logger.flush()
                    

                if np.mod(n_iters, valid_iters) == 0:
                    
                    accuracies=[]
                    for valid_batch in valid_batches:

                        y_v = [valid[1][i] for i in valid_batch]
                        x_v = [valid[0][i] for i in valid_batch]
                        x_, mask_, y_ = prepare_data(x_v, y_v, maxlen=None)
                        
                        
                        valid_logits = srlstm_model(x_, mask_)
                        
                        accuracy_ = accuracy(valid_logits, y_)
                        accuracies.append(accuracy_)
                    averaged_accuracy = np.mean(accuracies)
                    if best_val_error>1-averaged_accuracy:
                        modelparams = save_models(srlstm_model.srlstm_params)
                        output = open(srlstm_config['model_path']+"_%s.pkl"%(n_iters), 'wb')
                        pickle.dump(modelparams, output)
                        output.close()
                        best_val_error=1-averaged_accuracy
                    print('iter [{}] valid error : {} '.format(n_iters, 1-averaged_accuracy))
                    logger.write('iter [{}] valid error : {} '.format(n_iters, 1-averaged_accuracy))
                    logger.flush()
                    
                    
    except KeyboardInterrupt:
        print("Training interupted")                  
    print('----save checkpoint----')
    
    modelparams = save_models(srlstm_model.srlstm_params)
    output = open(srlstm_config['model_path']+"_%s.pkl"%(n_iters), 'wb')
    pickle.dump(modelparams, output)
    output.close()            
 
    
if __name__ == '__main__':
    train()


