# This sets up the proxy model for DR-CAML
# save_dataset() and load_data() set up the problem where X is a bag-of-words representation of the discharge summary and "real" is the continuous prediction output of DR-CAML, and "binary" is the thresholded predictions of DR-CAML. We use fpzip to compress the real-valued outputs.
# The proxy model is trained with "prediction_regression()" which trains a sklearn SGDRegressor for each code independently


# import pdb; pdb.set_trace()

import argparse
import csv
import fpzip
import gzip
import json
import logging
import os
import random
import time
import warnings

from collections import defaultdict
from nltk.tokenize import RegexpTokenizer
from sklearn.model_selection import RandomizedSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import SGDRegressor
from sklearn.exceptions import ConvergenceWarning

import numpy as np
import sklearn.metrics as metrics
import scipy.sparse
import scipy.stats as stats

import utils
import constants


class Vocab:
  def __init__(self):
    self.d = {}
    self.count = 0
    self.frozen = False

  def idx(self, word, frozen=False):
    if word not in self.d:
      if frozen or self.frozen:
        return -1

      self.d[word] = self.count
      self.count += 1

    return self.d[word]


def load_vocab():
  vocab = Vocab()
  with open(f"{constants.MIMIC_3_DIR}/vocab.csv") as inf:
    for term in inf:
      vocab.idx(term.strip())
  vocab.frozen = True
  return vocab

def load_c2ind():
  with open(f"{constants.PROXY_DIR}/c2ind.json") as inf:
    c2ind = json.load(inf)
    ind2c = {index: code for code, index in c2ind.items()}
  return c2ind, ind2c


def get_prediction_function(labeli):
  # n_features = 140796
  vocab = load_vocab()
  n_features = len(vocab.d)
  tokenizer = RegexpTokenizer(r'\w+')
  with gzip.open(f"{constants.PROXY_DIR}/clf.{labeli}.json.gz") as inf:
    clf = utils.sgd_classifier_from_json(SGDRegressor, inf.readline().decode())

  def predict(docs):
    if type(docs) == str:
      docs = [docs]
    yhat_raw = []
    for text in docs:
      row = np.zeros([n_features])
      for term in tokenizer.tokenize(text):
        if term.isnumeric():
          continue
        term = term.lower()
        word = vocab.idx(term)
        if word >= 0:
          row[word] = 1 
      log_prob = clf.predict(row.reshape(1, -1))
      prob = np.exp(log_prob)
      yhat_raw.append(np.array([1-prob, prob]).T)
    return np.concatenate(yhat_raw, axis=0)
  return predict


def save_dataset():
  preds_dir = "{constants.PREDS_DIR}/DRCAML_mimic3_full"
  notes_fn = f"{constants.MIMIC_3_DIR}/notes_labeled.csv"
  c2ind, _ = load_c2ind()
  vocab = load_vocab()
  notes = {}
  with open(notes_fn) as inf:
    reader = csv.reader(inf, delimiter=",")
    for line in reader:
      (_, hadm_id, note, codes) = line
      tmp = set()
      for token in note.split():
        word = vocab.idx(token)
        if word >= 0:
          tmp.add(word)
      notes[hadm_id] = tmp

  X = {}
  y_bin = {}
  y_real = {}
  for split in ["train", "dev", "test"]:
    data = []
    indptr = [0]
    indices = []
    y_bin_split = []
    y_real_split = []

    with open(os.path.join(preds_dir, "pred_all_scores_{}.json".format(split))) as inf:
      preds = json.load(inf)
    all_codes = sorted(c2ind, key=lambda x: c2ind[x])

    with open(os.path.join(preds_dir, "preds_{}.psv".format(split))) as inf:
      reader = csv.reader(inf, delimiter="|")
      for line in reader:
        hadm_id = line[0]

        if hadm_id not in preds:
          continue
        if hadm_id not in notes:
          continue

        indptr.append(indptr[-1] + len(notes[hadm_id]))
        data.extend([1 for _ in notes[hadm_id]])
        indices.extend([x for x in notes[hadm_id]])

        codes = [c2ind[code] for code in line[1:] if code in c2ind]
        y_bin_split.append(codes)

        y_real_split.append([preds[hadm_id][code] for code in all_codes])

    X[split] = scipy.sparse.csr_matrix((data, indices, indptr),
                                       shape=(len(indptr) - 1, len(vocab.d)))
    with open(f"{constants.PROXY_DIR}/X_{split}_csr.npz", "wb") as outf:
      scipy.sparse.save_npz(outf, X[split])

    y_bin[split] = y_bin_split
    with open(f"{constants.PROXY_DIR}/y_bin_{split}.json", "w") as outf:
      outf.write(json.dumps(y_bin_split))

    y_real[split] = y_real_split
    with gzip.open(f"{constants.PROXY_DIR}/y_real_{split}.json.gz", "wb") as outf:
      outf.write(json.dumps(y_real_split).encode("ascii"))


