from params import params
import ppdb_utils
import lasagne
import random
import numpy as np
import sys
import utils
import argparse
import theano
import numpy as np
from theano import tensor as T
from theano import config
from lasagne_lstm_nooutput import lasagne_lstm_nooutput
import lasagne
import cPickle

class lstm_model_sentiment(object):

    def getRegTerm(self, params, We, initial_We, l_out, l_softmax, pickled_params):
        if params.traintype == "normal":
            l2 = 0.5*params.LC*sum(lasagne.regularization.l2(x) for x in self.network_params)
            if params.updatewords:
                return l2 + 0.5*params.LW*lasagne.regularization.l2(We-initial_We)
            else:
                return l2
        elif params.traintype == "reg":
            tmp = lasagne.layers.get_all_params(l_out, trainable=True)
            idx = 1
            l2 = 0.
            while idx < len(tmp):
                l2 += 0.5*params.LRC*(lasagne.regularization.l2(tmp[idx]-np.asarray(pickled_params[idx].get_value(), dtype = config.floatX)))
                idx += 1
            tmp = lasagne.layers.get_all_params(l_softmax, trainable=True)
            l2 += 0.5*params.LC*sum(lasagne.regularization.l2(x) for x in tmp)
            return l2 + 0.5*params.LRW*lasagne.regularization.l2(We-initial_We)
        elif params.traintype == "rep":
            tmp = lasagne.layers.get_all_params(l_softmax, trainable=True)
            l2 = 0.5*params.LC*sum(lasagne.regularization.l2(x) for x in tmp)
            return l2
        else:
            raise ValueError('params.traintype not set correctly.')

    def getTrainableParams(self, params):
        if params.traintype == "rep":
            return self.network_params
        if params.updatewords or params.traintype == "reg":
            return self.all_params
        else:
            return self.network_params

    def __init__(self, We_initial, params):

        p = None

        if params.traintype == "reg" or params.traintype == "rep":
            p = cPickle.load(file(params.regfile, 'rb'))
            # print p
            #contains [<TensorType(float64, matrix)>,
            # W_in_to_ingate, W_hid_to_ingate, b_ingate, W_in_to_forgetgate,
            # W_hid_to_forgetgate, b_forgetgate, W_in_to_cell, W_hid_to_cell,
            # b_cell, W_in_to_outgate, W_hid_to_outgate, b_outgate, W_cell_to_ingate,
            # W_cell_to_forgetgate, W_cell_to_outgate]

        # if params.traintype == "reg":
        #     print "regularizing to parameters"
        #
        # if params.traintype == "rep":
        #     print "not updating embeddings"

        #params
        initial_We = theano.shared(np.asarray(We_initial, dtype = config.floatX))
        We = theano.shared(np.asarray(We_initial, dtype = config.floatX))

        if params.traintype == "reg":
            initial_We = theano.shared(np.asarray(p[0].get_value(), dtype = config.floatX))
            We = theano.shared(np.asarray(p[0].get_value(), dtype = config.floatX))
            updatewords = True

        if params.traintype == "rep":
            We = theano.shared(np.asarray(p[0].get_value(), dtype = config.floatX))
            updatewords = False

        g1batchindices = T.imatrix()
        g1mask = T.matrix()
        scores = T.matrix()

        l_in = lasagne.layers.InputLayer((None, None, 1))
        l_mask = lasagne.layers.InputLayer(shape=(None, None))
        l_emb = lasagne.layers.EmbeddingLayer(l_in, input_size=We.get_value().shape[0], output_size=We.get_value().shape[1], W=We)
        l_lstm = None
        if params.useoutgate:
            l_lstm = lasagne.layers.LSTMLayer(l_emb, params.layersize, peepholes=params.usepeep, learn_init=False, mask_input = l_mask)
        else:
            l_lstm = lasagne_lstm_nooutput(l_emb, params.layersize, peepholes=params.usepeep, learn_init=False, mask_input = l_mask)

        if params.traintype == "reg" or params.traintype == "rep":
            if params.useoutgate:
                W_in_to_ingate = np.asarray(p[1].get_value(), dtype = config.floatX)
                W_hid_to_ingate = np.asarray(p[2].get_value(), dtype = config.floatX)
                b_ingate = np.asarray(p[3].get_value(), dtype = config.floatX)
                W_in_to_forgetgate = np.asarray(p[4].get_value(), dtype = config.floatX)
                W_hid_to_forgetgate = np.asarray(p[5].get_value(), dtype = config.floatX)
                b_forgetgate = np.asarray(p[6].get_value(), dtype = config.floatX)
                W_in_to_cell = np.asarray(p[7].get_value(), dtype = config.floatX)
                W_hid_to_cell = np.asarray(p[8].get_value(), dtype = config.floatX)
                b_cell = np.asarray(p[9].get_value(), dtype = config.floatX)
                W_in_to_outgate = np.asarray(p[10].get_value(), dtype = config.floatX)
                W_hid_to_outgate = np.asarray(p[11].get_value(), dtype = config.floatX)
                b_outgate = np.asarray(p[12].get_value(), dtype = config.floatX)
                W_cell_to_ingate = np.asarray(p[13].get_value(), dtype = config.floatX)
                W_cell_to_forgetgate = np.asarray(p[14].get_value(), dtype = config.floatX)
                W_cell_to_outgate = np.asarray(p[15].get_value(), dtype = config.floatX)

                ingate = lasagne.layers.Gate(W_in=W_in_to_ingate, W_hid=W_hid_to_ingate, W_cell=W_cell_to_ingate, b=b_ingate)
                forgetgate = lasagne.layers.Gate(W_in=W_in_to_forgetgate, W_hid=W_hid_to_forgetgate, W_cell=W_cell_to_forgetgate, b=b_forgetgate)
                outgate = lasagne.layers.Gate(W_in=W_in_to_outgate, W_hid=W_hid_to_outgate, W_cell=W_cell_to_outgate, b=b_outgate)
                cell = lasagne.layers.Gate(W_in=W_in_to_cell, W_hid=W_hid_to_cell, W_cell=None, b=b_cell, nonlinearity=lasagne.nonlinearities.tanh)
                l_lstm = lasagne.layers.LSTMLayer(l_emb, params.layersize, ingate = ingate, forgetgate = forgetgate,
                                  outgate = outgate, cell = cell, peepholes=params.usepeep, learn_init=False, mask_input = l_mask)
            else:
                W_in_to_ingate = np.asarray(p[1].get_value(), dtype = config.floatX)
                W_hid_to_ingate = np.asarray(p[2].get_value(), dtype = config.floatX)
                b_ingate = np.asarray(p[3].get_value(), dtype = config.floatX)
                W_in_to_forgetgate = np.asarray(p[4].get_value(), dtype = config.floatX)
                W_hid_to_forgetgate = np.asarray(p[5].get_value(), dtype = config.floatX)
                b_forgetgate = np.asarray(p[6].get_value(), dtype = config.floatX)
                W_in_to_cell = np.asarray(p[7].get_value(), dtype = config.floatX)
                W_hid_to_cell = np.asarray(p[8].get_value(), dtype = config.floatX)
                b_cell = np.asarray(p[9].get_value(), dtype = config.floatX)
                W_cell_to_ingate = np.asarray(p[10].get_value(), dtype = config.floatX)
                W_cell_to_forgetgate = np.asarray(p[11].get_value(), dtype = config.floatX)

                ingate = lasagne.layers.Gate(W_in=W_in_to_ingate, W_hid=W_hid_to_ingate, W_cell=W_cell_to_ingate, b=b_ingate)
                forgetgate = lasagne.layers.Gate(W_in=W_in_to_forgetgate, W_hid=W_hid_to_forgetgate, W_cell=W_cell_to_forgetgate, b=b_forgetgate)
                cell = lasagne.layers.Gate(W_in=W_in_to_cell, W_hid=W_hid_to_cell, W_cell=None, b=b_cell, nonlinearity=lasagne.nonlinearities.tanh)
                l_lstm = lasagne_lstm_nooutput(l_emb, params.layersize, ingate = ingate, forgetgate = forgetgate,
                                  cell = cell, peepholes=params.usepeep, learn_init=False, mask_input = l_mask)

        l_out = lasagne.layers.SliceLayer(l_lstm, -1, 1)

        embg = lasagne.layers.get_output(l_out, {l_in:g1batchindices, l_mask:g1mask})

        l_in2 = lasagne.layers.InputLayer((None, We.get_value().shape[1]))
        l_sigmoid = lasagne.layers.DenseLayer(l_in2, params.memsize, nonlinearity=lasagne.nonlinearities.sigmoid)

        l_softmax = lasagne.layers.DenseLayer(l_sigmoid, 2, nonlinearity=T.nnet.softmax)
        X = lasagne.layers.get_output(l_softmax, {l_in2:embg})
        cost = T.nnet.categorical_crossentropy(X,scores)
        prediction = T.argmax(X, axis=1)

        self.network_params = lasagne.layers.get_all_params(l_out, trainable=True) + lasagne.layers.get_all_params(l_softmax, trainable=True)
        self.network_params.pop(0)
        self.all_params = lasagne.layers.get_all_params(l_out, trainable=True) + lasagne.layers.get_all_params(l_softmax, trainable=True)

        reg = self.getRegTerm(params, We, initial_We, l_out, l_softmax, p)
        self.trainable = self.getTrainableParams(params)
        cost = T.mean(cost) + reg

        self.feedforward_function = theano.function([g1batchindices,g1mask], embg)
        self.scoring_function = theano.function([g1batchindices,
                             g1mask],prediction)
        self.cost_function = theano.function([scores, g1batchindices,
                             g1mask], cost)

        grads = theano.gradient.grad(cost, self.trainable)
        if params.clip:
            grads = [lasagne.updates.norm_constraint(grad, params.clip, range(grad.ndim)) for grad in grads]
        updates = params.learner(grads, self.trainable, params.eta)
        self.train_function = theano.function([scores, g1batchindices,
                             g1mask], cost, updates=updates)
