# This script trains our classifier that labels explanations as plausible
#   We load in the clinical evaluation and train a simple classifier
#   to predict whether the clinician would have labeled a given explanation
#   as informative (plausible). We validate this and use it in our plausiblity evaluations.

from collections import defaultdict

import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.neural_network import MLPClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_validate
from sklearn.model_selection import LeaveOneOut, ShuffleSplit
from sklearn.metrics import precision_score, recall_score, accuracy_score, roc_auc_score
from sklearn.exceptions import ConvergenceWarning
# from sklearn.model_selection import 
from scipy.sparse import vstack, hstack, csr_matrix
from scipy.sparse.base import spmatrix
from gensim.models import KeyedVectors
# import sent2vec
import fasttext

import compare_explanations as ce
import cosine_sim
import constants
import warnings


def load_tfidf():
  tfidf = TfidfVectorizer()
  tfidf.fit(cosine_sim.get_train_notes())
  return tfidf.transform


def load_word2vec_vector():
  infn = f"{constants.PROXY_DIR}/BioWordVec_PubMed_MIMICIII_d200.vec.bin"
  model = KeyedVectors.load_word2vec_format(infn, binary=True)
  
  def word2vec(document):
    if type(document) == list:
      document = document[0]
    word_vectors = []
    for word in document.lower().split():
      try:
        word_vectors.append(model.get_vector(word))
      except KeyError:
        pass
    min_ = np.min(word_vectors, axis=0)
    max_ = np.max(word_vectors, axis=0)
    return np.concatenate([min_, max_], axis=0)

  return word2vec


def load_word2vec_model():
  infn = f"{constants.PROXY_DIR}/BioWordVec_PubMed_MIMICIII_d200.bin"
  model = fasttext.load_model(infn)

  def word2vec_transform(document):
    if type(document) == list:
      document = document[0]

    return model.get_sentence_vector(document)
  return word2vec_transform


def build(annotations, transform, heldout=None):
  X_code = []
  X_text = []
  y = []

  if heldout is not None:
    X_code_test = []
    X_text_test = []
    y_test = []

  for i, annotation in enumerate(annotations):
    code_vec = transform([annotation['code_description']])
    for j, exp in enumerate(annotation['explanations']):
      explanation_vec = transform([exp['text']])
      y_ = 0
      if exp['annotation'] in ['++', '+']:
        # six '+/-' annotations we treat as 0
        y_ = 1

      if heldout is not None and heldout[i] == j:
        X_code_test.append(code_vec)
        X_text_test.append(explanation_vec)
        y_test.append(y_)
      else:
        X_code.append(code_vec)
        X_text.append(explanation_vec)
        y.append(y_)

  if isinstance(X_code[0], spmatrix):
    X = hstack([vstack(X_code), vstack(X_text)])
    if heldout is not None:
      X_test = hstack([vstack(X_code_test), vstack(X_text_test)])
  else:
    X = np.concatenate([np.stack(X_code, axis=0), np.stack(X_text, axis=0)], axis=1)
    if heldout is not None:
      X_test = np.concatenate([np.stack(X_code_test, axis=0), np.stack(X_text_test, axis=0)], axis=1)

  if heldout is not None:
    return X, X_test, y, y_test

  return X, y


