# This pulls in the clinical annotations of Mullenbach et al's models
#   Note we do not yet have permission to share mimic_plausiblity.jsonl
# This is a helper script used by merge_explanations and bootstrap_annotations

import re
import csv
import json
import cdifflib

from collections import defaultdict

import constants


def load_plausibility_annotations():
  with open(f"{constants.PROXY_DIR}/c2ind.json") as inf:
    c2ind = json.load(inf)

  with open(f"{constants.PROXY_DIR}/mimic_plausibility.jsonl") as inf:
    raw = [json.loads(line) for line in inf]
  code_regex = re.compile("Code: (\S+) Full descriptions: (.*)$")
  exp_regex = re.compile("^([A-D]\)) (\+*(?:\/-)? )?(.*)$")

  annotations = []
  for obj in raw:
    code_search = code_regex.search(obj["code"])
    code = code_search.group(1)
    code_description = code_search.group(2)
    new_obj = {"code": code, "code_ind": c2ind[code],
               "code_description": code_description,
               "id": obj["id"], "explanations": []}
    for i in "ABCD":
      match = exp_regex.search(obj[i])
      _, annotation, explanation = match.groups()
      if annotation is not None:
        annotation = annotation.strip()

      words = explanation.split() 
      if len(words) == 9:
        ngram = " ".join(words[:4])
      else:
        ngram = " ".join(words[5:9])
      new_obj["explanations"].append(
          {"annotation": annotation,
           "text": explanation,
           "ngram": ngram})

    annotations.append(new_obj)

  text_regexes = {}
  counts = defaultdict(int)
  for annotation in annotations:
    for exp in annotation['explanations']:
      regex = re.compile(exp["text"])
      if regex not in text_regexes:
        idx = annotation["id"]
        text_regexes[regex] = idx
        counts[idx] += 1

  return annotations, counts, text_regexes


def load_explanations(full_texts, model="DRCAML"):
  explanations = {}
  with open(f"{constants.PREDS_DIR}/{model}_mimic3_full/test_explanations_0.csv") as inf:
    reader = csv.reader(inf)
    for line in reader:
      if line[0] in full_texts:
        explanations[line[0]] = line[1:]

  explanations_idxs = {}
  with open(f"{constants.PREDS_DIR}/{model}_mimic3_full/test_explanations_idxs_0.csv") as inf:
    reader = csv.reader(inf)
    for line in reader:
      if line[0] in full_texts:
        explanations_idxs[line[0]] = line[1:]

  return explanations, explanations_idxs


def check(annotations, annotation_to_hadm, full_texts, explanations, explanations_idxs):
  with open("annotation_check.txt", "w") as outf:
    for annotation in annotations:
    # for idx in [11, 14]:
    #   annotation = annotations[idx]
      hadm_id = annotation_to_hadm.get(annotation["id"])
      if hadm_id is None:
        continue
      full_text = full_texts[hadm_id].split()
      code_index = annotation["code_ind"]
      our_explanation = explanations[hadm_id][code_index]
      our_words = explanations[hadm_id][code_index].split()
      # our_explanation_idx = int(explanations_idxs[hadm_id][code_index])
      # our_ngram = " ".join(full_text[our_explanation_idx:our_explanation_idx + 4])
      outf.write(f"{annotation['id']}: '{our_explanation}'\n")
      for explanation in annotation['explanations']:
        words = explanation['text'].split()
        overlaps = cdifflib.CSequenceMatcher(None, our_words, words)
        _, b_pos, size = overlaps.find_longest_match(
            0, len(our_words), 0, len(words))

        # word_overlap = len(text[b_pos:b_pos + size].split())
        if size > 1:
          before = ' '.join(words[0:b_pos])
          match = ' '.join(words[b_pos:b_pos + size])
          after = ' '.join(words[b_pos + size:])
          output = f"{size:2} {before} [{match}] {after}\n"
        else:
          text = " ".join(words)
          output = f"{size:2} {text}\n"
        outf.write(output)

      # outf.write(f"  us  {our_ngram}\n")
      # for explanation in annotation['explanations']:
      #   ngram = explanation["ngram"]
      #   if ngram in our_explanation:
      #     offset = our_explanation[:our_explanation.index(ngram)].split()
      #     offset = len([x for x in offset if len(x) > 0]) - 5
      #     outf.write(f" {offset:4} {ngram}\n")
      #   else:
      #     outf.write(f"      {ngram}\n")
      outf.write("\n")


def match_annotations_to_hadm(annotations, counts, text_regexes):
  annotation_to_hadm = {}
  full_texts = {}
  with open(f"{constants.MIMIC_3_DIR}/test_full.csv") as inf:
    reader = csv.reader(inf)
    for obj in reader:
      matches = defaultdict(int)
      matching_regexes = []
      _, hadm_id, text, _, _ = obj
      for regex, idx in text_regexes.items():
        if regex.search(text):
          matches[idx] += 1
          matching_regexes.append(regex)

      for idx in matches:
        if matches[idx] == counts[idx]:
          if idx in annotation_to_hadm:
            # TODO Could try to resolve the mismatch, but for now just dropping this annotation (#63)
            prev_hadm = annotation_to_hadm[idx]
            if list(annotation_to_hadm.values()).count(prev_hadm) == 1:
              print(idx)
              del full_texts[prev_hadm]
              del annotation_to_hadm[idx]
          else:
            full_texts[hadm_id] = text
            annotation_to_hadm[idx] = hadm_id

  return annotation_to_hadm, full_texts


class AnnotationHandler:
  def __init__(self):
    a, b, c = load_plausibility_annotations()
    d, e = match_annotations_to_hadm(a, b, c)
    self.annotations = a
    self.counts = b
    self.text_regexes = c
    self.annotations_to_hadm = d
    self.full_texts = e

    self.hadm_to_annotations = defaultdict(list)
    for k, v in self.annotations_to_hadm.items():
      self.hadm_to_annotations[v].append(k)

    self.ngram_spans = {}
    for i, annotation in enumerate(self.annotations):
      ngram_spans = []
      hadm_id = self.annotations_to_hadm.get(i)
      if hadm_id is None:
        continue
      full_text = self.full_texts[hadm_id]
      for exp in annotation['explanations']:
        ngram = exp["ngram"]
        for match in re.finditer(ngram, full_text):
          start = len(full_text[:match.start()].split())
          end = start + len(ngram.split())
          ngram_spans.append((start, end))
      self.ngram_spans[i] = ngram_spans


def main():
  annotations, counts, text_regexes = load_plausibility_annotations()

  annotation_to_hadm, full_texts = match_annotations_to_hadm(annotations, counts, text_regexes)

  explanations, explanations_idxs = load_explanations(full_texts, model="CAML")

  check(annotations, annotation_to_hadm, full_texts, explanations, explanations_idxs)

if __name__ == "__main__":
  main()
