#!/usr/bin/env python
# -*- coding: utf-8 -*-


import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"
import sys
import random
import pdb
import numpy as np
import tensorflow as tf
import cPickle as pkl
import six.moves.cPickle as pickle
from collections import OrderedDict
import tensorflow.contrib.eager as tfe
tfe.enable_eager_execution()

def _p(pp, name):
    return '%s_%s' % (pp, name)

MODEL_PATH="imdb_vd"
DROPOUT_KEEP_PROB=0.9


def load_imdb_data(n_words=5000, sort_by_len=True):
    train_set = pkl.load(open("data/imdb_train.pkl", 'rb'))
    valid_set = pkl.load(open("data/imdb_valid.pkl", 'rb'))

    train_set_x, train_set_y = train_set

    def remove_unk(x):
        return [[1 if w >= n_words else w for w in sen] for sen in x]

    valid_set_x, valid_set_y = valid_set
    train_set_x, train_set_y = train_set

    train_set_x = remove_unk(train_set_x)
    valid_set_x = remove_unk(valid_set_x)

    def len_argsort(seq):
        return sorted(range(len(seq)), key=lambda x: len(seq[x]))

    if sort_by_len:
        sorted_index = len_argsort(valid_set_x)
        valid_set_x = [valid_set_x[i] for i in sorted_index]
        valid_set_y = [valid_set_y[i] for i in sorted_index]

        sorted_index = len_argsort(train_set_x)
        train_set_x = [train_set_x[i] for i in sorted_index]
        train_set_y = [train_set_y[i] for i in sorted_index]

    train = (train_set_x, train_set_y)
    valid = (valid_set_x, valid_set_y)

    # pdb.set_trace()
    return train, valid


def load_imdb_data_test(n_words=5000, sort_by_len=True):
    test_set = pkl.load(open("data/imdb_test.pkl", 'rb'))

    def remove_unk(x):
        return [[1 if w >= n_words else w for w in sen] for sen in x]

    test_set_x, test_set_y = test_set

    test_set_x = remove_unk(test_set_x)

    def len_argsort(seq):
        return sorted(range(len(seq)), key=lambda x: len(seq[x]))

    if sort_by_len:
        sorted_index = len_argsort(test_set_x)
        test_set_x = [test_set_x[i] for i in sorted_index]
        test_set_y = [test_set_y[i] for i in sorted_index]

    test = (test_set_x, test_set_y)
    # pdb.set_trace()
    return test
    
    
def prepare_data(seqs, labels, maxlen=100):
    lengths = [len(s) for s in seqs]
    if maxlen is not None:
        new_seqs = []
        new_labels = []
        new_lengths = []
        for l, s, y in zip(lengths, seqs, labels):
            if l < maxlen:
                new_seqs.append(s + [0] * (maxlen-len(s)))
                new_labels.append(y)

            else:
                new_seqs.append(s[0:maxlen])
                new_labels.append(y)
            new_lengths.append(maxlen)      
        lengths = new_lengths
        labels = new_labels
        seqs = new_seqs
    
    maxlen = np.max(lengths)
    n_samples = len(seqs)
    x = np.zeros((maxlen, n_samples))
    x_mask = np.zeros((maxlen, n_samples))
    for idx, s in enumerate(seqs):
        x[:lengths[idx], idx] = s
        x_mask[:lengths[idx], idx] = 1.
        
    
    return x.astype('int32'), x_mask.astype('float32'), np.asarray(labels).astype('int32')

