#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
os.environ["CUDA_VISIBLE_DEVICES"]="3"
import sys
import random
import pdb
import numpy as np
import tensorflow as tf
import cPickle as pkl
import six.moves.cPickle as pickle
from collections import OrderedDict

import pandas as pd
import tensorflow.contrib.eager as tfe
tfe.enable_eager_execution()

def _p(pp, name):
    return '%s_%s' % (pp, name)

def load_mit_file(file):
    df = pd.read_csv(file, header=None)

    Y = np.array(df[187].values).astype(np.int8)
    X = np.array(df[list(range(187))].values)

    X_trunc = []
    for i, _ in enumerate(X):
        X_trunc.append(np.trim_zeros(X[i], trim='b').tolist())

    data = (X_trunc, Y.tolist())
    return data


def load_mit_bih_a_data(path="data/mitbih/"):
    train = load_mit_file(path+"mitbih_new_train.csv")
    valid = load_mit_file(path+"mitbih_valid.csv")
    return train, valid


def load_mit_bih_a_testdata(path="data/mitbih/"):
    test = load_mit_file(path+"mitbih_test.csv")
    return test
       
def prepare_data(seqs, labels, maxlen=100):
    lengths = [len(s) for s in seqs]

    if maxlen is not None:
        new_seqs = []
        new_labels = []
        new_lengths = []
        for l, s, y in zip(lengths, seqs, labels):
            if l < maxlen:
                new_seqs.append(s + [0] * (maxlen-len(s)))
                new_labels.append(y)
            else:
                new_seqs.append(s[0:maxlen])
                new_labels.append(y)
            new_lengths.append(maxlen)
                   
        lengths = new_lengths
        labels = new_labels
        seqs = new_seqs
    
    maxlen = np.max(lengths)
    n_samples = len(seqs)
    x = np.zeros((maxlen, n_samples))
    x_mask = np.zeros((maxlen, n_samples))
    for idx, s in enumerate(seqs):
        #print(s)
        x[:lengths[idx], idx] = s
        x_mask[:lengths[idx], idx] = 1.
        
    
    return x.astype('float32'), x_mask.astype('float32'), np.asarray(labels).astype('int32')