def getSeqs(p1,p2,words):
    p1 = p1.split()
    p2 = p2.split()
    X1 = []
    X2 = []
    for i in p1:
        X1.append(ppdb_utils.lookupIDX(words,i))
    for i in p2:
        X2.append(ppdb_utils.lookupIDX(words,i))
    return X1, X2

def getSeq(p1,words):
    p1 = p1.split()
    X1 = []
    for i in p1:
        X1.append(ppdb_utils.lookupIDX(words,i))
    return X1

def predict(model,words,text):
    preds = []
    golds = []
    seq1 = []
    ct = 0
    if not isinstance(text,list):
        text=[text]
    for x in text:
        X1 = getSeq(x,words)
        print x
        print X1
        seq1.append(X1)
    x1,m1 = utils.prepare_data(seq1)
    scores = model.scoring_function(x1,m1)
    scores = np.squeeze(scores)
    # print text
    # print "score"
    print scores.tolist()



if __name__ == '__main__':
    new_path = '../similarity'
    if new_path not in sys.path:
        sys.path.append(new_path)
    new_path = '../sentiment'
    if new_path not in sys.path:
        sys.path.append(new_path)

    def str2bool(v):
        if v is None:
            return False
        if v.lower() in ("yes", "true", "t", "1"):
            return True
        if v.lower() in ("no", "false", "f", "0"):
            return False
        raise ValueError('A type that was supposed to be boolean is not boolean.')

    def learner2bool(v):
        if v is None:
            return lasagne.updates.adam
        if v.lower() == "adagrad":
            return lasagne.updates.adagrad
        if v.lower() == "adam":
            return lasagne.updates.adam
        raise ValueError('A type that was supposed to be a learner is not.')

    random.seed(1)
    np.random.seed(1)

    params = params()

    parser = argparse.ArgumentParser()
    parser.add_argument("-LW", help="Lambda for word embeddings (normal training).", type=float)
    parser.add_argument("-LC", help="Lambda for composition parameters (normal training).", type=float)
    parser.add_argument("-outfile", help="Output file name.")
    parser.add_argument("-batchsize", help="Size of batch.", type=int)
    parser.add_argument("-dim", help="Size of input.", type=int)
    parser.add_argument("-memsize", help="Size of classification layer.",
                        type=int)
    parser.add_argument("-wordfile", help="Word embedding file.")
    parser.add_argument("-layersize", help="Size of output layers in models.", type=int)
    parser.add_argument("-updatewords", help="Whether to update the word embeddings")
    parser.add_argument("-wordstem", help="Nickname of word embeddings used.")
    parser.add_argument("-save", help="Whether to pickle the model.")
    parser.add_argument("-traindata", help="Training data file.")
    parser.add_argument("-devdata", help="Training data file.")
    parser.add_argument("-testdata", help="Testing data file.")
    parser.add_argument("-peephole", help="Whether to use peephole connections in LSTM.")
    parser.add_argument("-outgate", help="Whether to use output gate in LSTM.")
    parser.add_argument("-nonlinearity", help="Type of nonlinearity in projection and DAN model.",
                        type=int)
    parser.add_argument("-nntype", help="Type of neural network.")
    parser.add_argument("-evaluate", help="Whether to evaluate the model during training.")
    parser.add_argument("-epochs", help="Number of epochs in training.", type=int)
    parser.add_argument("-regfile", help="Path to model file that we want to regularize towards.")
    parser.add_argument("-minval", help="Min rating possible in scoring.", type=int)
    parser.add_argument("-maxval", help="Max rating possible in scoring.", type=int)
    parser.add_argument("-LRW", help="Lambda for word embeddings (regularization training).", type=float)
    parser.add_argument("-LRC", help="Lambda for composition parameters (regularization training).", type=float)
    parser.add_argument("-traintype", help="Either normal, reg, or rep.")
    parser.add_argument("-clip", help="Threshold for gradient clipping.",type=int)
    parser.add_argument("-eta", help="Learning rate.", type=float)
    parser.add_argument("-learner", help="Either AdaGrad or Adam.")
    parser.add_argument("-task", help="Either sim, ent, or sentiment.")
    parser.add_argument("-numlayers", help="Number of layers in DAN Model.", type=int)

    args = parser.parse_args()

    params.LW = args.LW
    params.LC = args.LC
    params.outfile = args.outfile
    params.batchsize = args.batchsize
    params.hiddensize = args.dim
    params.memsize = args.memsize
    params.wordfile = args.wordfile
    params.nntype = args.nntype
    params.layersize = args.layersize
    params.updatewords = str2bool(args.updatewords)
    params.wordstem = args.wordstem
    params.save = str2bool(args.save)
    params.traindata = args.traindata
    params.devdata = args.devdata
    params.testdata = args.testdata
    params.usepeep = str2bool(args.peephole)
    params.useoutgate = str2bool(args.outgate)
    params.nntype = args.nntype
    params.epochs = args.epochs
    params.traintype = args.traintype
    params.evaluate = str2bool(args.evaluate)
    params.LRW = args.LRW
    params.LRC = args.LRC
    params.learner = learner2bool(args.learner)
    params.task = args.task
    params.numlayers = args.numlayers

    if args.eta:
        params.eta = args.eta

    params.clip = args.clip
    if args.clip:
        if params.clip == 0:
            params.clip = None

    params.regfile = args.regfile
    params.minval = args.minval
    params.maxval = args.maxval

    if args.nonlinearity:
        if args.nonlinearity == 1:
            params.nonlinearity = lasagne.nonlinearities.linear
        if args.nonlinearity == 2:
            params.nonlinearity = lasagne.nonlinearities.tanh
        if args.nonlinearity == 3:
            params.nonlinearity = lasagne.nonlinearities.rectify
        if args.nonlinearity == 4:
            params.nonlinearity = lasagne.nonlinearities.sigmoid

    (words, We) = ppdb_utils.getWordmap(params.wordfile)

    if args.task == "sim" or args.task == "ent":
        train_data = ppdb_utils.getSimEntDataset(params.traindata,words,params.task)
    elif args.task == "sentiment":
        train_data = ppdb_utils.getSentimentDataset(params.traindata,words)
    else:
        raise ValueError('Task should be ent, sim, or sentiment.')

    model = None

    # print sys.argv

    model = lstm_model_sentiment(We, params)
    print words['UUUNKKK']
    x=[
        "This movie is not bad",
        "This movie is not bad .",
        "This movie is not good",
        "This movie is not sad",
        "This movie is sad",
        "This movie is good",
        "This movie is bad",
        "This movie is happy"
    ]


    predict(model,words,x)
