from scipy.stats import spearmanr
from scipy.stats import pearsonr
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
import ppdb_utils
import numpy as np
import utils
import cPickle
from ppdb_utils import getWordmap
from params import params
from ppdb_utils import getPPDBData
import lasagne
import random
import numpy as np
import sys
import argparse
import ppdb_utils
import theano
import numpy as np
from theano import tensor as T
from theano import config
from lasagne_lstm_nooutput import lasagne_lstm_nooutput
import lasagne

class ppdb_lstm_model(object):

    def __init__(self, We_initial, params, all_params):

        initial_We = theano.shared(np.asarray(We_initial, dtype=config.floatX))
        We = theano.shared(np.asarray(We_initial, dtype=config.floatX))

        g1batchindices = T.imatrix()
        g2batchindices = T.imatrix()
        p1batchindices = T.imatrix()
        p2batchindices = T.imatrix()
        g1mask = T.matrix()
        g2mask = T.matrix()
        p1mask = T.matrix()
        p2mask = T.matrix()

        l_in = lasagne.layers.InputLayer((None, None, 1))
        l_mask = lasagne.layers.InputLayer(shape=(None, None))
        l_emb = lasagne.layers.EmbeddingLayer(l_in, input_size=We.get_value().shape[0],
                                              output_size=We.get_value().shape[1], W=We)
        l_lstm = None
        if params.outgate:
            l_lstm = lasagne.layers.LSTMLayer(l_emb, params.layersize, peepholes=params.peephole, learn_init=False,
                                              mask_input=l_mask)
        else:
            l_lstm = lasagne_lstm_nooutput(l_emb, params.layersize, peepholes=params.peephole, learn_init=False,
                                       mask_input=l_mask)
        print all_params
        all_param_values = [p.get_value() for p in all_params]
        lasagne.layers.set_all_param_values(l_lstm,all_param_values, trainable=True)
        l_out = lasagne.layers.SliceLayer(l_lstm, -1, 1)
        
        embg1 = lasagne.layers.get_output(l_out, {l_in: g1batchindices, l_mask: g1mask})
        embg2 = lasagne.layers.get_output(l_out, {l_in: g2batchindices, l_mask: g2mask})
        embp1 = lasagne.layers.get_output(l_out, {l_in: p1batchindices, l_mask: p1mask})
        embp2 = lasagne.layers.get_output(l_out, {l_in: p2batchindices, l_mask: p2mask})

        g1g2 = (embg1 * embg2).sum(axis=1)
        g1g2norm = T.sqrt(T.sum(embg1 ** 2, axis=1)) * T.sqrt(T.sum(embg2 ** 2, axis=1))
        g1g2 = g1g2 / g1g2norm

        p1g1 = (embp1 * embg1).sum(axis=1)
        p1g1norm = T.sqrt(T.sum(embp1 ** 2, axis=1)) * T.sqrt(T.sum(embg1 ** 2, axis=1))
        p1g1 = p1g1 / p1g1norm

        p2g2 = (embp2 * embg2).sum(axis=1)
        p2g2norm = T.sqrt(T.sum(embp2 ** 2, axis=1)) * T.sqrt(T.sum(embg2 ** 2, axis=1))
        p2g2 = p2g2 / p2g2norm

        costp1g1 = params.margin - g1g2 + p1g1
        costp1g1 = costp1g1 * (costp1g1 > 0)

        costp2g2 = params.margin - g1g2 + p2g2
        costp2g2 = costp2g2 * (costp2g2 > 0)

        cost = costp1g1 + costp2g2
        
        network_params = lasagne.layers.get_all_params(l_lstm, trainable=True)
        network_params.pop(0)
        self.all_params = lasagne.layers.get_all_params(l_lstm, trainable=True)

        l2 = 0.5 * params.LC * sum(lasagne.regularization.l2(x) for x in network_params)
        if params.updatewords:
            word_reg = 0.5 * params.LW * lasagne.regularization.l2(We - initial_We)
            cost = T.mean(cost) + l2 + word_reg
        else:
            cost = T.mean(cost) + l2

        self.feedforward_function = theano.function([g1batchindices, g1mask], embg1)
        self.cost_function = theano.function([g1batchindices, g2batchindices, p1batchindices, p2batchindices,
                                              g1mask, g2mask, p1mask, p2mask], cost)

        prediction = g1g2

        self.scoring_function = theano.function([g1batchindices, g2batchindices,
                                                 g1mask, g2mask], prediction)

        self.train_function = None
        if params.updatewords:
            grads = theano.gradient.grad(cost, self.all_params)
            if params.clip:
                grads = [lasagne.updates.norm_constraint(grad, params.clip, range(grad.ndim)) for grad in grads]
            updates = params.learner(grads, self.all_params, params.eta)
            self.train_function = theano.function([g1batchindices, g2batchindices, p1batchindices, p2batchindices,
                                                   g1mask, g2mask, p1mask, p2mask], cost, updates=updates)
        else:
            self.all_params = network_params
            grads = theano.gradient.grad(cost, self.all_params)
            if params.clip:
                grads = [lasagne.updates.norm_constraint(grad, params.clip, range(grad.ndim)) for grad in grads]
            updates = params.learner(grads, self.all_params, params.eta)
            self.train_function = theano.function([g1batchindices, g2batchindices, p1batchindices, p2batchindices,
                                                   g1mask, g2mask, p1mask, p2mask], cost, updates=updates)