def load_data(real_compression='fpzip16'):
  X = {}
  real = {}
  binary = {}
  truth = {}

  datadir = constants.PROXY_DIR
  c2ind, _ = load_c2ind()

  lengths = {}
  # n_train_features = -1
  splits = ["train", "dev", "test"]
  for split in splits:
    with open("{}/X_{}_csr.npz".format(datadir, split), "rb") as inf:
      X[split] = scipy.sparse.load_npz(inf)
      lengths["n_{}".format(split)] = X[split].shape[0]

    # if split == 'train':
    #   n_train_features = X[split].shape[1]
    # else:
    #   X[split] = X[split][:, :n_train_features]

    with open("{}/y_bin_{}.json".format(datadir, split)) as inf:
      y_load = json.loads(inf.readline())
      indptr = [0]
      data = []
      indices = []
      all_codes = set()
      for row in y_load:
        row = set(row)

        indptr.append(indptr[-1] + len(row))
        data.extend([1 for _ in row])
        indices.extend([x for x in row])
        for obj in row:
          all_codes.add(obj)

      if len(data) == 0:
        binary[split] = []
      else:
        csr = scipy.sparse.csr_matrix((data, indices, indptr),
                                      shape=(len(indptr) - 1, len(c2ind)))
        binary[split] = scipy.sparse.csc_matrix(csr)
        
    with gzip.open("{}/y_real_{}.{}.gz".format(datadir,split, real_compression), "rb") as inf:

      if real_compression == 'fpzip16':
        data = b"".join(inf.readlines())
        arr = fpzip.decompress(data, order="C")
        real[split] = arr[0, 0, :, :]

      elif real_compression == 'npy':
        real[split] = np.load(inf)

      elif real_compression == 'json':
        real[split] = np.array(json.loads(inf.readline().decode()), dtype=np.float32)

  X = utils.Splits(**X)
  real = utils.Splits(**real)
  # truth = utils.Splits(**truth)
  binary = utils.Splits(**binary)
  dataset = utils.MimicData(X=X, truth=truth, real=real, binary=binary, **lengths)

  return dataset


def evaluate(labels, preds, name='', printout=True):
  results = {}
  results['pearson'] = stats.pearsonr(preds, labels)[0]
  results['spearman'] = stats.spearmanr(preds, labels).correlation
  results['kendalltau'] = stats.kendalltau(preds, labels).correlation

  clipped_preds = np.clip(preds, 0, 1)
  mse = metrics.mean_squared_error(clipped_preds, labels)
  if printout:
    out = ["mse: {:.3e}".format(mse)]
    out.extend(["{}: {:.3f}".format(key, val) for key, val in results.items()])
    print(", ".join(out))

  results['mse'] = mse
  return results