def get_minibatches_idx(n, minibatch_size, shuffle=False):
    """
    Used to shuffle the dataset at each iteration.
    """

    idx_list = np.arange(n, dtype="int32")

    if shuffle:
        np.random.shuffle(idx_list)

    minibatches = []
    minibatch_start = 0
    for i in range(n // minibatch_size):
        minibatches.append(idx_list[minibatch_start:
                                    minibatch_start + minibatch_size])
        minibatch_start += minibatch_size

    if (minibatch_start != n):
        # Make a minibatch out of what is left
        minibatches.append(idx_list[minibatch_start:])
    
    return minibatches

def pred_error(f_pred, prepare_data, data, iterator, verbose=False):

    valid_err = 0
    for _, valid_batch in iterator:
        x, mask, _, y = prepare_data([data[0][t] for t in valid_batch],
                                  np.array(data[1])[valid_batch],
                                  maxlen=None)
        preds = f_pred(x, mask)
        targets = np.array(data[1])[valid_batch]
        valid_err += (preds == targets).sum()
    valid_err = 1. - numpy_floatX(valid_err) / len(data[0])

    return valid_err

class FillWeight(tfe.Network):
    def __init__(self):
        super(FillWeight, self).__init__()

    def uniform_weight(self, n_in, n_out=None, fix=False):
        if n_out is None:
            n_out = n_in
        W = np.random.uniform(low=-np.sqrt(6. / (n_in + n_out)),
        high=np.sqrt(6. / (n_in + n_out)), size=(n_in, n_out))   
               
        if fix:    
            W = np.random.uniform(low=-0.5, high=0.5, size=(n_in, n_out))
        return W.astype('float32')
        
    
    def uniform_vector(self, n_, fix=False):
        W = np.random.uniform(low=-np.sqrt(6. / (n_)),
        high=np.sqrt(6. / (n_)), size=(1, n_))
        
        if fix:    
            W = np.random.uniform(low=-0.5, high=0.5, size=(n_in, n_out))
            
        return W.astype('float32')

    def ortho_weight(self, ndim):
        W = np.random.randn(ndim, ndim)
        u, s, v = np.linalg.svd(W)
        return u.astype('float32')
        
    #def fix
        
    

class Dense_Layer(tfe.Network):

    def __init__(self, config, weight_filler, prefix='Dense', trained_param=None):
        super(Dense_Layer, self).__init__()
        
        self.config = config
        self.params = dict()
        
        self.dim_hidden = config['dim_hidden']

        if trained_param is not None:
            self.params[_p(prefix,'W')] = trained_param[_p(prefix,'W')]
            self.params[_p(prefix,'b')] = trained_param[_p(prefix,'b')]
        else:
            self.params[_p(prefix,'W')] = tf.Variable(weight_filler.uniform_weight(2*config['dim_hidden'], config['y_dim']).astype('float32'), name=_p(prefix,"W"))
            self.params[_p(prefix,'b')] = tf.Variable(np.zeros((config['y_dim'],)).astype('float32'), name=_p(prefix,"b"))
      
        
    
    def __call__(self, x, prefix='Dense'):
        
        dense_layer = tf.matmul(x, self.params[_p(prefix,'W')])+self.params[_p(prefix, 'b')]
        
        return dense_layer


class Embedding_Layer(tfe.Network):

    def __init__(self, config, weight_filler, prefix='Embedding', trained_param=None):
        super(Embedding_Layer, self).__init__()

        self.config = config
        self.params = dict()

        self.dim_hidden = config['dim_hidden']

        if trained_param is not None:
            self.params[_p(prefix, 'Wemb')] = trained_param[_p(prefix, 'Wemb')]
            self.params[_p(prefix, 'bemb')] = trained_param[_p(prefix, 'bemb')]
        else:
            self.params[_p(prefix, 'Wemb')] = tf.Variable(
                weight_filler.uniform_weight(1, config['dim_hidden']).astype('float32'),
                name=_p(prefix, "Wemb"))
            self.params[_p(prefix, 'bemb')] = tf.Variable(
                np.zeros((config['dim_hidden'],)).astype('float32'), name=_p(prefix, "bemb"))

    def __call__(self, x, prefix='Embedding'):
    
        emb = tf.matmul(x, self.params[_p(prefix, 'Wemb')]) + self.params[_p(prefix, 'bemb')]
        return emb
        
          
class SRLSTM_Peephole_Layer(tfe.Network):

    def __init__(self, config, weight_filler, prefix='LSTM', trained_param=None):
        super(SRLSTM_Peephole_Layer, self).__init__()
        self.config =  config
        self.params = dict()
        
        self.dim_hidden = config['dim_hidden']
        
        if trained_param is not None:
            self.params[_p(prefix, 'W')] = trained_param[_p(prefix,"W")]         
            self.params[_p(prefix, 'U')] = trained_param[_p(prefix,"U")]
            self.params[_p(prefix, 'P')] =  trained_param[_p(prefix,"P")]
            self.params[_p(prefix, 'b')] = trained_param[_p(prefix,"b")]
            self.params[_p(prefix, 'centroid')] = trained_param[_p(prefix,"centroid")]
            self.params[_p(prefix, 'temperature')] = trained_param[_p(prefix,"temperature")]
        else:
            W = np.concatenate([weight_filler.uniform_weight(self.dim_hidden, self.dim_hidden),
                                weight_filler.uniform_weight(self.dim_hidden, self.dim_hidden),
                                weight_filler.uniform_weight(self.dim_hidden, self.dim_hidden),
                                weight_filler.uniform_weight(self.dim_hidden, self.dim_hidden)], axis=1)
            self.params[_p(prefix, 'W')] = tf.Variable(W, name=_p(prefix,"W"))
            
            U = np.concatenate([weight_filler.uniform_weight(self.dim_hidden),
                                weight_filler.uniform_weight(self.dim_hidden),
                                weight_filler.uniform_weight(self.dim_hidden),
                                weight_filler.uniform_weight(self.dim_hidden)], axis=1)
            self.params[_p(prefix, 'U')] = tf.Variable(U, name=_p(prefix,"U"))
            
            
            P = np.concatenate([weight_filler.uniform_vector(self.dim_hidden),
                                weight_filler.uniform_vector(self.dim_hidden),
                                weight_filler.uniform_vector(self.dim_hidden)], axis=0)
            
            self.params[_p(prefix, 'P')] =  tf.Variable(P, name=_p(prefix,"P"))

            b_i = np.zeros((self.dim_hidden,))
            b_f = np.ones((self.dim_hidden,))
            b_o = np.zeros((self.dim_hidden,))
            b_c = np.zeros((self.dim_hidden,))
            
            b = np.concatenate([b_i, b_f, b_o, b_c],axis=0)
            self.params[_p(prefix, 'b')] = tf.Variable(b.astype('float32'), name=_p(prefix,"b"))
            
            
            temperature = np.ones((1,))
            self.params[_p(prefix, 'temperature')] = tf.Variable(temperature.astype('float32'), name=_p(prefix,"temperature"))

            centroids = weight_filler.uniform_weight(self.dim_hidden, config['n_centroid'], fix=True).astype('float32')
            self.params[_p(prefix, 'centroid')] = tf.Variable(centroids, name=_p(prefix,"centroid"))
    
    ''' gumbel softmax '''
    def sample_gumbel(self, shape, eps=1e-20): 
        """Sample from Gumbel(0, 1)"""
        U = tf.random_uniform(shape, minval=0, maxval=1)
        return -tf.log(-tf.log(U + eps) + eps)


    def gumbel_softmax_sample(self, logits, temperature): 
        """ Draw a sample from the Gumbel-Softmax distribution"""
        y = logits + self.sample_gumbel(tf.shape(logits))
        return tf.nn.softmax( y / temperature)

    def gumbel_softmax(self, logits, temperature, hard=False):
        """
        logits: [batch_size, n_class] unnormalized log-probs
        temperature: non-negative scalar
        hard: if True, take argmax, but differentiate w.r.t. soft sample y
        """
        y = self.gumbel_softmax_sample(logits, temperature)
        if hard: 
            y_hard = tf.cast(tf.equal(y,tf.reduce_max(y,1,keep_dims=True)),y.dtype)
            y = tf.stop_gradient(y_hard - y) + y
        return y
     
    def __call__(self, x, mask, prefix='LSTM', temperature=1):
        
        n_steps = x.shape[0]
        batch_size = x.shape[1]
        
        def _slice(x, n, dim):
            return x[:, n * dim:(n + 1) * dim]
            
        def one_step(prev, x_curr_mask):
            x_curr, x_mask = x_curr_mask
            h_prev, c_prev = tf.unstack(prev, axis=0)   
            X_ = x_curr + tf.matmul(h_prev, self.params[_p(prefix, 'U')])
            
            peephone_i = c_prev * (self.params[_p(prefix, 'P')])[0] ## BxD * 1xD
            peephone_f = c_prev * (self.params[_p(prefix, 'P')])[1]

            i = tf.sigmoid(_slice(X_, 0, self.dim_hidden) + peephone_i)
            f = tf.sigmoid(_slice(X_, 1, self.dim_hidden) + peephone_f)
            c = tf.tanh(_slice(X_, 2, self.dim_hidden))
            
            peephone_o = c * (self.params[_p(prefix, 'P')])[2]
            o = tf.sigmoid(_slice(X_, 3, self.dim_hidden) + peephone_o)
            
            
            c = f * c_prev + i * c
            c = x_mask[:, None] * c + (1. - x_mask)[:, None] * c_prev
            
            h = o * tf.tanh(c) 
            h = x_mask[:, None] * h + (1. - x_mask)[:, None] * h_prev
            
            trans_probs = tf.matmul(h, self.params[_p(prefix, 'centroid')])  #B xD  DxK--> BxK 
            trans_probs = self.gumbel_softmax(trans_probs, self.params[_p(prefix, 'temperature')][0], hard=False) # BxK
            
            hc = tf.matmul(trans_probs, tf.transpose(self.params[_p(prefix, 'centroid')])) # BxK KxD-> BxD 
            
            return tf.stack([hc, c], axis=0)

        h0  =  np.zeros([batch_size, self.dim_hidden]).astype('float32')
        c0  =  np.zeros([batch_size, self.dim_hidden]).astype('float32')
        init_state = tf.stack([h0, c0], axis=0)
        
        x_in  = tf.matmul(x, self.params[_p(prefix, 'W')])+self.params[_p(prefix, 'b')]
        states = tf.scan(one_step, [x_in, mask], 
                        initializer = init_state,
                        name = 'h_states')
                        
        
        h = tf.transpose(states, [1,0,2,3])[0]
        c = tf.transpose(states, [1,0,2,3])[1]
        return h, c


def create_config(dim_hidden=128,
           lrate=0.01,
           n_words =5000,
           batch_size= 256,
           n_centroid=5,
           y_dim =5,
           weight_decay=0.0005,
           model_path='new_models/ST-tau_BIH_c5'
           ):
    return locals().copy()

def save_models(modelparams):
    
    new_modelparams = OrderedDict()
    for (layer, param) in modelparams.items():
        for key_value in param:
            
            key, value= key_value[0], key_value[1]
            new_modelparams[key]=value
    return new_modelparams
    
  
def load_models(path):

    print(path)
    trained_param = np.load(path, allow_pickle=True)
    return trained_param
        

class SRLSTM_Model(tfe.Network):
  def __init__(self, srlstm_config, trained_param=None):
    super(SRLSTM_Model, self).__init__()
    self.srlstm_params=OrderedDict()
    self.config = srlstm_config
    
    print("config", self.config)
    self.weight_filler = FillWeight()

    print('...building embedding layer...')
    self.emb_layer = self.track_layer(Embedding_Layer(self.config, self.weight_filler, trained_param=trained_param))
    self.srlstm_params['emb_layer']=self.emb_layer.params.items()
    
    print('...building lstm layer...')
    self.srlstm_layer = self.track_layer(SRLSTM_Peephole_Layer(self.config, self.weight_filler, trained_param=trained_param))
    self.srlstm_params['srlstm_layer']=self.srlstm_layer.params.items()
    
    print('...building dense layer...')
    self.dense_layer = self.track_layer(Dense_Layer(self.config, self.weight_filler, , trained_param=trained_param))
    self.srlstm_params['dense_layer']=self.dense_layer.params.items()

  def call(self, x, mask):
      """Run the model."""
      emb = self.emb_layer(x)
      h, c = self.srlstm_layer(emb, mask)
      hc = tf.concat([h, c], axis=2)[-1]
      logits = self.dense_layer(hc, prefix='Dense')
      
      return logits



def set_seed(seed, n_gpu=1):
    """
    Sets the seed.

    :param seed: seed to set, set -1 to draw a random number
    :param n_gpu:
    :return: 0 on success
    """
    if seed == -1:
        seed = random.randrange(2**32 - 1)
    print("Seed: %s", seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.set_random_seed(seed)
    return 0
    
def accuracy(logits, Y):
    prediction = tf.nn.softmax(logits)
    
    correct_pred = tf.equal(tf.argmax(prediction, 1), Y)
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    return accuracy
       
def prepare_loss(logits, y):
    losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels = y, logits = logits) 
    loss = tf.reduce_mean(losses)
    return loss

def model_loss(model, x, mask, y):
    logits = model(x, mask)
    loss_value = prepare_loss(logits, y)+model.config['weight_decay']*tf.nn.l2_loss(model.dense_layer.params['Dense_W'])
    return loss_value
    


def train(
     max_epochs=5000, 
     dim_hidden=128, 
     batch_size = 256,
     valid_batch_size = 256,
     save_iters = 1000,
     valid_iters = 1000,
     display_iters = 1,
):
    log_path = "logs/st_gumbel_BIH_c5_soft.log"
    logger = open(log_path, "a")
    print('\nSetting seed')
    set_seed(12345)

    srlstm_config = create_config()  # returns dictionary
    print(srlstm_config)
    logger.write(str(srlstm_config))
    srlstm_model = SRLSTM_Model(srlstm_config)
    optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
    grad_fn = tfe.implicit_value_and_gradients(model_loss)
    
   
    
    print('...Loading Data...') 
    train, valid = load_mit_bih_a_data()

    valid_batches = get_minibatches_idx(len(valid[0]), valid_batch_size)
    
    print('%d traning samples'%(len(train[0])))
    print('%d valid samples'%(len(valid[0])))
    
    print('....Training....')
    
    n_iters = 0
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    session = tf.Session(config=config)
    train_loss = 0
    best_val_error=1
    try:
        for epoch in range(max_epochs):
            print('epoch: '+str(epoch))
            batches = get_minibatches_idx(len(train[0]), batch_size, shuffle=True)
            
            for train_batch in batches:
                n_iters += 1
                y_t = [train[1][i] for i in train_batch]
                x_t = [train[0][i] for i in train_batch]

                
                x_, mask_, y_ = prepare_data(x_t, y_t, maxlen=None)
                x_ = x_.reshape((x_.shape[0], x_.shape[1], 1))
                
                loss, grads = grad_fn(srlstm_model, x_, mask_, y_)
                
                optimizer.apply_gradients(grads)
                if np.mod(n_iters, display_iters)==0:
                    print('iter [{}] loss : {}'.format(n_iters, loss))
                    logger.write('iter [{}] loss : {}\n'.format(n_iters, loss))
                    logger.flush()
                    
                    
                if np.mod(n_iters, valid_iters) == 0:
                    
                    accuracies=[]
                    for valid_batch in valid_batches:
                        y_v = [valid[1][i] for i in valid_batch]
                        x_v = [valid[0][i] for i in valid_batch]

                        x_, mask_, y_ = prepare_data(x_v, y_v, maxlen=None)
                        x_ = x_.reshape((x_.shape[0], x_.shape[1], 1))
                        
                        
                        valid_logits = srlstm_model(x_, mask_)
                        
                        accuracy_ = accuracy(valid_logits, y_)
                        accuracies.append(accuracy_)
                    averaged_accuracy = np.mean(accuracies)
                    if best_val_error>1-averaged_accuracy:
                        modelparams = save_models(srlstm_model.srlstm_params)
                        output = open(srlstm_config['model_path']+"_%s.pkl"%(n_iters), 'wb')
                        pickle.dump(modelparams, output)
                        output.close()
                        best_val_error=1-averaged_accuracy
                        logger.write('saving new best model: [{}] valid error : {} '.format(n_iters, 1-averaged_accuracy))
                        print('saving new best model: [{}] valid error : {} '.format(n_iters, 1-averaged_accuracy))
                    print('iter [{}] valid error : {} '.format(n_iters, 1-averaged_accuracy))
                    logger.write('iter [{}] valid error : {} '.format(n_iters, 1-averaged_accuracy))
                    logger.flush()
                    
                    
    except KeyboardInterrupt:
        print("Training interupted")                  
    print('----save checkpoint----')
    
    modelparams = save_models(srlstm_model.srlstm_params)
    output = open(srlstm_config['model_path']+"_%s.pkl"%(n_iters), 'wb')
    pickle.dump(modelparams, output)
    output.close()             

    
if __name__ == '__main__':
    train()

