# import pdb; pdb.set_trace()

import csv
import datetime
import gzip
import inspect
import json
import os
import re
import sys
import numpy as np
from tqdm import tqdm
from collections import namedtuple
from sklearn.tree import _tree
from sklearn.metrics import roc_auc_score
from scipy.sparse import csr_matrix
from collections import defaultdict
import sklearn.metrics as metrics
import scipy.stats as stats
import constants
import evaluation

Splits = namedtuple('Splits', 'train dev test')
MimicData = namedtuple('data', 'X real binary truth n_train n_dev n_test')

class NumpySerializer(json.JSONEncoder):
  def default(self, obj):
    if isinstance(obj, np.integer):
      return int(obj)
    elif isinstance(obj, np.floating):
      return float(obj)
    elif isinstance(obj, np.ndarray):
      return obj.tolist()
    elif isinstance(obj, range):
      return list(obj)
    else:
      return super().default(obj)    


def sgd_classifier_to_json(classifier):
  args_dict = {}
  needed_args = inspect.signature(type(classifier).__init__).parameters.keys()
  for key in needed_args:
    if hasattr(classifier, key):
      args_dict[key] = getattr(classifier, key)
  params_dict = {}
  for key in ['classes_', 'coef_', 'intercept_', 't_']:
    if hasattr(classifier, key):
      params_dict[key] = getattr(classifier, key)
  full_dict = {'args': args_dict, 'params': params_dict}
  return json.dumps(full_dict, cls=NumpySerializer)


def sgd_classifier_from_json(cls, full_json):
  full_dict = json.loads(full_json)
  needed_args = inspect.signature(cls.__init__).parameters.keys()
  args = {key: val for key, val in full_dict['args'].items()
          if key in needed_args}
  new_classifier = cls(**args)
  for key in full_dict['params']:
    setattr(new_classifier, key, np.array(full_dict['params'][key]))
  return new_classifier


def load_data(keys):
  data = {}
  for key in keys:
    infn = os.path.join("data", "{}.json.gz".format(key))
    try:
      with gzip.open(infn) as inf:
        obj = json.loads(inf.readline().decode())
        data[key] = np.array(obj, dtype=np.float32)
    except Exception as e:
      print("failing on {}".format(infn))
      raise e

  return data

def ranking_precision_score(y_true, y_score, k=10):
  """Precision at rank k
  Parameters
  ----------
  y_true : array-like, shape = [n_samples]
      Ground truth (true relevance labels).
  y_score : array-like, shape = [n_samples]
      Predicted scores.
  k : int
      Rank.
  Returns
  -------
  precision @k : float

  From: https://gist.github.com/mblondel/7337391
  """
  unique_y = np.unique(y_true)

  if len(unique_y) > 2:
      raise ValueError("Only supported for two relevance levels.")

  pos_label = unique_y[1]
  n_pos = np.sum(y_true == pos_label)

  order = np.argsort(y_score)[::-1]
  y_true = np.take(y_true, order[:k])
  n_relevant = np.sum(y_true == pos_label)

  # Divide by min(n_pos, k) such that the best achievable score is always 1.0.
  return float(n_relevant) / min(n_pos, k)