def prediction_regression(data, model='linreg', seed=0, use_cv=False,
                          printout=True, save=False, job_id=None, num_jobs=None):
  '''
  For each code, predict the model's predicted probability for that label
    Requires prediction probabilities for each code, not binary labels
  '''
  random.seed(seed)
  np.random.seed(seed)

  if model == 'linreg':
    if use_cv:
      clf = SGDRegressor()
      param_dist = {'penalty': ['elasticnet', 'l1', 'l2'],
                    # 'max_iter': [1000, 2000],
                    'alpha': np.power(10., np.arange(-7, 0))}
      cv = RandomizedSearchCV(clf, n_iter=16, param_distributions=param_dist)
    else:
      clf = SGDRegressor(penalty='l1', alpha=1e-4)
      cv = clf
  else:
    raise ValueError("unk model {}".format(model))

  dir_ = constants.PROXY_DIR

  c2ind, ind2c = load_c2ind()

  if job_id == num_jobs:
    job_id = 0
  logging.warning(f"job {job_id} of {num_jobs} jobs")
  for code_idx in range(data.real.train.shape[1]):
    if num_jobs is not None and code_idx % num_jobs != job_id:
      continue

    code_name = ind2c[code_idx]
    results = {"code": code_name, "dev": {}, "test": {}}
    start = time.time()
    cv.fit(data.X.train, np.log(data.real.train[:, code_idx]))

    dev_preds = np.exp(cv.predict(data.X.dev))
    test_preds = np.exp(cv.predict(data.X.test))

    dev_binary = data.binary.dev[:, code_idx].todense()
    prec, rec, thresholds = metrics.precision_recall_curve(dev_binary, dev_preds)
    f1 = [2 * p * r / (p + r) for p, r in zip(prec, rec) if p + r > 0]
    best_f1_threshold = thresholds[np.argmax(f1)]

    dev_results = evaluate(data.real.dev[:, code_idx], dev_preds,
                                       name=model + "_", printout=printout)
    results["dev"]["regression"] = dev_results
    test_results = evaluate(data.real.test[:, code_idx], test_preds,
                                        name=model + "_", printout=printout)
    results["test"]["regression"] = test_results

    if hasattr(cv, 'cv_results_'):
      best_cv = cv.cv_results_['rank_test_score'].tolist().index(1)
      print("best cv params: {0}".format(cv.cv_results_['params'][best_cv]))
      clf = cv.best_estimator_
    else:
      clf = cv

    if save:
      with gzip.open(f"{dir_}/dev.{code_idx}.npy.gz", "w") as outf:
        np.save(outf, dev_preds)
        
      with gzip.open(f"{dir_}/test.{code_idx}.npy.gz", "w") as outf:
        np.save(outf, test_preds)
        
      with gzip.open(f"{dir_}/clf.{code_idx}.json.gz", "w") as outf:
        clf_json = utils.sgd_classifier_to_json(clf)
        outf.write(clf_json.encode("ascii"))

    test_binary = data.binary.test[:, code_idx].todense()
    splits = {"dev": [dev_binary, dev_preds], "test": [test_binary, test_preds]}
    for split in splits:
      binary, preds = splits[split]
      baseline = np.mean(binary)
      baseline = max(baseline, 1 - baseline)
      if np.unique(binary.tolist()).shape[0] > 1:
        prediction_auc = metrics.roc_auc_score(binary, preds)
        binary_preds = (preds > best_f1_threshold).astype(np.int32)
        accuracy = metrics.accuracy_score(binary, binary_preds)
        f1 = metrics.f1_score(binary, binary_preds)
        confusion = metrics.confusion_matrix(binary, binary_preds)

        results[split]["classification"] = {
            "baseline": baseline, "accuracy": accuracy, "f1": f1,
            "auc_score": prediction_auc, "confusion": confusion}
        if printout:
          print(" {:.1f}% 1s  auc: {:.3f}, f1 {:.3f}".format(
            100 * baseline, prediction_auc, f1))
      else:
        results[split]["classification"] = {"baseline": baseline}
        if printout:
          print(" {:.1f}% 1s".format(100 * baseline))

    if save:
      with open(f"{constants.PROXY_DIR}/results.json", "a") as outf:
        outf.write("{}\n".format(json.dumps(results, cls=utils.NumpySerializer)))


