"""
    Reads (or writes) BOW-formatted notes and performs scikit-learn logistic regression
"""
import csv
import numpy as np
import os
import pickle
import sys
import time
import argparse

from collections import Counter, defaultdict
from scipy.sparse import csr_matrix
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier
from tqdm import tqdm
import pandas as pd

from constants import *
import datasets
import evaluation
from learn import tools
import persistence
from utils import read_bows, write_bows, construct_X_Y

import nltk

#Constants
C = 1.0
MAX_ITER = 20

def main(args):
    Y, train_fname, dev_fname = args.Y, args.data_path, args.data_path.replace("train", "dev")
    vocab_file, version, n = args.vocab, args.version, args.n 
 
    #need to handle really large text fields
    csv.field_size_limit(sys.maxsize)   

    #get lookups from non-BOW data
    data_path = train_fname.replace('_bows', '') if "_bows" in train_fname else train_fname
    dicts = datasets.load_lookups(args)
    w2ind, ind2c, c2ind = dicts['w2ind'], dicts['ind2c'], dicts['c2ind']

    # Construct X,y
    if os.path.exists(train_fname.replace('.csv', '_bows.csv')):
        X, yy_tr, hids_tr = read_bows(Y, train_fname.replace('.csv', '_bows.csv'), c2ind, version)
    else:
        X, yy_tr, hids_tr =  construct_X_Y(train_fname, Y, w2ind, c2ind, version)
        write_bows(train_fname, X, hids_tr, yy_tr, ind2c)
    if os.path.exists(dev_fname.replace('.csv', '_bows.csv')):
        X_dv, yy_dv, hids_dv =  read_bows(Y, dev_fname.replace('.csv', '_bows.csv'), c2ind, version)
    else:
        X_dv, yy_dv, hids_dv =  construct_X_Y(dev_fname, Y, w2ind, c2ind, version)
        write_bows(dev_fname, X_dv, hids_dv, yy_dv, ind2c)

    print("X.shape: " + str(X.shape))
    print("yy_tr.shape: " + str(yy_tr.shape))
    print("X_dv.shape: " + str(X_dv.shape))
    print("yy_dv.shape: " + str(yy_dv.shape))

    #deal with labels that don't have any positive examples
    #drop empty columns from yy. keep track of which columns kept
    #predict on test data with those columns. guess 0 on the others
    labels_with_examples = yy_tr.sum(axis=0).nonzero()[0]
    yy = yy_tr[:, labels_with_examples]

    # build the classifier
    clf = OneVsRestClassifier(LogisticRegression(C=C, max_iter=MAX_ITER, solver='sag'), n_jobs=-1)

    # train
    print("training...")
    clf.fit(X.copy(), yy)

    #predict
    print("predicting...")
    yhat = clf.predict(X_dv)
    yhat_raw = clf.predict_proba(X_dv)

    #deal with labels that don't have positive training examples
    print("reshaping output to deal with labels missing from train set")
    labels_with_examples = set(labels_with_examples)
    yhat_full = np.zeros(yy_dv.shape)
    yhat_full_raw = np.zeros(yy_dv.shape)
    j = 0
    for i in range(yhat_full.shape[1]):
        if i in labels_with_examples:
            yhat_full[:,i] = yhat[:,j]
            yhat_full_raw[:,i] = yhat_raw[:,j]
            j += 1

    #evaluate
    metrics = evaluation.all_metrics(yhat_full, yy_dv, k=[8, 15], yhat_raw=yhat_full_raw)
    evaluation.print_metrics(metrics)

    #save metric history, model, params
    model_dir = os.path.join(MODEL_DIR, '_'.join(["log_reg", time.strftime('%b_%d_%H:%M', time.localtime())]))
    os.mkdir(model_dir)
    print("saving model")
    with open("%s/model.pkl" % model_dir, 'wb') as f:
       pickle.dump(clf, f)
    print("saving predictions")
    preds_file = persistence.write_preds(yhat_full, model_dir, hids_dv, 'dev', ind2c, yhat_raw=yhat_full_raw)

    print("sanity check on train")
    yhat_tr = clf.predict(X)
    yhat_tr_raw = clf.predict_proba(X)

    #reshape output again
    yhat_tr_full = np.zeros(yy_tr.shape)
    yhat_tr_full_raw = np.zeros(yy_tr.shape)
    j = 0
    for i in range(yhat_tr_full.shape[1]):
        if i in labels_with_examples:
            yhat_tr_full[:,i] = yhat_tr[:,j]
            yhat_tr_full_raw[:,i] = yhat_tr_raw[:,j]
            j += 1

    #evaluate
    metrics_tr = evaluation.all_metrics(yhat_tr_full, yy_tr, k=[8, 15], yhat_raw=yhat_tr_full_raw)
    evaluation.print_metrics(metrics_tr)

    if n > 0:
        print("generating top important ngrams")
        if 'bows' in dev_fname:
            dev_fname = dev_fname.replace('_bows', '')
        print("calculating top ngrams using file %s" % dev_fname)
        calculate_top_ngrams(dev_fname, clf, c2ind, w2ind, labels_with_examples, n)


    print("saving metrics")
    metrics_hist = defaultdict(lambda: [])
    metrics_hist_tr = defaultdict(lambda: [])
    for name in metrics.keys():
        metrics_hist[name].append(metrics[name])
    for name in metrics_tr.keys():
        metrics_hist_tr[name].append(metrics_tr[name])
    metrics_hist_all = (metrics_hist, metrics_hist, metrics_hist_tr)
    persistence.save_metrics(metrics_hist_all, model_dir)

