from params import params
import ppdb_utils
import lasagne
import random
import numpy as np
import sys
import utils
import argparse
import theano
import numpy as np
from theano import tensor as T
from theano import config
from lasagne_lstm_nooutput import lasagne_lstm_nooutput
import lasagne
import cPickle
import os

class lstm_model_sentiment(object):

    def getRegTerm(self, params, We, initial_We, l_out, l_softmax, pickled_params):
        if params.traintype == "normal":
            l2 = 0.5*params.LC*sum(lasagne.regularization.l2(x) for x in self.network_params)
            if params.updatewords:
                return l2 + 0.5*params.LW*lasagne.regularization.l2(We-initial_We)
            else:
                return l2
        elif params.traintype == "reg":
            tmp = lasagne.layers.get_all_params(l_out, trainable=True)
            idx = 1
            l2 = 0.
            while idx < len(tmp):
                l2 += 0.5*params.LRC*(lasagne.regularization.l2(tmp[idx]-np.asarray(pickled_params[idx].get_value(), dtype = config.floatX)))
                idx += 1
            tmp = lasagne.layers.get_all_params(l_softmax, trainable=True)
            l2 += 0.5*params.LC*sum(lasagne.regularization.l2(x) for x in tmp)
            return l2 + 0.5*params.LRW*lasagne.regularization.l2(We-initial_We)
        elif params.traintype == "rep":
            tmp = lasagne.layers.get_all_params(l_softmax, trainable=True)
            l2 = 0.5*params.LC*sum(lasagne.regularization.l2(x) for x in tmp)
            return l2
        else:
            raise ValueError('params.traintype not set correctly.')

    def getTrainableParams(self, params):
        if params.traintype == "rep":
            return self.network_params
        if params.updatewords or params.traintype == "reg":
            return self.all_params
        else:
            return self.network_params

    def __init__(self, We_initial, params):

        p = None

        if params.traintype == "reg" or params.traintype == "rep":
            p = cPickle.load(file(params.regfile, 'rb'))
            # print p
            #contains [<TensorType(float64, matrix)>,
            # W_in_to_ingate, W_hid_to_ingate, b_ingate, W_in_to_forgetgate,
            # W_hid_to_forgetgate, b_forgetgate, W_in_to_cell, W_hid_to_cell,
            # b_cell, W_in_to_outgate, W_hid_to_outgate, b_outgate, W_cell_to_ingate,
            # W_cell_to_forgetgate, W_cell_to_outgate]

        # if params.traintype == "reg":
        #     print "regularizing to parameters"
        #
        # if params.traintype == "rep":
        #     print "not updating embeddings"

        #params
        initial_We = theano.shared(np.asarray(We_initial, dtype = config.floatX))
        We = theano.shared(np.asarray(We_initial, dtype = config.floatX))

        if params.traintype == "reg":
            initial_We = theano.shared(np.asarray(p[0].get_value(), dtype = config.floatX))
            We = theano.shared(np.asarray(p[0].get_value(), dtype = config.floatX))
            updatewords = True

        if params.traintype == "rep":
            We = theano.shared(np.asarray(p[0].get_value(), dtype = config.floatX))
            updatewords = False

        g1batchindices = T.imatrix()
        g1mask = T.matrix()
        scores = T.matrix()

        l_in = lasagne.layers.InputLayer((None, None, 1))
        l_mask = lasagne.layers.InputLayer(shape=(None, None))
        l_emb = lasagne.layers.EmbeddingLayer(l_in, input_size=We.get_value().shape[0], output_size=We.get_value().shape[1], W=We)
        l_lstm = None
        if params.useoutgate:
            l_lstm = lasagne.layers.LSTMLayer(l_emb, params.layersize, peepholes=params.usepeep, learn_init=False, mask_input = l_mask)
        else:
            l_lstm = lasagne_lstm_nooutput(l_emb, params.layersize, peepholes=params.usepeep, learn_init=False, mask_input = l_mask)

        if params.traintype == "reg" or params.traintype == "rep":
            if params.useoutgate:
                W_in_to_ingate = np.asarray(p[1].get_value(), dtype = config.floatX)
                W_hid_to_ingate = np.asarray(p[2].get_value(), dtype = config.floatX)
                b_ingate = np.asarray(p[3].get_value(), dtype = config.floatX)
                W_in_to_forgetgate = np.asarray(p[4].get_value(), dtype = config.floatX)
                W_hid_to_forgetgate = np.asarray(p[5].get_value(), dtype = config.floatX)
                b_forgetgate = np.asarray(p[6].get_value(), dtype = config.floatX)
                W_in_to_cell = np.asarray(p[7].get_value(), dtype = config.floatX)
                W_hid_to_cell = np.asarray(p[8].get_value(), dtype = config.floatX)
                b_cell = np.asarray(p[9].get_value(), dtype = config.floatX)
                W_in_to_outgate = np.asarray(p[10].get_value(), dtype = config.floatX)
                W_hid_to_outgate = np.asarray(p[11].get_value(), dtype = config.floatX)
                b_outgate = np.asarray(p[12].get_value(), dtype = config.floatX)
                W_cell_to_ingate = np.asarray(p[13].get_value(), dtype = config.floatX)
                W_cell_to_forgetgate = np.asarray(p[14].get_value(), dtype = config.floatX)
                W_cell_to_outgate = np.asarray(p[15].get_value(), dtype = config.floatX)

                ingate = lasagne.layers.Gate(W_in=W_in_to_ingate, W_hid=W_hid_to_ingate, W_cell=W_cell_to_ingate, b=b_ingate)
                forgetgate = lasagne.layers.Gate(W_in=W_in_to_forgetgate, W_hid=W_hid_to_forgetgate, W_cell=W_cell_to_forgetgate, b=b_forgetgate)
                outgate = lasagne.layers.Gate(W_in=W_in_to_outgate, W_hid=W_hid_to_outgate, W_cell=W_cell_to_outgate, b=b_outgate)
                cell = lasagne.layers.Gate(W_in=W_in_to_cell, W_hid=W_hid_to_cell, W_cell=None, b=b_cell, nonlinearity=lasagne.nonlinearities.tanh)
                l_lstm = lasagne.layers.LSTMLayer(l_emb, params.layersize, ingate = ingate, forgetgate = forgetgate,
                                  outgate = outgate, cell = cell, peepholes=params.usepeep, learn_init=False, mask_input = l_mask)
            else:
                W_in_to_ingate = np.asarray(p[1].get_value(), dtype = config.floatX)
                W_hid_to_ingate = np.asarray(p[2].get_value(), dtype = config.floatX)
                b_ingate = np.asarray(p[3].get_value(), dtype = config.floatX)
                W_in_to_forgetgate = np.asarray(p[4].get_value(), dtype = config.floatX)
                W_hid_to_forgetgate = np.asarray(p[5].get_value(), dtype = config.floatX)
                b_forgetgate = np.asarray(p[6].get_value(), dtype = config.floatX)
                W_in_to_cell = np.asarray(p[7].get_value(), dtype = config.floatX)
                W_hid_to_cell = np.asarray(p[8].get_value(), dtype = config.floatX)
                b_cell = np.asarray(p[9].get_value(), dtype = config.floatX)
                W_cell_to_ingate = np.asarray(p[10].get_value(), dtype = config.floatX)
                W_cell_to_forgetgate = np.asarray(p[11].get_value(), dtype = config.floatX)

                ingate = lasagne.layers.Gate(W_in=W_in_to_ingate, W_hid=W_hid_to_ingate, W_cell=W_cell_to_ingate, b=b_ingate)
                forgetgate = lasagne.layers.Gate(W_in=W_in_to_forgetgate, W_hid=W_hid_to_forgetgate, W_cell=W_cell_to_forgetgate, b=b_forgetgate)
                cell = lasagne.layers.Gate(W_in=W_in_to_cell, W_hid=W_hid_to_cell, W_cell=None, b=b_cell, nonlinearity=lasagne.nonlinearities.tanh)
                l_lstm = lasagne_lstm_nooutput(l_emb, params.layersize, ingate = ingate, forgetgate = forgetgate,
                                  cell = cell, peepholes=params.usepeep, learn_init=False, mask_input = l_mask)

        l_out = lasagne.layers.SliceLayer(l_lstm, -1, 1)

        embg = lasagne.layers.get_output(l_out, {l_in:g1batchindices, l_mask:g1mask})

        l_in2 = lasagne.layers.InputLayer((None, We.get_value().shape[1]))
        l_sigmoid = lasagne.layers.DenseLayer(l_in2, params.memsize, nonlinearity=lasagne.nonlinearities.sigmoid)

        l_softmax = lasagne.layers.DenseLayer(l_sigmoid, 2, nonlinearity=T.nnet.softmax)
        X = lasagne.layers.get_output(l_softmax, {l_in2:embg})
        cost = T.nnet.categorical_crossentropy(X,scores)
        prediction = T.argmax(X, axis=1)

        self.network_params = lasagne.layers.get_all_params(l_out, trainable=True) + lasagne.layers.get_all_params(l_softmax, trainable=True)
        self.network_params.pop(0)
        self.all_params = lasagne.layers.get_all_params(l_out, trainable=True) + lasagne.layers.get_all_params(l_softmax, trainable=True)

        reg = self.getRegTerm(params, We, initial_We, l_out, l_softmax, p)
        self.trainable = self.getTrainableParams(params)
        cost = T.mean(cost) + reg

        self.feedforward_function = theano.function([g1batchindices,g1mask], embg)
        self.scoring_function = theano.function([g1batchindices,
                             g1mask],prediction)
        self.cost_function = theano.function([scores, g1batchindices,
                             g1mask], cost)

        grads = theano.gradient.grad(cost, self.trainable)
        if params.clip:
            grads = [lasagne.updates.norm_constraint(grad, params.clip, range(grad.ndim)) for grad in grads]
        updates = params.learner(grads, self.trainable, params.eta)
        self.train_function = theano.function([scores, g1batchindices,
                             g1mask], cost, updates=updates)
