#!/usr/bin/env python
# -*- coding: utf-8 -*-



import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"
import sys
import random
import pdb
import numpy as np
import tensorflow as tf
import cPickle as pkl
import six.moves.cPickle as pickle
from collections import OrderedDict
#from peentree import *
import tensorflow.contrib.eager as tfe
tfe.enable_eager_execution()

def _p(pp, name):
    return '%s_%s' % (pp, name)




def load_imdb_data(n_words=5000, sort_by_len=True):
  
    train_set = pkl.load(open("data/imdb_train.pkl", 'rb'))
    valid_set = pkl.load(open("data/imdb_valid.pkl", 'rb'))      

    train_set_x, train_set_y = train_set
   

    def remove_unk(x):
        return [[1 if w >= n_words else w for w in sen] for sen in x]

    valid_set_x, valid_set_y = valid_set
    train_set_x, train_set_y = train_set

    train_set_x = remove_unk(train_set_x)
    valid_set_x = remove_unk(valid_set_x)

    def len_argsort(seq):
        return sorted(range(len(seq)), key=lambda x: len(seq[x]))

    if sort_by_len:
        sorted_index = len_argsort(valid_set_x)
        valid_set_x = [valid_set_x[i] for i in sorted_index]
        valid_set_y = [valid_set_y[i] for i in sorted_index]

        sorted_index = len_argsort(train_set_x)
        train_set_x = [train_set_x[i] for i in sorted_index]
        train_set_y = [train_set_y[i] for i in sorted_index]

    train = (train_set_x, train_set_y)
    valid = (valid_set_x, valid_set_y)
    
    return train, valid
    
    
def prepare_data(seqs, labels, maxlen=100):

    
    lengths = [len(s) for s in seqs]

    
    #
    if maxlen is not None:
        new_seqs = []
        new_labels = []
        new_lengths = []
        for l, s, y in zip(lengths, seqs, labels):
            if l < maxlen:
                new_seqs.append(s + [0] * (maxlen-len(s)))
                new_labels.append(y)

            else:
                new_seqs.append(s[0:maxlen])
                new_labels.append(y)
            new_lengths.append(maxlen)
        #pdb.set_trace()        
        lengths = new_lengths
        labels = new_labels
        seqs = new_seqs
    
    maxlen = np.max(lengths)
    n_samples = len(seqs)
    x = np.zeros((maxlen, n_samples))
    x_mask = np.zeros((maxlen, n_samples))
    
    for idx, s in enumerate(seqs):
        #print(s)
        x[:lengths[idx], idx] = s
        x_mask[:lengths[idx], idx] = 1.
        
    
    return x.astype('int32'), x_mask.astype('float32'), np.asarray(labels).astype('int32')

