# Load in the saved explanations from each method
#   and produce the plausibility scores in Section 5 of the paper

import csv
import logging
import numpy as np
from collections import defaultdict

from sklearn.linear_model import LogisticRegression
from scipy.sparse import hstack
from scipy.sparse.base import spmatrix

import constants
import compare_explanations
import bootstrap_annotations

MODELS = ["caml", "drcaml", "cosine", "logreg", "cnn", "proxy"]


def load_cnn(**kwargs):

  infn = f"{constants.PROXY_DIR}/cnn_plausibility_annotations.csv"
  with open(infn) as inf:
    reader = csv.reader(inf)
    headers = next(reader)
    ngram_idxs = [headers.index(f"NGRAM-{i}") for i in range(4)]
    hadm_idx = headers.index("HADM_ID")
    explanation_idx = headers.index("14GRAM")

    data = {}
    for row in reader:
      hadm_id = row[hadm_idx]
      expl = row[explanation_idx]

      ngrams = [row[i] for i in ngram_idxs]
      key = (hadm_id, ";".join(sorted(set(ngrams))))
      data[key] = expl

  return data


def load_caml(agg="avg", **kwargs):
  assert agg in ["avg", "first"]

  fn_base = "CAML_mimic3_full/plausibility_annotations_explanations"
  infn = f"{constants.PREDS_DIR}/{fn_base}-{agg}.csv"
  with open(infn) as inf:
    reader = csv.reader(inf)
    headers = next(reader)
    ngram_idxs = [headers.index(f"NGRAM-{i}") for i in range(4)]
    hadm_idx = headers.index("HADM_ID")
    explanation_idx = headers.index("14GRAM")

    data = {}
    for row in reader:
      hadm_id = row[hadm_idx]
      expl = row[explanation_idx]

      ngrams = [row[i] for i in ngram_idxs]
      key = (hadm_id, ";".join(sorted(set(ngrams))))
      data[key] = expl

  return data


def load_drcaml(agg="avg", **kwargs):
  assert agg in ["avg", "first"]

  fn_base = "DRCAML_mimic3_full/plausibility_annotations_explanations"
  infn = f"{constants.PREDS_DIR}/{fn_base}-{agg}.csv"
  with open(infn) as inf:
    reader = csv.reader(inf)
    headers = next(reader)
    hadm_idx = headers.index("HADM_ID")
    ngram_idxs = [headers.index(f"NGRAM-{i}") for i in range(4)]
    explanation_idx = headers.index("14GRAM")

    data = {}
    for row in reader:
      hadm_id = row[hadm_idx]
      expl = row[explanation_idx]
      ngrams = [row[i] for i in ngram_idxs]
      key = (hadm_id, ";".join(sorted(set(ngrams))))
      data[key] = expl

  return data


def load_logreg(agg="avg", **kwargs):
  fn_base = "CAML_mimic3_full/plausibility_annotations_explanations"
  infn = f"{constants.PREDS_DIR}/{fn_base}-{agg}.csv"
  with open(infn) as inf:
    reader = csv.reader(inf)
    headers = next(reader)
    hadm_idx = headers.index("HADM_ID")
    ngram_idxs = [headers.index(f"NGRAM-{i}") for i in range(4)]
    explanation_idx = headers.index("14GRAM")

    data = {}
    for row in reader:
      hadm_id = row[hadm_idx]
      expl = row[explanation_idx]
      ngrams = [row[i] for i in ngram_idxs]
      key = (hadm_id, ";".join(sorted(set(ngrams))))
      data[key] = expl

  return data


def load_cosine(handler=None, **kwargs):

  if handler is None:
    handler = compare_explanations.AnnotationHandler()

  annotations_dict = {a['id']: a for a in handler.annotations}

  infn = f"{constants.PROXY_DIR}/cosine_sim_exps.csv"
  with open(infn) as inf:
    reader = csv.reader(inf)
    headers = next(reader)
    hadm_idx = headers.index("HADM_ID")
    annotation_idx = headers.index("ANNOTATION_ID")
    explanation_idx = headers.index("EXPLANATION")

    data = {}
    for row in reader:
      hadm_id = row[hadm_idx]
      annotation_id = int(row[annotation_idx])
      expl = row[explanation_idx]

      annotation = annotations_dict[annotation_id]
      ngrams = [exp["ngram"] for exp in annotation["explanations"]]
      key = (hadm_id, ";".join(sorted(set(ngrams))))
      data[key] = expl

  return data


