'''
Script that compares proxy to true labels
'''
import argparse
import logging
import time
import json
import os
import numpy as np
from tqdm import tqdm
import datasets
import evaluation
from sklearn import metrics
from scipy import stats
import mimic_proxy
from log_reg import read_bows, construct_X_Y
from utils import *

def main(args):
    dicts = datasets.load_lookups(args)
    print('reading true labels')
    mimic_X, mimic_yy= read_mimic(args.data_path, dicts, args.version, args.Y)

    print('reading proxy')
    proxy = read_json_preds(args.proxy_scores)
    print('converting proxy to matrix')
    proxy = scores_to_matrix(proxy, dicts['ind2c'])

    calc_classification_metrics(mimic_yy, proxy, dicts['ind2c'], thresholds_path=args.thresholds)


if __name__=="__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('data_path')
    parser.add_argument('vocab')
    parser.add_argument('--proxy_scores', help="Path to json scores. Defaults to running mimic_proxy.compile_nice_json()")
    parser.add_argument('--thresholds', help="Path to thresholds per class. Defaults to 0.5")
    parser.add_argument('--Y', default='full')
    parser.add_argument('--version', default='mimic3')
    parser.add_argument("--public-model", dest="public_model", action="store_const", required=False, const=True,
                        help="optional flag for testing pre-trained models from the public github")
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')

    start = time.time()
    main(args)
    end = time.time()
    logging.info(f'Time to run script: {end-start} secs')