class QuestionDataset:
  def __init__(self, annotations, transform):
    X_code = []
    X_text = []
    y = []
    question_mapping = {}
    ngram_mapping = {}
    self.n_annotations = 0

    for i, annotation in enumerate(annotations):
      self.n_annotations += 1
      code_vec = transform([annotation['code_description']])

      start_index = len(y)
      ngrams = []
      for j, exp in enumerate(annotation['explanations']):
        text = exp['text']
        ngrams.append(exp["ngram"])
        explanation_vec = transform([text])
        y_ = 0
        if exp['annotation'] in ['++', '+']:
          # six '+/-' annotations we treat as 0
          y_ = 1

        X_code.append(code_vec)
        X_text.append(explanation_vec)
        y.append(y_)

      question_mapping[i] = (start_index, len(y))
      # ngram_idxs = {ngram: idx for idx, ngram in enumerate(np.unique(ngrams))}
      ngram_idxs = {}
      ngram_map = defaultdict(list)
      for j, ngram in enumerate(ngrams):
        if ngram not in ngram_idxs:
          ngram_idxs[ngram] = len(ngram_idxs)
        ngram_map[ngram_idxs[ngram]].append(j)
      ngram_mapping[i] = dict(ngram_map)

    if isinstance(X_code[0], spmatrix):
      X = hstack([vstack(X_code), vstack(X_text)])
    else:
      X = np.concatenate(
          [np.stack(X_code, axis=0), np.stack(X_text, axis=0)], axis=1)

    self.question_mapping = question_mapping
    self.ngram_mapping = ngram_mapping
    self.X = csr_matrix(X)
    self.y = y

  def get_X_y(self, idx, ngram_idx=None):
    (start, end) = self.question_mapping[idx]
    X_train = vstack([self.X[:start, ], self.X[end:, ]])
    X_test = self.X[start:end, ]
    y_train = np.concatenate([self.y[:start], self.y[end:]], axis=0)
    y_test = self.y[start:end]

    if ngram_idx is None:
      return X_train, X_test, y_train, y_test

    same_ngram = self.ngram_mapping[idx][ngram_idx]
    other_ngram = [i for i in range(4) if i not in same_ngram]
    X_train2 = [X_test[i] for i in other_ngram]
    y_train2 = [y_test[i] for i in other_ngram]
    X_test = vstack([X_test[i] for i in same_ngram])
    y_test = np.array([y_test[i] for i in same_ngram])

    X_train = vstack([X_train, *X_train2])
    y_train = np.concatenate([y_train, y_train2], axis=0)

    return X_train, X_test, y_train, y_test

  def enumerate_ngram_splits(self):
    for i in self.question_mapping:
      for j in sorted(self.ngram_mapping[i]):
        yield self.get_X_y(i, ngram_idx=j)

  def sample_new_labels(self, label_probs):
    if type(label_probs) == np.ndarray:
      label_probs = label_probs.tolist()
    new_labels = np.zeros_like(label_probs, dtype=np.int32)
    for i in self.question_mapping:
      for j in sorted(self.ngram_mapping[i]):
        idxs = self.ngram_mapping[i][j]
        # sample a new label from the first probability in the list
        new_label = np.random.binomial(1, label_probs[i][idxs[0]])
        for k, idx in enumerate(idxs):
          new_labels[i, idx] = new_label

    return new_labels


###
#   We want to train model for each q and predict new labels for the four
#   explanations. Then we want to be able to compare the predicted
#   labels against the model's preferred ngram.

def build_per_question_models(question_dataset, continuous=False):
  model = LogisticRegression(max_iter=1000)
  n = question_dataset.n_annotations

  true_labels = []
  new_labels = []
  # cutoff = np.mean(question_dataset.y)
  for i in range(n):
    X, X_test, y, y_test = question_dataset.get_X_y(i)
    model.fit(X, y)
    y_hat = model.predict_proba(X_test)[:, 1]
    true_labels.append(y_test)

    # # print(y_test, y_hat)
    # accuracy = accuracy_score(y_test, y_hat > cutoff)
    # # loss = recall_score(y_test, y_hat, labels=[0, 1])
    # loss = precision_score(y_test, y_hat > cutoff, zero_division=0)
    # # print(f"{i:2d} acc: {accuracy:.3f} recall: {loss:.3f}")
    # tot_acc += accuracy
    # tot_loss += loss
    new_labels.append(y_hat.tolist())

  new_labels = np.array(new_labels)
  cutoff = np.percentile(new_labels, 100 * (1 - np.mean(question_dataset.y)))
  
  true = np.array(true_labels)
  accuracy = accuracy_score(true.reshape(-1), (new_labels > cutoff).reshape(-1))
  roc_auc = roc_auc_score(true.reshape(-1), new_labels.reshape(-1))
  print(f"acc: {accuracy:0.3f}, auc: {roc_auc:0.3f}")
  if not continuous:
    new_labels = (new_labels > cutoff).astype(np.int32)

  return new_labels


