# Reimplementation of the cosine similarity baseline for explanations

import re
import csv
import json
import cdifflib

import numpy as np
from collections import defaultdict
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

import constants
import compare_explanations


def get_train_notes():
  with open(f"{constants.MIMIC_3_DIR}/test_full.csv") as inf:
    reader = csv.reader(inf)
    for obj in reader:
      _, hadm_id, text, _, _ = obj
      yield text
  with open(f"{constants.MIMIC_3_DIR}/ICD9_descriptions") as inf:
    reader = csv.reader(inf, delimiter='\t')
    next(reader)
    for obj in reader:
      code, description = obj
      yield description



def rank_similar_annotations():
  tfidf = TfidfVectorizer()
  tfidf.fit(get_train_notes())
  handler = compare_explanations.AnnotationHandler()
  annotations = [a for a in handler.annotations
                 if a['id'] in handler.annotations_to_hadm]
  inf, hiinf = 0, 0
  tot_inf, tot_hiinf = [0] * len(annotations), [0] * len(annotations)
  context = []
  rankings = []
  for i, annotation in enumerate(annotations):
    code_vec = tfidf.transform([annotation['code_description']])
    ngram_sims = []
    ngrams = []
    best_ngram = None
    best_sim = -1
    for j, exp in enumerate(annotation['explanations']):
      if exp['annotation'] == '+':
        tot_inf[i] = 1
      if exp['annotation'] == '++':
        tot_hiinf[i] = 1

      ngram = exp["ngram"]
      exp_vec = tfidf.transform([ngram])
      sim = cosine_similarity(code_vec, exp_vec)
      ngrams.append(ngram)
      ngram_sims.append(sim)

    best_ngram = np.argmax(ngram_sims)
    ranking = sorted([0, 1, 2, 3], key=lambda x: ngram_sims[x], reverse=True)
    rankings.append(";".join(map(str, ranking)))

    hadm_id = handler.annotations_to_hadm[annotation['id']]
    context.append([hadm_id, ";".join(ngrams)])

    if annotation["explanations"][best_ngram]["annotation"] == '+':
      inf += 1
    if annotation["explanations"][best_ngram]["annotation"] == '++':
      hiinf += 1

    # yield best_ngram
  # print(inf, hiinf)
  # print(sum(tot_inf), sum(tot_hiinf))

  for (hadm, ngrams), ranking in zip(context, rankings):
    print("{},{},{}".format(hadm, ngrams, ranking))


def main():
  tfidf = TfidfVectorizer()
  tfidf.fit(get_train_notes())
  annotations, counts, text_regexes = compare_explanations.load_plausibility_annotations()
  annotation_to_hadm, full_texts = compare_explanations.match_annotations_to_hadm(
      annotations, counts, text_regexes)

  matches = 0
  for annotation in annotations:
    code_vec = tfidf.transform([annotation['code_description']])
    
    hadm_id = annotation_to_hadm.get(annotation["id"])
    if hadm_id is None:
      continue
    text = full_texts[hadm_id].split()
    text_ngrams = [" ".join(text[i: i + 4]) for i in range(len(text) - 4)]
    text_vecs = tfidf.transform(text_ngrams)
    best_i = -1
    best_sim = 0
    for i, vec in enumerate(text_vecs):
      sim = cosine_similarity(code_vec, vec)
      if sim > best_sim:
        best_sim = sim
        best_i = i

    best_ngram = text_ngrams[best_i]
    if best_i == -1:
      print("no match!")
    else:
      print(f"best ngram: {best_ngram}")

    for exp in annotation['explanations']:
      if best_ngram == exp['ngram']:
        matches += 1
        break

  print(f"{matches} matches")


if __name__ == "__main__":
  # main()
  rank_similar_annotations()