def prediction_classification(data, seed=0, use_cv=True):
  '''
  For each code, predict whether the model will predict that code
    Requires binary labels for the codes, not prediction probabilities
  '''
  random.seed(seed)
  np.random.seed(seed)

  if use_cv:
    clf = LogisticRegression(solver='saga')
    param_dist = {'penalty': ['l1', 'l2'], 'C': np.power(10., np.arange(-5, 1)),
                  'max_iter': [100, 500, 1000]}
    cv = RandomizedSearchCV(clf, n_iter=32, param_distributions=param_dist)
  else:
    clf = LogisticRegression(solver='saga')
    cv = clf

  for label in range(data.binary.train.shape[1]):
    train_binary_label = data.binary.train[:, label].toarray().reshape(-1)
    test_binary_label = data.binary.test[:, label].toarray().reshape(-1)
    start = time.time()
    with warnings.catch_warnings():
      warnings.filterwarnings("ignore", category=ConvergenceWarning)
      cv.fit(data.X.train, train_binary_label)
    test_bin_preds = cv.predict_proba(data.X.test)[:, 1]
    test_acc = cv.score(data.X.test, test_binary_label)
    print("fitting took {:.1f}m".format((time.time() - start) / 60))

    # extract best clf from cv
    if hasattr(cv, 'cv_results_'):
      best_cv = cv.cv_results_['rank_test_score'].tolist().index(1)
      print("best cv params: {0}".format(cv.cv_results_['params'][best_cv]))
      clf = cv.best_estimator_

    baseline = np.mean(test_binary_label)
    baseline = max(baseline, 1 - baseline)
    prediction_auc = metrics.roc_auc_score(test_binary_label, test_bin_preds)
    print("{} model gets test auc: {:.5f} test acc {:.1f} compared to baseline of {:.1f}".format(
      label, prediction_auc, 100 * test_acc, 100 * baseline))


def recompute_thresholds(data):
  c2ind, _ = load_c2ind()

  i = 0
  thresholds = {}
  for code, code_idx in c2ind.items():
    i += 1
    if i % 1000 == 0:
      logging.warn(i)
    with gzip.open(f"{constants.PROXY_DIR}/clf.{code_idx}.json.gz") as inf:
      clf = utils.sgd_classifier_from_json(SGDRegressor, inf.readline().decode())
      train_preds = np.exp(clf.predict(data.X.train))
      train_binary = data.binary.train[:, code_idx].todense()
      if np.mean(train_binary) > 0:
        prec, rec, thresh = metrics.precision_recall_curve(
            train_binary, train_preds)
        f1 = [2 * p * r / (p + r) for p, r in zip(prec, rec) if p + r > 0]
        thresholds[code] = thresh[np.argmax(f1)]
      else:
        thresholds[code] = 0.5

  return thresholds
 

def compile_nice_json(split="dev"):
  dir_ = constants.PROXY_DIR
  with open(f"{dir_}/{split}_hadm.txt") as inf:
    hadm_ids = [int(x) for x in inf]
  c2ind, ind2c = load_c2ind()

  d = defaultdict(dict)
  for code_idx, code in ind2c.items():
    # if not os.path.exists(f"{dir_}/{split}.{code_idx}.npy.gz"):
    #   continue
    with gzip.open(f"{dir_}/{split}.{code_idx}.npy.gz") as inf:
      preds = np.load(inf)
      for hadm_id, pred in zip(hadm_ids, preds):
        d[hadm_id][code] = pred

  return d


def main():
  parser = argparse.ArgumentParser()
  parser.add_argument('--job_id', type=int, default=None)
  parser.add_argument('--num_jobs', type=int, default=None)
  args = parser.parse_args()

  dataset = load_data()
  prediction_regression(
      dataset, save=True, printout=False,
      job_id=args.job_id, num_jobs=args.num_jobs)


if __name__ == "__main__":
  # save_dataset()
  main()