def build_per_ngram_models(question_dataset, continuous=False):
  model = LogisticRegression(max_iter=1000)

  n = question_dataset.n_annotations
  true_labels = np.zeros([n, 4])
  new_labels = np.zeros([n, 4])

  for i in question_dataset.question_mapping:
    for j in question_dataset.ngram_mapping[i]:
      (X, X_test, y, y_test) = question_dataset.get_X_y(i, j) 
      model.fit(X, y)
      y_hat = model.predict_proba(X_test)[:, 1]
      
      idxs = question_dataset.ngram_mapping[i][j]
      for k, idx in enumerate(idxs):
        new_labels[i][idx] = y_hat[k]
        true_labels[i][idx] = y_test[k]
      # true_labels.extend(y_test.tolist())
      # new_labels.extend(y_hat.tolist())

  # new_labels = np.array(new_labels)
  # true_labels = np.array(true_labels)

  cutoff = np.percentile(new_labels, 100 * (1 - np.mean(question_dataset.y)))
  accuracy = accuracy_score(true_labels.reshape(-1),
                            (new_labels > cutoff).reshape(-1))
  roc_auc = roc_auc_score(true_labels, new_labels)
  print(f"acc: {accuracy:0.3f}, auc: {roc_auc:0.3f}")
  if not continuous:
    new_labels = (new_labels > cutoff).astype(np.int32)
  
  return new_labels


def validate(X, y, model='logreg', C=1.0):
  np.random.seed(42)

  n = len(y)
  scoring = ['accuracy', 'roc_auc']
  if model.lower() == 'mlp':
    model = MLPClassifier(max_iter=1000, hidden_layer_sizes=(8,))
  else:
    model = LogisticRegression(C=C, max_iter=5000)
  
  def cv_check(test_size, split_func=None):
    cv_scoring = scoring

    if split_func == "shuffle":
      split = ShuffleSplit(100, test_size=test_size)
    elif split_func == "loo" or test_size == 1:
      split = LeaveOneOut()
      cv_scoring = ['accuracy']
    else:
      split = n // test_size

    results = {"convergence_raised": None}
    with warnings.catch_warnings(record=True) as warns:
      warnings.filterwarnings("always")
      cv = cross_validate(model, X, y,
                          cv=split, scoring=cv_scoring)

      convergence_raised = False
      for warn in warns:
        if warn.category == UserWarning:
          return results
        elif warn.category == ConvergenceWarning:
          results["convergence_raised"] = warn.message

    for score in cv_scoring:
      vals = cv[f"test_{score}"]
      mean = np.mean(vals)
      std = np.std(vals)
      # results[score] = f"({mean - std:.2f}, {mean + std:.2f})"
      results[score] = f"{mean:.2f} \u00B1 {std:.2f}"
    return results

  rows = ['test_size'] + scoring
  width = 0

  convergence_raised = None
  results = defaultdict(list)
  # test_sizes = [64, 32, 16, 8, 4, 1]
  test_sizes = [64, 32, 16]
  # test_sizes = [1]
  for i, test_size in enumerate(test_sizes):
    results['test_size'].append(str(test_size))
    cv = cv_check(test_size)
    if cv["convergence_raised"] is not None:
      convergence_raised = cv["convergence_raised"]
    new_width = width
    for score in scoring:
      val = cv.get(score, 'N/A')
      results[score].append(val)
      new_width = max(new_width, len(val))
    if new_width != width:
      width = new_width
      print(" ".join([f"{row:{width}s}" for row in rows]))
    vals = [results[row][i] for row in rows]
    print(" ".join([f"{val:{width}s}" for val in vals]))

  if convergence_raised is not None:
    print(convergence_raised)

  # for i in range(len(test_sizes)):
    
  # print(f"{'test size':>10s}", end=' ')
  # for n in results['test_size']:
  #   print(f"{n:^{width}d}", end=' ')
  # print()
  # 
  # for row in scoring:
  #   print(f"{row:>10s}", end=' ')
  #   for val in results[row]:
  #     print(f"{val:{width}s}", end=' ')
  #   print()
  # results['loo'].append('loo')
  # results['accuracy'].append(


if __name__ == "__main__":
  transform = load_word2vec_model()
  handler = ce.AnnotationHandler()
  qd = QuestionDataset(handler.annotations, transform)
  new_labels = build_per_question_models(qd)
