# Analyses for Appendix "Reproducing plausibility scores"
#   we extract each (re)trained model's score for the four explanations
#   in the original clinical evaluation and assign each model
#   to its highest-ranked explanation (allowing conflicts)
#   This produces the "Ours" column in the final table.

import csv
import numpy as np
from collections import defaultdict

import constants
import compare_explanations


def load_caml():
  infn = f"{constants.PREDS_DIR}/DRCAML_mimic3_full/plausibility_annotations_scored.csv"
  with open(infn) as inf:
    reader = csv.reader(inf)
    headers = next(reader)
    hadm_idx = headers.index("HADM_ID")
    ngram_idxs = [headers.index(f"NGRAM-{i}") for i in range(4)]
    scores_idx = headers.index("SCORES")
    rankings_idx = headers.index("RANKING")

    data = {}
    for row in reader:
      hadm_id = row[hadm_idx]
      ngrams = [row[i] for i in ngram_idxs]
      rankings = list(map(int, row[rankings_idx].split(";")))
      best_ngram = ngrams[rankings.index(0)]

      key = (hadm_id, ";".join(sorted(set(ngrams))))
      data[key] = best_ngram
    
  return data


def load_logreg():
  infn = f"{constants.MODEL_DIR}/log_reg_Jan_13_12:44/plausibility_annotations_scored.csv"
  with open(infn) as inf:
    reader = csv.reader(inf)
    headers = next(reader)
    hadm_idx = headers.index("HADM_ID")
    ngram_idxs = [headers.index(f"NGRAM-{i}") for i in range(4)]
    scores_idx = headers.index("SCORES")
    rankings_idx = headers.index("RANKING")

    data = {}
    for row in reader:
      hadm_id = row[hadm_idx]
      ngrams = [row[i] for i in ngram_idxs]
      rankings = list(map(int, row[rankings_idx].split(";")))
      best_ngram = ngrams[rankings.index(0)]

      key = (hadm_id, ";".join(sorted(set(ngrams))))
      data[key] = best_ngram
    
  return data


def load_cosine():
  infn = f"{constants.PROXY_DIR}/cosine_rankings.csv"
  with open(infn) as inf:
    reader = csv.reader(inf)
    headers = next(reader)
    hadm_idx = headers.index("HADM_ID")
    ngrams_idx = headers.index("NGRAMS")
    rankings_idx = headers.index("RANKING")

    data = {}
    for row in reader:
      hadm_id = row[hadm_idx]
      ngrams = row[ngrams_idx].split(";")
      rankings = list(map(int, row[rankings_idx].split(";")))
      best_ngram = ngrams[rankings.index(0)]

      key = (hadm_id, ";".join(sorted(set(ngrams))))
      data[key] = best_ngram

  return data

def load_cnn():
  infn = f"{constants.MODEL_DIR}/cnn_vanilla_Mar_03_17:51:53/cnn_rankings.csv"
  with open(infn) as inf:
    reader = csv.reader(inf)
    headers = next(reader)
    hadm_idx = headers.index("HADM_ID")
    off = 3

    data = {}
    for row in reader:
      hadm_id = row[hadm_idx]
      n_ngrams = (len(row) - off) // 3
      ngrams = row[off:off + n_ngrams]
      scores = row[off + n_ngrams: off + 2 * n_ngrams]
      rankings = row[off + n_ngrams * 2: off + off * n_ngrams]
      rankings = list(map(int, rankings))
      best_ngram = ngrams[rankings.index(0)]
  
      key = (hadm_id, ";".join(sorted(set(ngrams))))
      data[key] = best_ngram

  return data


def merge(handler):
  cnn = load_cnn()
  caml = load_caml()
  logreg = load_logreg()
  cosine = load_cosine()

  names = {0: "cnn", 1: "caml", 2: "logreg", 3: "cosine"}
  d = {0: cnn, 1: caml, 2: logreg, 3: cosine}

  data = defaultdict(list)
  for i in range(4):
    for key in d[i]:
      data[key].append(d[i][key])

  good = defaultdict(int)
  great = defaultdict(int)
  total_offset = 0

  for annotation in handler.annotations:
    hadm_id = handler.annotations_to_hadm.get(annotation['id'])
    if hadm_id is None:
      continue

    ngrams = []
    scores = {}
    for exp in annotation["explanations"]:
      ngrams.append(exp["ngram"])
      scores[exp["ngram"]] = exp["annotation"]

    counts = defaultdict(int)
    for ngram in ngrams:
      counts[ngram] += 1

    key = (hadm_id, ";".join(sorted(set(ngrams))))
    for i, ngram in enumerate(data[key]):
      counts[ngram] -= 1
      if scores[ngram] == "++":
        good[i] += 1
        great[i] += 1
      elif scores[ngram] == "+":
        good[i] += 1

    offset = sum(val for val in counts.values())
    total_offset += offset
    print(offset, end=', ')
  print()
    
  print(total_offset)
  for i in range(4):
    print(names[i], good[i], great[i])