def getSeqs(p1,p2,words):
    p1 = p1.split()
    p2 = p2.split()
    X1 = []
    X2 = []
    for i in p1:
        X1.append(ppdb_utils.lookupIDX(words,i))
    for i in p2:
        X2.append(ppdb_utils.lookupIDX(words,i))
    return X1, X2

def getSeq(p1,words):
    p1 = p1.split()
    X1 = []
    for i in p1:
        X1.append(ppdb_utils.lookupIDX(words,i))
    return X1

def predict(model,words,text):
    preds = []
    golds = []
    seq1 = []
    ct = 0
    if not isinstance(text,list):
        text=[text]
    for x in text:
        X1 = getSeq(x,words)
        print x
        print X1
        seq1.append(X1)
    x1,m1 = utils.prepare_data(seq1)
    scores = model.scoring_function(x1,m1)
    scores = np.squeeze(scores)
    # print text
    # print "score"
    return scores.tolist()



if __name__ == '__main__':
    print "begin"
    new_path = '../similarity'
    if new_path not in sys.path:
        sys.path.append(new_path)
    new_path = '../sentiment'
    if new_path not in sys.path:
        sys.path.append(new_path)

    def str2bool(v):
        if v is None:
            return False
        if v.lower() in ("yes", "true", "t", "1"):
            return True
        if v.lower() in ("no", "false", "f", "0"):
            return False
        raise ValueError('A type that was supposed to be boolean is not boolean.')

    def learner2bool(v):
        if v is None:
            return lasagne.updates.adam
        if v.lower() == "adagrad":
            return lasagne.updates.adagrad
        if v.lower() == "adam":
            return lasagne.updates.adam
        raise ValueError('A type that was supposed to be a learner is not.')

    random.seed(1)
    np.random.seed(1)

    params = params()

    parser = argparse.ArgumentParser()
    parser.add_argument("-LW", help="Lambda for word embeddings (normal training).", type=float)
    parser.add_argument("-LC", help="Lambda for composition parameters (normal training).", type=float)
    parser.add_argument("-outfile", help="Output file name.")
    parser.add_argument("-batchsize", help="Size of batch.", type=int)
    parser.add_argument("-dim", help="Size of input.", type=int)
    parser.add_argument("-memsize", help="Size of classification layer.",
                        type=int)
    parser.add_argument("-wordfile", help="Word embedding file.")
    parser.add_argument("-layersize", help="Size of output layers in models.", type=int)
    parser.add_argument("-updatewords", help="Whether to update the word embeddings")
    parser.add_argument("-wordstem", help="Nickname of word embeddings used.")
    parser.add_argument("-save", help="Whether to pickle the model.")
    parser.add_argument("-traindata", help="Training data file.")
    parser.add_argument("-devdata", help="Training data file.")
    parser.add_argument("-testdata", help="Testing data file.")
    parser.add_argument("-peephole", help="Whether to use peephole connections in LSTM.")
    parser.add_argument("-outgate", help="Whether to use output gate in LSTM.")
    parser.add_argument("-nonlinearity", help="Type of nonlinearity in projection and DAN model.",
                        type=int)
    parser.add_argument("-nntype", help="Type of neural network.")
    parser.add_argument("-evaluate", help="Whether to evaluate the model during training.")
    parser.add_argument("-epochs", help="Number of epochs in training.", type=int)
    parser.add_argument("-regfile", help="Path to model file that we want to regularize towards.")
    parser.add_argument("-minval", help="Min rating possible in scoring.", type=int)
    parser.add_argument("-maxval", help="Max rating possible in scoring.", type=int)
    parser.add_argument("-LRW", help="Lambda for word embeddings (regularization training).", type=float)
    parser.add_argument("-LRC", help="Lambda for composition parameters (regularization training).", type=float)
    parser.add_argument("-traintype", help="Either normal, reg, or rep.")
    parser.add_argument("-clip", help="Threshold for gradient clipping.",type=int)
    parser.add_argument("-eta", help="Learning rate.", type=float)
    parser.add_argument("-learner", help="Either AdaGrad or Adam.")
    parser.add_argument("-task", help="Either sim, ent, or sentiment.")
    parser.add_argument("-numlayers", help="Number of layers in DAN Model.", type=int)

    args = parser.parse_args()

    params.LW = args.LW
    params.LC = args.LC
    params.outfile = args.outfile
    params.batchsize = args.batchsize
    params.hiddensize = args.dim
    params.memsize = args.memsize
    params.wordfile = args.wordfile
    params.nntype = args.nntype
    params.layersize = args.layersize
    params.updatewords = str2bool(args.updatewords)
    params.wordstem = args.wordstem
    params.save = str2bool(args.save)
    params.traindata = args.traindata
    params.devdata = args.devdata
    params.testdata = args.testdata
    params.usepeep = str2bool(args.peephole)
    params.useoutgate = str2bool(args.outgate)
    params.nntype = args.nntype
    params.epochs = args.epochs
    params.traintype = args.traintype
    params.evaluate = str2bool(args.evaluate)
    params.LRW = args.LRW
    params.LRC = args.LRC
    params.learner = learner2bool(args.learner)
    params.task = args.task
    params.numlayers = args.numlayers
    print "test1"
    if args.eta:
        params.eta = args.eta

    params.clip = args.clip
    if args.clip:
        if params.clip == 0:
            params.clip = None

    params.regfile = args.regfile
    params.minval = args.minval
    params.maxval = args.maxval

    if args.nonlinearity:
        if args.nonlinearity == 1:
            params.nonlinearity = lasagne.nonlinearities.linear
        if args.nonlinearity == 2:
            params.nonlinearity = lasagne.nonlinearities.tanh
        if args.nonlinearity == 3:
            params.nonlinearity = lasagne.nonlinearities.rectify
        if args.nonlinearity == 4:
            params.nonlinearity = lasagne.nonlinearities.sigmoid

    (words, We) = ppdb_utils.getWordmap(params.wordfile)

    if args.task == "sim" or args.task == "ent":
        train_data = ppdb_utils.getSimEntDataset(params.traindata,words,params.task)
    elif args.task == "sentiment":
        train_data = ppdb_utils.getSentimentDataset(params.traindata,words)
    else:
        raise ValueError('Task should be ent, sim, or sentiment.')

    model = None

    # print sys.argv

    model = lstm_model_sentiment(We, params)
    print words['UUUNKKK']

    predictin = '/home/jhliu/Myalgorithm_5nd/mypipe'
    # os.mkfifo(path=predictin, mode=0o666)
    predictout = '/home/jhliu/Myalgorithm_5nd/mypipe2'
    # os.mkfifo(path=predictout, mode=0o666)
    print "open1"
    pin = open(predictin, "r")
    print "open2"
    pout = open(predictout, "w")
    while 1:
        print "waiting"
        x = pin.readline()
        print "get"
        print x
        res = predict(model,words,x)
        pout.write(str(res) + '\n')