def tree_to_code(tree, feature_names, outf=sys.stdout):
    tree_ = tree.tree_
    feature_name = [feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!" for i in tree_.feature]

    def recurse(node, depth):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            outf.write("{}if {} <= {}:\n".format(indent, name, threshold))
            recurse(tree_.children_left[node], depth + 1)
            outf.write("{}else:  # if {} > {}\n".format(indent, name, threshold))
            recurse(tree_.children_right[node], depth + 1)
        else:
            outf.write("{}return {}\n".format(indent, tree_.value[node]))

    recurse(0, 1)

def read_json_preds(fname):
    if fname is None:
        return compile_nice_json(split="dev")
    with open(fname) as f:
        return json.load(f)

def read_mimic(fname, dicts, version, Y):
    if os.path.exists(fname.replace('.csv', '_bows.csv')):
        X, yy, hids =  read_bows(Y, fname.replace('.csv', '_bows.csv'), dicts["c2ind"], version)
    else:
        X, yy, hids =  construct_X_Y(fname, Y, dicts["w2ind"], dicts["c2ind"], version)
        write_bows(fname, X, hids, yy, dicts["ind2c"])
    return X, yy

def scores_to_matrix(scores, ind2c):
    matrix = []
    # import ipdb;ipdb.set_trace()
    for hid in tqdm(scores):
        ex = []
        for label_idx in ind2c:
            ex.append(scores[hid][ind2c[label_idx]])
        matrix.append(ex)
    return np.array(matrix)

def probs_to_preds(probs, thresholds_path, ind2c):
    if thresholds_path is None:
        return np.where(probs > 0.5, 1, 0)
    thresholds = json.load(open(thresholds_path))
    y_hat = np.zeros(probs.shape)
    
    print('converting probs to predictions')
    for ex_idx in tqdm(range(probs.shape[0])):
        for code_idx in range(probs.shape[1]):
            # import ipdb;ipdb.set_trace()
            # code_idx = c2ind[code]
            code = ind2c[code_idx]
            pred = 1 if probs[ex_idx][code_idx] > thresholds[code] else 0
            y_hat[ex_idx][code_idx] = pred
    return y_hat

def calc_classification_metrics(y_true, y_pred, ind2c, thresholds_path=None):
    y_hat = probs_to_preds(y_pred, thresholds_path, ind2c)
    print('======= Classification metrics ======')
    metrics_tr = evaluation.all_metrics(y_hat, np.round(y_true), k=[8, 15], yhat_raw=y_pred)
    evaluation.print_metrics(metrics_tr)

def calc_regression_metrics(caml, proxy, outfile):
    # for each label
    fout = open(outfile, 'w')
    scores = []
    for i in tqdm(range(caml.shape[1])):
        c, p = caml[:,i], proxy[:,i]
        cr = np.round(c)
        cp = np.where(p > 0.5, 1, 0)

        mse = metrics.mean_squared_error(c,p)
        pearson = stats.pearsonr(c,p)[0]
        spearman = stats.spearmanr(c,p).correlation
        kt = stats.kendalltau(c,p).correlation
        try:
            auc = metrics.roc_auc_score(cr,cp)
        except ValueError:
            auc = 0.
        f1 = metrics.f1_score(cr,cp, zero_division=0)
        precision = metrics.precision_score(cr,cp, zero_division=0)

        scores.append({
            "mse": mse,
            "pearson": pearson,
            "spearman": spearman,
            "kendalltau": kt
        })

        st = f"mse: {mse}, pearson: {pearson}, spearman: {spearman}, kendalltau: {kt} " \
                f"dev auc: {auc}, f1 {f1}, avg precision: {precision}"
        fout.write(st + '\n')
    fout.close()

    print('======= Regression metrics ======')
    for metric in scores[0].keys():
        m = []
        for label in scores:
            m.append(label[metric])
        m = np.nan_to_num(m)
        mean = np.mean(m)
        std = np.std(m)
        print(f'Macro {metric}: {mean} avg; {std} std')

    return scores

def write_bows(data_fname, X, hadm_ids, y, ind2c):
    out_name = data_fname.split('.csv')[0] + '_bows.csv'
    with open(out_name, 'w') as of:
        w = csv.writer(of)
        w.writerow(['HADM_ID', 'BOW', 'LABELS'])
        for i in range(X.shape[0]):
            bow = X[i].toarray()[0]
            inds = bow.nonzero()[0]
            counts = bow[inds]
            bow_str = ' '.join(['%d:%d' % (ind, count) for ind,count in zip(inds,counts)])
            code_str = ';'.join([ind2c[ind] for ind in y[i].nonzero()[0]])
            w.writerow([str(hadm_ids[i]), bow_str, code_str])

def read_bows(Y, bow_fname, c2ind, version):
    num_labels = len(c2ind)
    data = []
    row_ind = []
    col_ind = []
    hids = []
    y = []
    with open(bow_fname, 'r') as f:
        r = csv.reader(f)
        #header
        next(r)
        for i,row in tqdm(enumerate(r), desc='Reading BOW'):
            # import ipdb;ipdb.set_trace()
            hid = int(row[0])
            bow_str = row[1]
            code_str = row[2]
            for pair in bow_str.split():
                split = pair.split(':')
                ind, count = split[0], split[1]
                data.append(int(count))
                row_ind.append(i)
                col_ind.append(int(ind))
            if code_str == '':
                label_set = set()
            else:
                label_set = set([c2ind[c] for c in code_str.split(';')])
            y.append([1 if j in label_set else 0 for j in range(num_labels)])
            hids.append(hid)
        X = csr_matrix((data, (row_ind, col_ind)))
    return X, np.array(y), hids


def construct_X_Y(notefile, Y, w2ind, c2ind, version):
    """
        Each row consists of text pertaining to one admission
        INPUTS:
            notefile: path to file containing note data
            Y: size of the output label space
            w2ind: dictionary from words to integers for discretizing
            c2ind: dictionary from labels to integers for discretizing
            version: which (MIMIC) dataset
        OUTPUTS: 
            csr_matrix where each row is a BOW
                Dimension: (# instances in dataset) x (vocab size)
    """
    Y = len(c2ind)
    yy = []
    hadm_ids = []
    with open(notefile, 'r') as notesfile:
        reader = csv.reader(notesfile)
        next(reader)
        i = 0

        subj_inds = []
        indices = []
        data = []

        for i,row in tqdm(enumerate(reader), desc='Constructing X,y'):
            label_set = set()
            for l in str(row[3]).split(';'):
                if l in c2ind.keys():
                    label_set.add(c2ind[l])
            subj_inds.append(len(indices))
            yy.append([1 if j in label_set else 0 for j in range(Y)])
            text = row[2]
            for word in text.split():
                if word in w2ind:
                    index = w2ind[word]
                    if index != 0:
                        #ignore padding characters
                        indices.append(index)
                        data.append(1)
                else:
                    #OOV
                    indices.append(len(w2ind))
                    data.append(1)
            i += 1
            hadm_ids.append(int(row[1]))
        subj_inds.append(len(indices))

    return csr_matrix((data, indices, subj_inds)), np.array(yy), hadm_ids

def compile_nice_json(split="dev"):
  dir_ = constants.PROXY_DIR
  with open(f"{dir_}/{split}_hadm.txt") as inf:
    hadm_ids = [int(x) for x in inf]
  with open(f"{dir_}/c2ind.json") as inf:
    c2ind = json.load(inf)
    ind2c = {index: code for code, index in c2ind.items()}

  d = defaultdict(dict)
  for code_idx, code in ind2c.items():
    # if not os.path.exists(f"{dir_}/{split}.{code_idx}.npy.gz"):
    #   continue
    with gzip.open(f"{dir_}/{split}.{code_idx}.npy.gz") as inf:
      preds = np.load(inf)
      for hadm_id, pred in zip(hadm_ids, preds):
        d[hadm_id][code] = pred

  return d