def get_minibatches_idx(n, minibatch_size, shuffle=False):
    """
    Used to shuffle the dataset at each iteration.
    """

    idx_list = np.arange(n, dtype="int32")

    if shuffle:
        np.random.shuffle(idx_list)

    minibatches = []
    minibatch_start = 0
    for i in range(n // minibatch_size):
        minibatches.append(idx_list[minibatch_start:
                                    minibatch_start + minibatch_size])
        minibatch_start += minibatch_size

    if (minibatch_start != n):
        # Make a minibatch out of what is left
        minibatches.append(idx_list[minibatch_start:])
    
    return minibatches

def pred_error(f_pred, prepare_data, data, iterator, verbose=False):
    
    valid_err = 0
    for _, valid_batch in iterator:
        x, mask, _, y = prepare_data([data[0][t] for t in valid_batch],
                                  np.array(data[1])[valid_batch],
                                  maxlen=None)
        preds = f_pred(x, mask)
        targets = np.array(data[1])[valid_batch]
        valid_err += (preds == targets).sum()
    valid_err = 1. - numpy_floatX(valid_err) / len(data[0])

    return valid_err

class FillWeight(tfe.Network):
    def __init__(self):
        super(FillWeight, self).__init__()

    def uniform_weight(self, n_in, n_out=None, fix=False):
        if n_out is None:
            n_out = n_in
        W = np.random.uniform(low=-np.sqrt(6. / (n_in + n_out)),
        high=np.sqrt(6. / (n_in + n_out)), size=(n_in, n_out))   
               
        if fix:    
            W = np.random.uniform(low=-0.5, high=0.5, size=(n_in, n_out))
        return W.astype('float32')
        
    
    def uniform_vector(self, n_, fix=False):
        W = np.random.uniform(low=-np.sqrt(6. / (n_)),
        high=np.sqrt(6. / (n_)), size=(1, n_))
        
        if fix:    
            W = np.random.uniform(low=-0.5, high=0.5, size=(n_in, n_out))
            
        return W.astype('float32')

    def ortho_weight(self, ndim):
        W = np.random.randn(ndim, ndim)
        u, s, v = np.linalg.svd(W)
        return u.astype('float32')
        
    #def fix
        

class Softmax_Layer(tfe.Network):
    
    def __init__(self, config, weight_filler, prefix='Dense'):
        super(Softmax_Layer, self).__init__()
        
        self.config = config
        self.params = dict()
        self.dim_hidden = config['dim_hidden']
        self.params[_p(prefix,'W')] = tf.Variable(weight_filler.uniform_weight(config['dim_hidden'], config['y_dim']).astype('float32'),name=_p(prefix,"W"))
        self.params[_p(prefix,'b')] = tf.Variable(np.zeros((config['y_dim'],)).astype('float32'), name=_p(prefix,"b"))
    
    def __call__(self, x, prefix='Dense'):
        
        softmax_layer = tf.nn.softmax(tf.matmul(x, self.params[_p(prefix,'W')])+self.params[_p(prefix, 'b')])
        return softmax_layer
    

class Dense_Layer(tfe.Network):

    def __init__(self, config, weight_filler, prefix='Dense', trained_param=None):
        super(Dense_Layer, self).__init__()
        
        self.config = config
        self.params = dict()
        
        self.dim_hidden = config['dim_hidden']

        if trained_param is not None:
            print("Dense_Layer")
            self.params[_p(prefix,'W')] = trained_param[_p(prefix,'W')]
            self.params[_p(prefix,'b')] = trained_param[_p(prefix,'b')]
        else:
            self.params[_p(prefix,'W')] = tf.Variable(weight_filler.uniform_weight(2*config['dim_hidden'], config['y_dim']).astype('float32'), name=_p(prefix,"W"))
            self.params[_p(prefix,'b')] = tf.Variable(np.zeros((config['y_dim'],)).astype('float32'), name=_p(prefix,"b"))
        
        
    
    def __call__(self, x, prefix='Dense'):
        
        dense_layer = tf.matmul(x, self.params[_p(prefix,'W')])+self.params[_p(prefix, 'b')]
        
        return dense_layer
    

class Embedding_Layer(tfe.Network):

    def __init__(self, config, weight_filler, prefix='Embedding', trained_param=None):
        super(Embedding_Layer, self).__init__()
        
        self.config = config
        self.params = dict()
        
        self.dim_hidden = config['dim_hidden']
        if trained_param is not None:
            print("Embedding_Layer")
            self.params['Wemb'] = trained_param['Wemb']
        else:
            self.params['Wemb'] = tf.Variable(weight_filler.uniform_weight(config['n_words'], config['dim_hidden']).astype('float32'), name=_p(prefix,"Wemb"))
        #self.params['Wemb'] = tf.Variable(weight_filler.uniform_weight(config['n_words'], config['dim_hidden']).astype('float32'), name=_p(prefix,"Wemb"))

    
    
    def __call__(self, x, prefix='Embedding'):
        
        emb = tf.gather(self.params['Wemb'], x)#.reshape([None, self.n_samples, self.dim_hidden])
        
        return emb
        
          
class LSTM_Layer(tfe.Network):

    def __init__(self, config, weight_filler, prefix='LSTM', trained_param=None):
        super(LSTM_Layer, self).__init__()
        self.config =  config
        self.params = dict()
        
        self.dim_hidden = config['dim_hidden']
        
        if trained_param is not None:
            self.params[_p(prefix, 'W')] = trained_param[_p(prefix,"W")]
            self.params[_p(prefix, 'U')] = trained_param[_p(prefix,"U")]
            self.params[_p(prefix, 'b')] = trained_param[_p(prefix,"b")]
        else:
            W = np.concatenate([weight_filler.uniform_weight(self.dim_hidden),
                                weight_filler.uniform_weight(self.dim_hidden),
                                weight_filler.uniform_weight(self.dim_hidden),
                                weight_filler.uniform_weight(self.dim_hidden)], axis=1)
            self.params[_p(prefix, 'W')] = tf.Variable(W, name=_p(prefix,"W"))
            
            U = np.concatenate([weight_filler.uniform_weight(self.dim_hidden),
                                weight_filler.uniform_weight(self.dim_hidden),
                                weight_filler.uniform_weight(self.dim_hidden),
                                weight_filler.uniform_weight(self.dim_hidden)], axis=1)
            self.params[_p(prefix, 'U')] = tf.Variable(U, name=_p(prefix,"U"))
        
        
        
        
        #b = np.zeros((4 * self.dim_hidden,))
        b_i = np.zeros((self.dim_hidden,))
        b_f = np.zeros((self.dim_hidden,))
        b_o = np.zeros((self.dim_hidden,))
        b_c = np.zeros((self.dim_hidden,))
        
        b = np.concatenate([b_i, b_f, b_o, b_c],axis=0)
        self.params[_p(prefix, 'b')] = tf.Variable(b.astype('float32'), name=_p(prefix,"b"))
        
     
    def __call__(self, x, mask, prefix='LSTM'):
        
        n_steps = x.shape[0]
        batch_size = x.shape[1]
        
        def _slice(x, n, dim):
            return x[:, n * dim:(n + 1) * dim]
            
        def one_step(prev, x_curr_mask):
            '''
            @param x_curr: x(t)
            @param h_prev: h(t-1) 
            @param c_prev: c(t-1)
            '''
            
            x_curr, x_mask = x_curr_mask
            h_prev, c_prev = tf.unstack(prev, axis=0)   
            # input modulation gate
            X_ = x_curr + tf.matmul(h_prev, self.params[_p(prefix, 'U')])
            
            
            
            #input, forget, output gates and cell state
            i = tf.sigmoid(_slice(X_, 0, self.dim_hidden) )
            f = tf.sigmoid(_slice(X_, 1, self.dim_hidden) )
            c = tf.tanh(_slice(X_, 2, self.dim_hidden))
            
           
            o = tf.sigmoid(_slice(X_, 3, self.dim_hidden) )
            
            
            c = f * c_prev + i * c
            c = x_mask[:, None] * c + (1. - x_mask)[:, None] * c_prev
            
            h = o * tf.tanh(c) 
            h = x_mask[:, None] * h + (1. - x_mask)[:, None] * h_prev
            
            
            return tf.stack([h, c], axis=0)
            
        #x_in =     
            
        
        h0  =  np.zeros([batch_size, self.dim_hidden]).astype('float32')
        c0  =  np.zeros([batch_size, self.dim_hidden]).astype('float32')
        init_state = tf.stack([h0, c0], axis=0)
        #c0 = tf.zeros([batch_size, self.dim_hidden])  
        
        x_in  = tf.matmul(x, self.params[_p(prefix, 'W')])+self.params[_p(prefix, 'b')]
        states = tf.scan(one_step, [x_in, mask], 
                        initializer = init_state,
                        name = 'h_states')
                        
        
        h = tf.transpose(states, [1,0,2,3])[0]
        c = tf.transpose(states, [1,0,2,3])[1]
        return h, c


def create_config(dim_hidden=256,
           weight_decay=0.0005,
           lrate=0.01,
           n_words =5000,
           batch_size= 8,
           y_dim =2,
           model_path='models/LSTM_IMDB'
           ):
    return locals().copy()

def save_models(modelparams):
    
    new_modelparams = OrderedDict()
    for (layer, param) in modelparams.items():

        for key_value in param:
            
            key, value= key_value[0], key_value[1]

            new_modelparams[key]=value
    return new_modelparams

def load_models(path):

    trained_param = np.load(path, allow_pickle=True)
    return trained_param


class LSTM_Model(tfe.Network):
  def __init__(self, lstm_config, trained_param=None):
    super(LSTM_Model, self).__init__()
    self.lstm_params=OrderedDict()
    self.config = lstm_config
    
    print("config", self.config)
    self.weight_filler = FillWeight()
    
    print('...building embedding layer...')
    self.emb_layer = self.track_layer(Embedding_Layer(self.config, self.weight_filler, trained_param=trained_param))
    self.lstm_params['emb_layer']=self.emb_layer.params.items() 
    
    print('...building lstm layer...')
    self.lstm_layer = self.track_layer(LSTM_Layer(self.config, self.weight_filler, trained_param=trained_param))
    self.lstm_params['lstm_layer']=self.lstm_layer.params.items()
    
    print('...building dense layer...')
    self.dense_layer = self.track_layer(Dense_Layer(self.config, self.weight_filler, trained_param=trained_param))
    self.lstm_params['dense_layer']=self.dense_layer.params.items()

  def call(self, x, mask):
      """Run the model."""
      
      emb = self.emb_layer(x)
      h, c = self.lstm_layer(emb, mask)
      
      
      hc = tf.concat([h, c], axis=2)
      hc = hc[-1]
      logits = self.dense_layer(hc)
      
      return logits



def set_seed(seed, n_gpu=1):
    """
    Sets the seed.

    :param seed: seed to set, set -1 to draw a random number
    :param n_gpu:
    :return: 0 on success
    """
    if seed == -1:
        seed = random.randrange(2**32 - 1)
    print("Seed: %s", seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.set_random_seed(seed)
    return 0
    
def accuracy(logits, Y):
    prediction = tf.nn.softmax(logits)
    
    correct_pred = tf.equal(tf.argmax(prediction, 1), Y)
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    return accuracy
       
def prepare_loss(logits, y):

    losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels = y, logits = logits)
    loss = tf.reduce_mean(losses)
    return loss

def model_loss(model, x, mask, y):
    logits = model(x, mask)
    loss_value = prepare_loss(logits, y)+model.config['weight_decay']*tf.nn.l2_loss(model.dense_layer.params['Dense_W'])
    return loss_value
    


def train(
     max_epochs=5000, 
     dim_hidden=256, 
     batch_size = 8,
     valid_batch_size = 8,
     save_iters = 1000,
     valid_iters = 1000,
     display_iters = 1,
):
    log_path = "logs/lstm_IMDB.log"
    logger = open(log_path, "a")
    print('\nSetting seed')
    set_seed(12345)

    lstm_config = create_config()  # returns dictionary
    lstm_model = LSTM_Model(lstm_config)
    optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
    grad_fn = tfe.implicit_value_and_gradients(model_loss)
    logger.write(str(lstm_config))
    print(lstm_config)
    
   
    
    print('...Loading Data...') 
    train, valid = load_imdb_data()
    
    valid_batches = get_minibatches_idx(len(valid[0]), valid_batch_size)
    
    print('%d traning samples'%(len(train[0])))
    print('%d valid samples'%(len(valid[0])))
    
    print('....Training....')
    
    n_iters = 0
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    session = tf.Session(config=config)
    train_loss = 0
    best_val_error=1
    try:
        for epoch in range(max_epochs):
            print('epoch: '+str(epoch))
            batches = get_minibatches_idx(len(train[0]), batch_size, shuffle=True)
            
            for train_batch in batches:
                n_iters += 1
                y_t = [train[1][i] for i in train_batch]
                x_t = [train[0][i] for i in train_batch]
                x_, mask_, y_ = prepare_data(x_t, y_t, maxlen=None)
                
                loss, grads = grad_fn(lstm_model, x_, mask_, y_)
                
                optimizer.apply_gradients(grads)
                if np.mod(n_iters, display_iters)==0:
                    print('iter [{}] loss : {}'.format(n_iters, loss))
                    logger.write('iter [{}] loss : {}\n'.format(n_iters, loss))
                    logger.flush()
                    
                if np.mod(n_iters, valid_iters) == 0:
                    accuracies=[]
                    for valid_batch in valid_batches:
                        y_v = [valid[1][i] for i in valid_batch]
                        x_v = [valid[0][i] for i in valid_batch]
                        x_, mask_, y_ = prepare_data(x_v, y_v, maxlen=None)

                        valid_logits = lstm_model(x_, mask_)
                        
                        accuracy_ = accuracy(valid_logits, y_)
                        accuracies.append(accuracy_)
                    averaged_accuracy = np.mean(accuracies)
                    if best_val_error>1-averaged_accuracy:
                        modelparams = save_models(lstm_model.lstm_params)
                        output = open(lstm_config['model_path']+"_%s.pkl"%(n_iters), 'wb')
                        pickle.dump(modelparams, output)
                        output.close()
                        best_val_error=1-averaged_accuracy
                    print('iter [{}] valid error : {} '.format(n_iters, 1-averaged_accuracy))
                    logger.write('iter [{}] valid error : {} '.format(n_iters, 1-averaged_accuracy))
                    logger.flush()
                    
                    
    except KeyboardInterrupt:
        print("Training interupted")                  
    print('----save checkpoint----')
    
    modelparams = save_models(lstm_model.lstm_params)
    output = open(lstm_config['model_path']+"_%s.pkl"%(n_iters), 'wb')
    pickle.dump(modelparams, output)
    output.close()            

    
if __name__ == '__main__':
    train()