def load_proxy(handler=None, **kwargs):

  if handler is None:
    handler = compare_explanations.AnnotationHandler()

  annotations_dict = {a['id']: a for a in handler.annotations}

  infn = f"{constants.PROXY_DIR}/top_proxy_ngrams.csv"
  with open(infn) as inf:
    reader = csv.reader(inf)
    headers = next(reader)
    hadm_idx = headers.index("HADM_ID")
    label_idx = headers.index("LABEL")
    explanation_idx = headers.index("NGRAM")

    data = {}
    for row in reader:
      hadm_id = row[hadm_idx]
      label = row[label_idx]

      if hadm_id not in handler.hadm_to_annotations:
        continue

      for idx in handler.hadm_to_annotations[hadm_id]:
        annotation = annotations_dict[idx]
        if label == annotation["code"]:
          expl = row[explanation_idx]
          ngrams = [exp["ngram"] for exp in annotation["explanations"]]
          key = (hadm_id, ";".join(sorted(set(ngrams))))
          data[key] = expl

  return data


def get_keys_of_caml_table1(keys):
  unique_ngrams = ["large mucus plug on",
                   "hepatic artery branch pseudoaneurysm",
                   "systolic dysfunction with ef"]
  new_keys = [None, None, None]
  for key in keys:
    hadm_id, ngrams = key
    for i, ngram in enumerate(unique_ngrams):
      if ngram in ngrams:
        assert new_keys[i] is None
        new_keys[i] = key
  return new_keys


def bootstrap(scores, k=1000):
  np.random.seed(42)
  counts = defaultdict(list)
  for _ in range(k):
    for name, score_list in scores.items():
      new_sum = np.sum(np.random.binomial(1, score_list))
      counts[name].append(new_sum)

  print("{:6s} {:4s} {:5s}".format("model", "2.5%", "97.5%"))
  for name in counts:
    p2_5, p97_5 = np.percentile(counts[name], [2.5, 97.5])
    print("{:6s} {:4.1f} {:5.1f}".format(name, p2_5, p97_5))


def main(handler, qd, transform, agg='avg'):
  full_model = LogisticRegression()
  full_model.fit(qd.X, qd.y)
  new_labels = full_model.predict_proba(qd.X)[:, 1]
  cutoff = np.percentile(new_labels, 100 * (1 - np.mean(qd.y)))

  funcs = {"caml": load_caml, "drcaml": load_drcaml,
           "logreg": load_logreg, "cosine": load_cosine,
           "proxy": load_proxy, "cnn": load_cnn}
  names = dict(enumerate(MODELS))
  explanations = {i: funcs[name](agg=agg, handler=handler)
                  for i, name in names.items()}
  scores = {name: [] for name in names.values()}

  all_keys = []
  for annotation in handler.annotations:
    hadm_id = handler.annotations_to_hadm[annotation['id']]
    ngrams = [exp["ngram"] for exp in annotation["explanations"]]
    key = (hadm_id, ";".join(sorted(set(ngrams))))
    all_keys.append(key)

    code_vec = transform([annotation['code_description']]).reshape(1, -1)

    for i, name in names.items():
      exp_vec = transform([explanations[i][key]]).reshape(1, -1)
      # print(i, code_vec.shape, code_vec, exp_vec.shape, exp_vec)

      if isinstance(exp_vec, spmatrix):
        vec = hstack([code_vec, exp_vec])
      else:
        vec = np.concatenate([code_vec, exp_vec], axis=1)
      score = full_model.predict_proba(vec)
      scores[name].append(score[:, 1][0])

  for name in names.values():
    print(name, np.sum(np.array(scores[name]) > cutoff))

  # table_keys = get_keys_of_caml_table1(all_keys)
  # for key in table_keys:
  #   for i, name in names.items():
  #     print(name, explanations[i][key])

  return scores

# if __name__ == "__main__":
#   import importlib, bootstrap_annotations as ba, merge_explanations as me, compare_explanations as ce
#   transform = importlib.reload(ba.load_word2vec_model())
#   handler = importlib.reload(ce.AnnotationHandler())
#   qd = importlib.reload(ba.QuestionDataset(handler.annotations, transform))
#   scores = importlib.reload(me.main(handler, qd, transform))
#   importlib.reload(me.bootstrap(scores))
