import os, multiprocessing, csv, tqdm, ast, json
import pandas as pd
from multiprocessing.pool import Pool
from quickumls import QuickUMLS


def match(params):
    # define QuickUMLS: using default parameters
    matcher = QuickUMLS("QuickUMLS", threshold=0.5)
    nid, text = params
    umls_result = matcher.match(text, best_match=True, ignore_syntax=False)

     # format UMLS results
    new_umls_result = {} # "term": {sidx, eidx, [term, cui, similarity]}
    for result in umls_result:
        ngram = result[0]["ngram"]
        sidx, eidx = result[0]["start"], result[0]["end"]
        term_cui_sim_list = [[result[0]["term"], result[0]["cui"], result[0]["similarity"]]]

        for r in result[1:]:
            try:
                assert ngram == r["ngram"]
                assert sidx == r["start"]
                assert eidx == r["end"]
            except:
                print(f"Unmatched ngram/indices error for note {nid} with ngram {ngram}")
            term_cui_sim_list.append([r["term"], r["cui"], r["similarity"]])
        
        # sort by similarity
        term_cui_sim_list.sort(key=lambda x: x[2], reverse=True)
        
        new_umls_result[(ngram, sidx, eidx)] = term_cui_sim_list
    
    # sort by indices
    sorted_umls_result = dict(sorted(new_umls_result.items(), key=lambda x: x[0][1]))
    
    # check consistency
    try:
        assert len(sorted_umls_result) == len(umls_result)
    except:
        print(f"Unmatched number of ngrams error for note {nid}")
    
    return {
        "note_id": nid,
        "text": text,
        "UMLS": [[k, v] for k, v in sorted_umls_result.items()]
    }


    # entities = {}
    # indices = {}
    # new_results = []
    # for result in results:
    #     new_result = []
    #     for r in result:
    #         new_r = {}
    #         for k, v in r.items():
    #             if k == "semtypes":
    #                 v = list(v)
    #             new_r[k] = v
    #         if new_r:
    #             new_result.append(new_r)
    #     if new_result:
    #         new_results.append(new_result)

    #     ngram = result[0]['ngram']
    #     entities[ngram] = {'term': [], 'cui': [], 'similarity': [], 'semtypes': [], 'preferred': []}
    #     indices[ngram] = [result[0]['start'], result[0]['end']]

    #     # fill other information for reference
    #     for r in result:
    #         entities[r['ngram']]['term'].append(r['term'])
    #         entities[r['ngram']]['cui'].append(r['cui'])
    #         entities[r['ngram']]['similarity'].append(r['similarity'])
    #         entities[r['ngram']]['semtypes'].append(list(r['semtypes']))
    #         entities[r['ngram']]['preferred'].append(r['preferred'])
    # return [idx, text, indices, entities]
    # return [idx, text, results]

if __name__ == '__main__':
    # data_path = 'mimic3'

    # load dataset
    datasets = ['../data/snomed-ct-entity-challenge/1.0.0/mimic-iv_notes_training_set.csv', '../data/mimic-iv-note/2.2/note/discharge_ICD10_excl.csv']
    # note_id, text
    output_files = ["../data/intrim/snomed_ct_entity_linking_UMLS.jsonl", "../data/intrim/discharge_ICD10_excl_UMLS.jsonl"]
    
    # iterate over dataset
    for dataset, outfile in zip(datasets, output_files):
        fname = dataset[:-4]
        print('-'*20, 'Processing ', dataset, '-'*20,'\n')
        todo = []
        df = pd.read_csv(dataset)
        for i in range(len(df)):
            todo.append((df.loc[i, 'note_id'], df.loc[i, 'text']))

        # with open(os.path.join(data_path, dataset), newline='') as csvfile:
        #     csvreader = csv.reader(csvfile, delimiter=',')
        #     headers = next(csvreader)
        #     for row in csvreader:
        #         # ids.append(int(row[0]))
        #         # texts.append(' '.join(ast.literal_eval(row[1])))
        #         todo.append((int(row[0]), ' '.join(ast.literal_eval(row[1]))))
        
        # prepare output file
        # outfile = os.path.join('mimic3_umls', fname+'.json')
        # with open(outfile, 'a', newline='', encoding='utf8') as outf:
        #     writer = csv.writer(outf)
        #     writer.writerow(['ROW_ID', 'processed_text', 'UMLS_terms'])
        # all_results = []
        _p = Pool(multiprocessing.cpu_count()-8)
        for r in tqdm.tqdm(_p.imap(match, todo), total=len(todo)):
            # with open(outfile, 'a', newline='', encoding='utf-8') as out_f:
            #     writer = csv.writer(out_f)
                # writer.writerow([str(r[0]), r[1], str(r[2])])
            # all_results.append({'ROW_ID': r[0], 'processed_text': r[1], 'terms': list(r[2].keys()), 'term_indices': r[2], 'info': r[3]})
            with open(outfile, "a") as f:
                f.write(json.dumps(r, default=list) + "\n")
        
        _p.close()
        _p.join()

        print("Done with identifying UMLS terms.")
        