def str2bool(v):
    if v is None:
        return False
    if v.lower() in ("yes", "true", "t", "1"):
        return True
    if v.lower() in ("no", "false", "f", "0"):
        return False
    raise ValueError('A type that was supposed to be boolean is not boolean.')

def learner2bool(v):
    if v is None:
        return lasagne.updates.adagrad
    if v.lower() == "adagrad":
        return lasagne.updates.adagrad
    if v.lower() == "adam":
        return lasagne.updates.adam
    raise ValueError('A type that was supposed to be a learner is not.')


def getSeqs(p1,p2,words):
    p1 = p1.split()
    p2 = p2.split()
    X1 = []
    X2 = []
    for i in p1:
        X1.append(ppdb_utils.lookupIDX(words,i))
    for i in p2:
        X2.append(ppdb_utils.lookupIDX(words,i))
    return X1, X2

def getSeq(p1,words):
    p1 = p1.split()
    X1 = []
    for i in p1:
        X1.append(ppdb_utils.lookupIDX(words,i))
    return X1

def getCorrelation(model,words,f):
    f = open(f,'r')
    lines = f.readlines()
    preds = []
    golds = []
    seq1 = []
    seq2 = []
    for i in lines:
        i = i.split("\t")
        p1 = i[0]; p2 = i[1]; score = float(i[2])
        X1, X2 = getSeqs(p1,p2,words)
        seq1.append(X1)
        seq2.append(X2)
        golds.append(score)
    x1,m1 = utils.prepare_data(seq1)
    x2,m2 = utils.prepare_data(seq2)
    scores = model.scoring_function(x1,x2,m1,m2)
    preds = np.squeeze(scores)
    return pearsonr(preds,golds)[0], spearmanr(preds,golds)[0]

def acc(preds,scores):
    golds = []
    for n,i in enumerate(scores):
        p = -1
        i=i.strip()
        if i == "CONTRADICTION":
            p = 0
        elif i == "NEUTRAL":
            p = 1
        elif i == "ENTAILMENT":
            p = 2
        else:
            raise ValueError('Something wrong with data...')
        golds.append(p)
    #print confusion_matrix(golds,preds)
    return accuracy_score(golds,preds)

