import numpy as np
import os
import pickle
import glob
import argparse
import gzip
import json
from tqdm import tqdm
from sklearn.linear_model import SGDRegressor
import datasets
from log_reg import calculate_top_ngrams
import utils
import mimic_proxy
import constants

class Proxy:
    def __init__(self, coef):
        self.coef_ = coef

def get_proxy_weights(model_dir):
    files = glob.glob(os.path.join(model_dir, "clf.*.json.gz"))
    coefs = []
    for fn in tqdm(files):
        with gzip.open(fn) as inf:
            clf = utils.sgd_classifier_from_json(SGDRegressor, inf.readline().decode())
            coefs.append(clf.coef_)
    coefs = np.stack(coefs)
    return Proxy(coefs)

def get_proxy_dicts():
    with gzip.open(f"{constants.PROXY_DIR}/vocab.json.gz") as inf:
        w2ind = json.loads(inf.readline().decode())
    return w2ind

def main(args):
    dicts = {}
    if args.proxy:
        dicts['w2ind'] = get_proxy_dicts()
        caml_dicts = datasets.load_lookups(args)
        dicts['c2ind'] = caml_dicts['c2ind']
        model = get_proxy_weights(args.model_path)
        dataset = mimic_proxy.load_data()
        X, yy = dataset.X.dev[:, :140796], dataset.binary.dev.todense()
    else:    
        dicts = datasets.load_lookups(args)
        model = pickle.load(open(args.model_path, 'rb'))
        X, yy = utils.read_mimic(args.data_path, dicts, args.version, args.Y)

    print("calculating top ngrams using file %s" % args.data_path)
    labels_with_examples = yy.sum(axis=0).nonzero()[0]
    calculate_top_ngrams(args.data_path, model, dicts["c2ind"], dicts["w2ind"], labels_with_examples, 
                        args.n, outdir=args.outdir, proxy=True)

    print()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="train a neural network on some clinical documents")
    parser.add_argument("data_path", type=str,
                        help="path to a file containing sorted data.")
    parser.add_argument("vocab", type=str, help="path to a file holding vocab word list for discretizing words")
    parser.add_argument("version", type=str, help="which mimic you are using", choices=['mimic2','mimic3'])
    parser.add_argument("model_path", type=str, help="path to trained model")
    parser.add_argument("outdir", type=str, help="path save ngrams")
    parser.add_argument("n", type=int, help="n-gram size")
    parser.add_argument('--Y', default='full')
    parser.add_argument('--version', default='mimic3')
    parser.add_argument("--proxy", dest="proxy", action="store_true", default=False,
                        help="flag for using ")
    parser.add_argument("--public-model", dest="public_model", action="store_const", required=False, const=True,
                        help="optional flag for testing pre-trained models from the public github")
    args = parser.parse_args()
    main(args)