def get_minibatches_idx(n, minibatch_size, shuffle=False):
    idx_list = np.arange(n, dtype="int32")

    if shuffle:
        np.random.shuffle(idx_list)

    minibatches = []
    minibatch_start = 0
    for i in range(n // minibatch_size):
        minibatches.append(idx_list[minibatch_start:
                                    minibatch_start + minibatch_size])
        minibatch_start += minibatch_size

    if (minibatch_start != n):
        # Make a minibatch out of what is left
        minibatches.append(idx_list[minibatch_start:])
    
    return minibatches

def pred_error(f_pred, prepare_data, data, iterator, verbose=False):
    """
    Just compute the error
    f_pred: Theano fct computing the prediction
    prepare_data: usual prepare_data for that dataset.
    """
    
    valid_err = 0
    for _, valid_batch in iterator:
        x, mask, _, y = prepare_data([data[0][t] for t in valid_batch],
                                  np.array(data[1])[valid_batch],
                                  maxlen=None)
        preds = f_pred(x, mask)
        targets = np.array(data[1])[valid_batch]
        valid_err += (preds == targets).sum()
    valid_err = 1. - numpy_floatX(valid_err) / len(data[0])

    return valid_err

class FillWeight(tfe.Network):
    def __init__(self):
        super(FillWeight, self).__init__()

    def uniform_weight(self, n_in, n_out=None, fix=False):
        if n_out is None:
            n_out = n_in
        W = np.random.uniform(low=-np.sqrt(6. / (n_in + n_out)),
        high=np.sqrt(6. / (n_in + n_out)), size=(n_in, n_out))   
               
        if fix:    
            W = np.random.uniform(low=-0.5, high=0.5, size=(n_in, n_out))
        return W.astype('float32')
        
    
    def uniform_vector(self, n_, fix=False):
        W = np.random.uniform(low=-np.sqrt(6. / (n_)),
        high=np.sqrt(6. / (n_)), size=(1, n_))
        
        if fix:    
            W = np.random.uniform(low=-0.5, high=0.5, size=(n_in, n_out))
            
        return W.astype('float32')

    def ortho_weight(self, ndim):
        W = np.random.randn(ndim, ndim)
        u, s, v = np.linalg.svd(W)
        return u.astype('float32')
        
class Softmax_Layer(tfe.Network):
    
    def __init__(self, config, weight_filler, prefix='Dense'):
        super(Softmax_Layer, self).__init__()
        
        self.config = config
        self.params = dict()
        self.dim_hidden = config['dim_hidden']
        self.params[_p(prefix,'W')] = tf.Variable(weight_filler.uniform_weight(config['dim_hidden'], config['y_dim']).astype('float32'),name=_p(prefix,"W"))
        self.params[_p(prefix,'b')] = tf.Variable(np.zeros((config['y_dim'],)).astype('float32'), name=_p(prefix,"b"))
    
    def __call__(self, x, prefix='Dense'):
        
        softmax_layer = tf.nn.softmax(tf.matmul(x, self.params[_p(prefix,'W')])+self.params[_p(prefix, 'b')])
        return softmax_layer
    

class Dense_Layer(tfe.Network):

    def __init__(self, config, weight_filler, prefix='Dense', trained_param=None):
        super(Dense_Layer, self).__init__()
        
        self.config = config
        self.params = dict()
        
        self.dim_hidden = config['dim_hidden']

        if trained_param is not None:
            print("Dense_Layer")
            self.params[_p(prefix,'W')] = trained_param[_p(prefix,'W')]
            self.params[_p(prefix,'b')] = trained_param[_p(prefix,'b')]
        else:
            self.params[_p(prefix,'W')] = tf.Variable(weight_filler.uniform_weight(2*config['dim_hidden'], config['y_dim']).astype('float32'), name=_p(prefix,"W"))
            self.params[_p(prefix,'b')] = tf.Variable(np.zeros((config['y_dim'],)).astype('float32'), name=_p(prefix,"b"))

    def __call__(self, x, prefix='Dense'):
        dense_layer = tf.matmul(x, self.params[_p(prefix,'W')])+self.params[_p(prefix, 'b')]
        
        return dense_layer
    

class Embedding_Layer(tfe.Network):

    def __init__(self, config, weight_filler, prefix='Embedding', trained_param=None):
        super(Embedding_Layer, self).__init__()
        
        self.config = config
        self.params = dict()

        self.dim_hidden = config['dim_hidden']

        if trained_param is not None:
            print("Embedding_Layer")
            self.params['Wemb'] = trained_param['Wemb']
        else:
            self.params['Wemb'] = tf.Variable(weight_filler.uniform_weight(config['n_words'], config['dim_hidden']).astype('float32'), name=_p(prefix,"Wemb"))

    
    def __call__(self, x, prefix='Embedding'):
        
        emb = tf.gather(self.params['Wemb'], x)#.reshape([None, self.n_samples, self.dim_hidden])
        return emb


class Variational_Dropout_Layer(tfe.Network):
    def __init__(self, config):
        super(Variational_Dropout_Layer, self).__init__()
        self.config = config
        self.variational_dropout_mask_x = None
        self.variational_dropout_mask_h = None

    def reset_mask(self, batch_size):
        bernoulli = tf.contrib.distributions.Bernoulli(probs=[self.config['dropout_keep_prob']],
                                                       dtype=tf.dtypes.float32)
        self.variational_dropout_mask_x = bernoulli.sample([batch_size, self.config['dim_hidden']])
        self.variational_dropout_mask_x = tf.compat.v1.squeeze(self.variational_dropout_mask_x)
        self.variational_dropout_mask_x = np.concatenate([self.variational_dropout_mask_x,
                                                          self.variational_dropout_mask_x,
                                                          self.variational_dropout_mask_x,
                                                          self.variational_dropout_mask_x], axis=0)
        self.variational_dropout_mask_h = bernoulli.sample([batch_size, self.config['dim_hidden']])
        self.variational_dropout_mask_h = tf.compat.v1.squeeze(self.variational_dropout_mask_h)

    def __call__(self, x, h, prefix='Variational_Dropout_Layer'):
        x_ = x * self.variational_dropout_mask_x
        h_ = h * self.variational_dropout_mask_h
        return x_, h_
        
          
class LSTM_Peephole_Layer(tfe.Network):

    def __init__(self, config, weight_filler, prefix='LSTM', trained_param=None):
        super(LSTM_Peephole_Layer, self).__init__()
        self.config =  config
        self.params = dict()
        
        self.dim_hidden = config['dim_hidden']

        if trained_param is not None:
            print("LSTM_Peephole_Layer")
            self.params[_p(prefix, 'W')] = trained_param[_p(prefix,"W")]
            self.params[_p(prefix, 'U')] = trained_param[_p(prefix,"U")]
            self.params[_p(prefix, 'b')] = trained_param[_p(prefix,"b")]
        else:
            W = np.concatenate([weight_filler.uniform_weight(self.dim_hidden),
                                weight_filler.uniform_weight(self.dim_hidden),
                                weight_filler.uniform_weight(self.dim_hidden),
                                weight_filler.uniform_weight(self.dim_hidden)], axis=1)
            self.params[_p(prefix, 'W')] = tf.Variable(W, name=_p(prefix,"W"))

            U = np.concatenate([weight_filler.uniform_weight(self.dim_hidden),
                                weight_filler.uniform_weight(self.dim_hidden),
                                weight_filler.uniform_weight(self.dim_hidden),
                                weight_filler.uniform_weight(self.dim_hidden)], axis=1)
            self.params[_p(prefix, 'U')] = tf.Variable(U, name=_p(prefix,"U"))

            b_i = np.zeros((self.dim_hidden,))
            b_f = np.zeros((self.dim_hidden,))
            b_o = np.zeros((self.dim_hidden,))
            b_c = np.zeros((self.dim_hidden,))
            b = np.concatenate([b_i, b_f, b_o, b_c], axis=0)
            self.params[_p(prefix, 'b')] = tf.Variable(b.astype('float32'), name=_p(prefix,"b"))

        self.variational_dropout_layer = Variational_Dropout_Layer(config)
        
     
    def __call__(self, x, mask, prefix='LSTM'):
        
        n_steps = x.shape[0]
        batch_size = x.shape[1]
        
        def _slice(x, n, dim):
            return x[:, n * dim:(n + 1) * dim]
            
        def one_step(prev, x_curr_mask):
            '''
            @param x_curr: x(t)
            @param h_prev: h(t-1) 
            @param c_prev: c(t-1)
            '''
            
            x_curr, x_mask = x_curr_mask
            h_prev, c_prev = tf.unstack(prev, axis=0)   
            
            x_curr, h_prev = self.variational_dropout_layer(x_curr, h_prev)
            X_ = x_curr + tf.matmul(h_prev, self.params[_p(prefix, 'U')])

            i = tf.sigmoid(_slice(X_, 0, self.dim_hidden) )
            f = tf.sigmoid(_slice(X_, 1, self.dim_hidden) )
            c = tf.tanh(_slice(X_, 2, self.dim_hidden))
            o = tf.sigmoid(_slice(X_, 3, self.dim_hidden) )

            c = f * c_prev + i * c
            c = x_mask[:, None] * c + (1. - x_mask)[:, None] * c_prev
            
            h = o * tf.tanh(c) 
            h = x_mask[:, None] * h + (1. - x_mask)[:, None] * h_prev
            
            
            return tf.stack([h, c], axis=0)

        h0  =  np.zeros([batch_size, self.dim_hidden]).astype('float32')
        c0  =  np.zeros([batch_size, self.dim_hidden]).astype('float32')
        init_state = tf.stack([h0, c0], axis=0)

        self.variational_dropout_layer.reset_mask(x.shape[1])
        x_in  = tf.matmul(x, self.params[_p(prefix, 'W')])+self.params[_p(prefix, 'b')]
        states = tf.scan(one_step, [x_in, mask], 
                        initializer = init_state,
                        name = 'h_states')
                        
        
        h = tf.transpose(states, [1,0,2,3])[0]
        c = tf.transpose(states, [1,0,2,3])[1]
        return h, c


def create_config(dim_hidden=256,
           weight_decay=0.0005,
           lrate=0.01,
           n_words =5000,
           batch_size= 8,
           y_dim =2,
           model_path='new_models/VD_IMDB',#+MODEL_PATH+"_"+str(DROPOUT_KEEP_PROB),
           dropout_keep_prob=DROPOUT_KEEP_PROB
           ):
    return locals().copy()

def save_models(modelparams):
    
    new_modelparams = OrderedDict()
    for (layer, param) in modelparams.items():
        for key_value in param:
            key, value= key_value[0], key_value[1]
            new_modelparams[key]=value

    return new_modelparams
    
  
def load_models(path):
    print(path)
    trained_param = np.load(path, allow_pickle=True)
    return trained_param
        

class LSTM_Model(tfe.Network):
  def __init__(self, lstm_config, trained_param=None, return_h_c=False):
    super(LSTM_Model, self).__init__()
    self.lstm_params=OrderedDict()
    self.config = lstm_config
    self.return_h_c = return_h_c
    
    print("config", self.config)
    self.weight_filler = FillWeight()
    
    print('...building embedding layer...')
    self.emb_layer = self.track_layer(Embedding_Layer(self.config, self.weight_filler,
                                                      trained_param=trained_param))
    self.lstm_params['emb_layer']=self.emb_layer.params.items() 
    
    print('...building lstm layer...')
    self.lstm_layer = self.track_layer(LSTM_Peephole_Layer(self.config, self.weight_filler,
                                                      trained_param=trained_param))
    self.lstm_params['lstm_layer']=self.lstm_layer.params.items()
    
    print('...building dense layer...')
    self.dense_layer = self.track_layer(Dense_Layer(self.config, self.weight_filler,
                                                      trained_param=trained_param))
    self.lstm_params['dense_layer']=self.dense_layer.params.items()

  def call(self, x, mask):
      """Run the model."""
      
      emb = self.emb_layer(x)
      h, c = self.lstm_layer(emb, mask)

      hc = tf.concat([h, c], axis=2)
      hc = hc[-1]
      logits = self.dense_layer(hc)
      
      if self.return_h_c is True:
          return h, c
      return logits



def set_seed(seed, n_gpu=1):
    """
    Sets the seed.

    :param seed: seed to set, set -1 to draw a random number
    :param n_gpu:
    :return: 0 on success
    """
    if seed == -1:
        seed = random.randrange(2**32 - 1)
    #print("Seed: %s", seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.set_random_seed(seed)
    return 0
    
def accuracy(logits, Y):
    prediction = tf.nn.softmax(logits)
    
    correct_pred = tf.equal(tf.argmax(prediction, 1), Y)
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    return accuracy
       
def prepare_loss(logits, y):
    losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels = y, logits = logits)
    loss = tf.reduce_mean(losses)
    return loss