def accSentiment(preds,scores):
    golds = []
    for n,i in enumerate(scores):
        p = -1
        i=i.strip()
        if i == "0":
            p = 0
        elif i == "1":
            p = 1
        else:
            raise ValueError('Something wrong with data...')
        golds.append(p)
    return accuracy_score(golds,preds)

def getAcc(model,words,f):
    f = open(f,'r')
    lines = f.readlines()
    preds = []
    golds = []
    seq1 = []
    seq2 = []
    ct = 0
    for i in lines:
        i = i.split("\t")
        p1 = i[0]; p2 = i[1]; score = i[2]
        X1, X2 = getSeqs(p1,p2,words)
        seq1.append(X1)
        seq2.append(X2)
        ct += 1
        if ct % 100 == 0:
            x1,m1 = utils.prepare_data(seq1)
            x2,m2 = utils.prepare_data(seq2)
            scores = model.scoring_function(x1,x2,m1,m2)
            scores = np.squeeze(scores)
            preds.extend(scores.tolist())
            seq1 = []
            seq2 = []
        golds.append(score)
    if len(seq1) > 0:
        x1,m1 = utils.prepare_data(seq1)
        x2,m2 = utils.prepare_data(seq2)
        scores = model.scoring_function(x1,x2,m1,m2)
        scores = np.squeeze(scores)
        preds.extend(scores.tolist())
    return acc(preds,golds)

def getAccSentiment(model,words,f):
    f = open(f,'r')
    lines = f.readlines()
    preds = []
    golds = []
    seq1 = []
    ct = 0
    for i in lines:
        i = i.split("\t")
        p1 = i[0]; score = i[1]
        X1 = getSeq(p1,words)
        seq1.append(X1)
        ct += 1
        if ct % 100 == 0:
            x1,m1 = utils.prepare_data(seq1)
            scores = model.scoring_function(x1,m1)
            scores = np.squeeze(scores)
            preds.extend(scores.tolist())
            seq1 = []
        golds.append(score)
    if len(seq1) > 0:
        x1,m1 = utils.prepare_data(seq1)
        scores = model.scoring_function(x1,m1)
        scores = np.squeeze(scores)
        preds.extend(scores.tolist())
    return accSentiment(preds,golds)

def evaluate(model,words,file,params):
    if params.task == "sim":
        p,s = getCorrelation(model,words,file)
        return p,s
    elif params.task == "ent":
        s = getAcc(model,words,file)
        return s
    elif params.task == "sentiment":
        s = getAccSentiment(model,words,file)
        return s
    else:
        raise ValueError('Task should be ent, sim, or sentiment')

def evaluate_all(model,words):
    prefix = "../data/"
    parr = []; sarr = []

    farr = ["annotated-ppdb-dev",
            "annotated-ppdb-test",
            "sicktest"]

    for i in farr:
        p,s = getCorrelation(model,words,prefix+i)
        parr.append(p); sarr.append(s)

    s = ""
    for i,j,k in zip(parr,sarr,farr):
        s += str(i)+" "+str(j)+" "+k+" | "

    print s

    
def getWordmap(textfile):
    words={}
    We = []
    f = open(textfile,'r')
    lines = f.readlines()
    for (n,i) in enumerate(lines):
        i=i.split()
        j = 1
        v = []
        while j < len(i):
            v.append(float(i[j]))
            j += 1
        words[i[0]]=n
        We.append(v)
    return (words, np.array(We))
   
    
def predict(model,words,text):
    preds = []
    golds = []
    seq1 = []
    ct = 0
    X1 = getSeq(text,words)
    print(X1)
    seq1.append(X1)
    x1,m1 = utils.prepare_data(seq1)
    print x1
    print m1
    scores = model.scoring_function(x1,m1)
    scores = np.squeeze(scores)
    preds.extend(scores.tolist())
    print preds



