import argparse
import re
from tqdm import tqdm
from files2rouge import files2rouge
from pathlib import Path
import pandas as pd
import csv
import sys
import numpy as np
from pprint import pprint
import datasets
from compare_explanations import load_plausibility_annotations, match_annotations_to_hadm

def filter_rouge(output_string):
    reg = "ROUGE-(1|2|L) Average_(R|P|F): (\d.\d+)"
    lines = output_string.split('\n')
    _j = {}
    for l in lines:
        if re.search(reg, l):
            match = re.search(reg, l)
            r_type = f'rouge-{match.group(1)}'.lower() # {1,2,L}
            m_type = match.group(2).lower() # {R, P, F}
            value = eval(match.group(3))
            # import ipdb; ipdb.set_trace()
            if r_type not in _j:
                _j[r_type] = {}
            elif 'f' == m_type:
                _j[r_type] = value
    return _j

def calc_rouge(hyp_path, ref_path, outfile='tmp.txt'):
    files2rouge.run(hyp_path, ref_path, ignore_empty=True, saveto=outfile)
    rouge_score = Path(outfile).read_text()
    rouge_score = filter_rouge(rouge_score)
    return rouge_score

def get_rouge_from_df(hyp_df, ref_df, text_col):
    merged = pd.merge(hyp_df, ref_df,  how='inner', on=['HADM_ID','LABEL'], suffixes=("_hyp","_ref"))
    hyp = list(merged[text_col + '_hyp'])
    ref = list(merged[text_col + '_ref'])
    assert len(hyp) == len(ref), "len(hyp) != len(ref)"
    with open('hyp.txt', 'w') as hypfile:
        hypfile.write('\n'.join(hyp))
    with open('ref.txt', 'w') as reffile:
        reffile.write('\n'.join(ref))
    rouge_score = calc_rouge('hyp.txt','ref.txt')
    pprint(rouge_score)

def read_our_format_exp(caml_path, notesfile, c2ind, four_gram_only=False):
    # only return explanations for true labels
    caml = open(caml_path)
    notes = open(notesfile)
    next(notes) # read header
    out_rows = []
    # SUBJECT_ID,HADM_ID,LABEL,INDEX,NGRAM,SCORE
    for note, exps in tqdm(zip(notes, caml)):
        note = note.strip().split(',')
        exps = exps.strip().split(',')
        SUBJECT_ID, HADM_ID = note[0], note[1]
        LABELS = note[3].split(';')
        for label in LABELS:
            label_idx = c2ind[label]
            exp = exps[label_idx+1]
            if four_gram_only:
                exp = ' '.join(exp.split()[5:-5])
            if exp != '':
                out_rows.append([SUBJECT_ID,HADM_ID, label, exp])
        
    df = pd.DataFrame(out_rows, columns=['SUBJECT_ID','HADM_ID','LABEL','NGRAM'])
    df['HADM_ID'] = df['HADM_ID'].astype(np.int64)
    return df

def get_plausibility_annotations(subsetby=None):
    annotations, counts, text_regexes = load_plausibility_annotations()
    annotation_to_hadm, full_texts = match_annotations_to_hadm(annotations, counts, text_regexes)
    out_rows = []
    for annotation in annotations:
        # 'HADM_ID','LABEL','NGRAM', 'TEXT', 'ANNOTATION'
        for exp in annotation['explanations']:
            try:
                row = [
                        annotation_to_hadm[annotation['id']],
                        annotation['code'],
                        exp['ngram'],
                        exp['text'],
                        exp['annotation']
                        ]
            except KeyError:
                print(f"Annotation ID {annotation['id']} not in annotation_to_hadm")
        out_rows.append(row)
    df = pd.DataFrame(out_rows, columns=['HADM_ID','LABEL','NGRAM', 'TEXT', 'ANNOTATION'])
    df['HADM_ID'] = df['HADM_ID'].astype(np.int64)
    if subsetby is not None:
        if subsetby == '+/++':
            df = df[~df['ANNOTATION'].isna()]
        else:
            df = df[df['ANNOTATION'] == subsetby]
    return df

def read_caml_format_exp(exp_path):
    df =  pd.read_csv(exp_path)
    return df

# Figures out which format the explanations is in
def read_exp(path, data_path, c2ind, four_gram_only=False, subsetby=None):
    if path.endswith('top_ngrams.csv'):
        df = read_caml_format_exp(path)
    elif path == "plausibility_annotations":
        df = get_plausibility_annotations(subsetby=subsetby)
    elif path.endswith('explanations_0.csv'):
        df= read_our_format_exp(path, data_path, c2ind, four_gram_only=four_gram_only)
    else:
        raise f'Wrong file format: {path}'
    return df

def main(args):
    dicts = datasets.load_lookups(args)
    print(f'reading hyp {args.hyp_path}')
    hyp_exp = read_exp(args.hyp_path, 
                            args.data_path, 
                            dicts['c2ind'],
                            four_gram_only=args.four_gram_only,
                            subsetby=args.subsetby)
    print(f'reading ref {args.ref_path}')
    ref_exp = read_exp(args.ref_path, 
                            args.data_path, 
                            dicts['c2ind'],
                            four_gram_only=args.four_gram_only,
                            subsetby=args.subsetby)
    get_rouge_from_df(hyp_exp, ref_exp, 'NGRAM')

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="train a neural network on some clinical documents")
    parser.add_argument("hyp_path", type=str, help="path to proxy explanations or \"plausibility_annotations\"")
    parser.add_argument("ref_path", type=str, help="path to caml explanations")
    parser.add_argument('vocab')
    parser.add_argument('data_path')
    parser.add_argument('--Y', default='full')
    parser.add_argument('--version', default='mimic3')
    parser.add_argument("--public-model", dest="public_model", action="store_const", required=False, const=True,
                        help="optional flag for testing pre-trained models from the public github")
    parser.add_argument("--four_gram_only", action="store_true", required=False, default=False,
                        help="optional flag only calculating rouge overlap")
    parser.add_argument('--subsetby', choices=['+', '++', '+/++'], help='subset plausibility explanations by which annotation it received')
    args = parser.parse_args()
    main(args)