def model_loss(model, x, mask, y):
    logits = model(x, mask)
    loss_value = prepare_loss(logits, y)+model.config['weight_decay']*tf.nn.l2_loss(model.dense_layer.params['Dense_W'])
    return loss_value
    


def train(
     max_epochs=5000, 
     dim_hidden=256, 
     batch_size = 8,
     valid_batch_size = 8,
     save_iters = 1000,
     valid_iters = 1000,
     display_iters = 1,
):
    lowest_loss = float("inf")
    log_path = "logs/" + MODEL_PATH + "_" + str(DROPOUT_KEEP_PROB) + ".log"
    logger = open(log_path, "a")
    print('\nSetting seed')
    set_seed(12345)

    lstm_config = create_config()  # returns dictionary
    lstm_model = LSTM_Model(lstm_config)
    optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
    grad_fn = tfe.implicit_value_and_gradients(model_loss)
    logger.write(str(lstm_config))
    print(lstm_config)
    
   
    
    print('...Loading Data...') 
    train, valid = load_imdb_data()
    
    valid_batches = get_minibatches_idx(len(valid[0]), valid_batch_size)
    
    print('%d traning samples'%(len(train[0])))
    print('%d test samples'%(len(valid[0])))
    
    print('....Training....')
    
    n_iters = 0
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    session = tf.Session(config=config)
    train_loss = 0
    
    best_val_error=1
    try:
        for epoch in range(max_epochs):
            print('epoch: '+str(epoch))
            batches = get_minibatches_idx(len(train[0]), batch_size, shuffle=True)
            
            for train_batch in batches:
                n_iters += 1
                y_t = [train[1][i] for i in train_batch]
                x_t = [train[0][i] for i in train_batch]

                x_, mask_, y_ = prepare_data(x_t, y_t, maxlen=None)
                
                loss, grads = grad_fn(lstm_model, x_, mask_, y_)
                
                optimizer.apply_gradients(grads)
                if np.mod(n_iters, display_iters)==0:
                    print('iter [{}] loss : {}'.format(n_iters, loss))
                    logger.write('iter [{}] loss : {}\n'.format(n_iters, loss))
                    logger.flush()
                    
 
                if np.mod(n_iters, valid_iters) == 0:

                    accuracies = []
                    for valid_batch in valid_batches:
                        y_v = [valid[1][i] for i in valid_batch]
                        x_v = [valid[0][i] for i in valid_batch]
                        x_, mask_, y_ = prepare_data(x_v, y_v, maxlen=None)
                        valid_logits = lstm_model(x_, mask_)
                        accuracy_ = accuracy(valid_logits, y_)
                        accuracies.append(accuracy_)
                    averaged_accuracy = np.mean(accuracies)
                    loss = 1 - averaged_accuracy
                    if loss < lowest_loss:
                        modelparams = save_models(lstm_model.lstm_params)
                        output = open(lstm_config['model_path'] + ".pkl", 'wb')
                        pickle.dump(modelparams, output)
                        output.close()
                        print('----save new best model----')
                        logger.write('----save new best model----\n')
                        lowest_loss = loss
                    print('iter [{}] valid error : {} '.format(n_iters, loss))
                    logger.write('iter [{}] valid error : {} \n'.format(n_iters, loss))
                    logger.write(
                        'iter [{}] valid acc : {} \n'.format(n_iters, averaged_accuracy))
                    logger.flush()
                    
                    
    except KeyboardInterrupt:
        print("Training interupted")                  
    print('----save checkpoint----')
    
    modelparams = save_models(lstm_model.lstm_params)
    output = open(lstm_config['model_path']+"_%s.pkl"%(n_iters), 'wb')
    pickle.dump(modelparams, output)
    output.close()            

    
if __name__ == '__main__':
    train()