if __name__ == '__main__':
    random.seed(1)
    np.random.seed(1)

    params = params()

    parser = argparse.ArgumentParser()
    parser.add_argument("-LW", help="Lambda for word embeddings (normal training).", type=float)
    parser.add_argument("-LC", help="Lambda for composition parameters (normal training).", type=float)
    parser.add_argument("-outfile", help="Output file name.")
    parser.add_argument("-batchsize", help="Size of batch.", type=int)
    parser.add_argument("-dim", help="Size of input.", type=int)
    parser.add_argument("-wordfile", help="Word embedding file.")
    parser.add_argument("-layersize", help="Size of output layers in models.", type=int)
    parser.add_argument("-updatewords", help="Whether to update the word embeddings")
    parser.add_argument("-wordstem", help="Nickname of word embeddings used.")
    parser.add_argument("-save", help="Whether to pickle the model.")
    parser.add_argument("-train", help="Training data file.")
    parser.add_argument("-margin", help="Margin in objective function.", type=float)
    parser.add_argument("-samplingtype", help="Type of sampling used.")
    parser.add_argument("-peephole", help="Whether to use peephole connections in LSTM.")
    parser.add_argument("-outgate", help="Whether to use output gate in LSTM.")
    parser.add_argument("-nonlinearity", help="Type of nonlinearity in projection and DAN model.",
                        type=int)
    parser.add_argument("-nntype", help="Type of neural network.")
    parser.add_argument("-evaluate", help="Whether to evaluate the model during training.")
    parser.add_argument("-epochs", help="Number of epochs in training.", type=int)
    parser.add_argument("-clip", help="Threshold for gradient clipping.",type=int)
    parser.add_argument("-eta", help="Learning rate.", type=float)
    parser.add_argument("-learner", help="Either AdaGrad or Adam.")
    parser.add_argument("-add_rnn", help="Whether to keep RNN close to an addition model.")
    parser.add_argument("-numlayers", help="Number of layers in DAN Model.", type=int)
    parser.add_argument("-num_examples", help="Number of examples to use in training. If not set, will use all examples.", type=int)

    args = parser.parse_args()

    params.LW = args.LW
    params.LC = args.LC
    params.outfile = args.outfile
    params.batchsize = args.batchsize
    params.hiddensize = args.dim
    params.wordfile = args.wordfile
    params.nntype = args.nntype
    params.layersize = args.layersize
    params.updatewords = str2bool(args.updatewords)
    params.wordstem = args.wordstem
    params.save = str2bool(args.save)
    params.train = args.train
    params.margin = args.margin
    params.type = args.samplingtype
    params.peephole = str2bool(args.peephole)
    params.outgate = str2bool(args.outgate)
    params.nntype = args.nntype
    params.epochs = args.epochs
    params.evaluate = str2bool(args.evaluate)
    params.learner = learner2bool(args.learner)
    params.add_rnn = str2bool(args.add_rnn)
    params.numlayers = args.numlayers

    if args.eta:
        params.eta = args.eta

    params.clip = args.clip
    if args.clip:
        if params.clip == 0:
            params.clip = None

    if args.nonlinearity:
        if args.nonlinearity == 1:
            params.nonlinearity = lasagne.nonlinearities.linear
        if args.nonlinearity == 2:
            params.nonlinearity = lasagne.nonlinearities.tanh
        if args.nonlinearity == 3:
            params.nonlinearity = lasagne.nonlinearities.rectify
        if args.nonlinearity == 4:
            params.nonlinearity = lasagne.nonlinearities.sigmoid

    (words, We) = getWordmap(params.wordfile)
    examples = getPPDBData(params.train, words)

    if args.num_examples:
        examples = examples[0:args.num_examples]

    print "Saving to: " + params.outfile

    model = None

    print sys.argv

    with open("./lstm-outgate/lstm-model-outgate40.pickle","r") as f:
        all_params = cPickle.load(f)
    if params.nntype == 'lstm':
        model = ppdb_lstm_model(We, params, all_params)
    else:
        "Error no type specified"

    print predict(model,words,"This is a good book")