def calculate_top_ngrams(inputfile, clf, c2ind, w2ind, labels_with_examples, n, outdir=DATA_DIR, proxy=False):
    
    #Reshape the coefficients matrix back into having 0's for columns of codes not in training set.
    labels_with_examples = set(labels_with_examples)
    mat = clf.coef_
    if proxy:
        mat_full=mat
    else:
        mat_full = np.zeros((8922, mat.shape[1]))
        j = 0
        for i in range(mat_full.shape[0]):
            if i in labels_with_examples:
                mat_full[i,:] = mat[j,:]
                j += 1
    #write out to csv
    f = open("%s/top_ngrams.csv" % outdir, 'w')
    writer = csv.writer(f, delimiter = ',')
    #write header
    writer.writerow(['SUBJECT_ID', 'HADM_ID', 'LABEL', 'INDEX', 'NGRAM', 'SCORE'])
            
    #get text as list of strings for each record in dev set
    with open("%s" % (inputfile), 'r') as notesfile:
        reader = csv.reader(notesfile)
        next(reader)

        all_rows = []
        for i,row in tqdm(enumerate(reader),'Calculating top N grams'):
                        
            text = row[2]
            hadm_id = row[1]
            subject_id = row[0]
            labels = row[3].split(';')
            
            #for each text, label pair, calculate heighest weighted n-gram in text
            for label in labels:
                myList = []
                
                #subject id
                myList.append(subject_id)
                #hadm id
                myList.append(hadm_id)

                #augmented coefficients matrix has dims (5000, 51918) (num. labels, size vocab.)
                #get row corresponding to label:
                word_weights = mat_full[c2ind[label]]
                
                #get each set of n grams in text
                #get ngrams
                fourgrams = nltk.ngrams(text.split(), n)
                fourgrams_scores = []
                for grams in fourgrams:
                    #calculate score
                    sum_weights = 0
                    for word in grams:
                        if word in w2ind:
                            inx = w2ind[word]
                            if inx < len(word_weights):
                                #add coeff from logistic regression matrix for given word
                                sum_weights = sum_weights + word_weights[inx]
                        else:
                            #else if word not in vocab, adds 0 weight
                            pass
                    fourgrams_scores.append(sum_weights)
                 
                #get the fourgram itself
                w = [word for word in text.split()][fourgrams_scores.index(max(fourgrams_scores)):fourgrams_scores.index(max(fourgrams_scores))+n]
                    
                #label
                myList.append(label)
                #start index of 4-gram
                myList.append(fourgrams_scores.index(max(fourgrams_scores)))
                #4-gram
                myList.append(" ".join(w))
                #sum weighted score (highest)
                myList.append(max(fourgrams_scores))
            
                writer.writerow(myList)
                
    f.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="train a neural network on some clinical documents")
    parser.add_argument("Y", type=str, help="size of label space")
    parser.add_argument("data_path", type=str,
                        help="path to a file containing sorted train data. dev/test splits assumed to have same name format with 'train' replaced by 'dev' and 'test'")
    parser.add_argument("vocab", type=str, help="path to a file holding vocab word list for discretizing words")
    parser.add_argument("version", type=str, help="which mimic you are using", choices=['mimic2','mimic3'])
    parser.add_argument("n", type=int, help="n-gram size")
    parser.add_argument("--public-model", dest="public_model", action="store_const", required=False, const=True,
                        help="optional flag for testing pre-trained models from the public github")
    args = parser.parse_args()
    main(args)

