import os
import sys
# sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import argparse
import re
from tqdm import tqdm
from pathlib import Path
import pandas as pd
import csv
import numpy as np
from functools import reduce
import datasets
from compare_explanations import load_plausibility_annotations, match_annotations_to_hadm


def read_our_format_exp(caml_path, notesfile, c2ind, four_gram_only=False):
    # only return explanations for true labels
    caml = open(caml_path)
    notes = open(notesfile)
    next(notes) # read header
    out_rows = []
    # SUBJECT_ID,HADM_ID,LABEL,INDEX,NGRAM,SCORE
    for note, exps in tqdm(zip(notes, caml)):
        note = note.strip().split(',')
        exps = exps.strip().split(',')
        SUBJECT_ID, HADM_ID = note[0], note[1]
        LABELS = note[3].split(';')
        for label in LABELS:
            label_idx = c2ind[label]
            exp = exps[label_idx+1]
            if four_gram_only:
                exp = ' '.join(exp.split()[5:-5])
            if exp != '':
                out_rows.append([SUBJECT_ID,HADM_ID, label, exp])
        
    df = pd.DataFrame(out_rows, columns=['SUBJECT_ID','HADM_ID','LABEL','NGRAM'])
    df['HADM_ID'] = df['HADM_ID'].astype(np.int64)
    name = caml_path.split('/')[-2]
    return  name, df

def get_plausibility_annotations(subsetby=None):
    annotations, counts, text_regexes = load_plausibility_annotations()
    annotation_to_hadm, full_texts = match_annotations_to_hadm(annotations, counts, text_regexes)
    out_rows = []
    for annotation in annotations:
        # 'HADM_ID','LABEL','NGRAM', 'TEXT', 'ANNOTATION'
        for exp in annotation['explanations']:
            try:
                row = [
                        annotation_to_hadm[annotation['id']],
                        annotation['code'],
                        exp['ngram'],
                        exp['text'],
                        exp['annotation']
                        ]
            except KeyError:
                print(f"Annotation ID {annotation['id']} not in annotation_to_hadm")
        out_rows.append(row)
    df = pd.DataFrame(out_rows, columns=['HADM_ID','LABEL','NGRAM', 'TEXT', 'ANNOTATION'])
    df['HADM_ID'] = df['HADM_ID'].astype(np.int64)
    if subsetby is not None:
        if subsetby == '+/++':
            df = df[~df['ANNOTATION'].isna()]
        else:
            df = df[df['ANNOTATION'] == subsetby]
        name = 'plausibility_annoations' + subsetby
    else:
        name = 'plausibility_annoations'
    return name, df

def read_caml_format_exp(exp_path):
    df =  pd.read_csv(exp_path)
    if 'proxy' in exp_path:
        name = 'proxy'
    elif 'log_reg' in exp_path:
        name = 'log_reg'
    else:
        name = 'CAML'
    return name, df

# Figures out which format the explanations is in
def read_exp(path, data_path, c2ind, four_gram_only=False, subsetby=None):
    if path.endswith('top_ngrams.csv'):
        name, df = read_caml_format_exp(path)
    elif path == "plausibility_annotations":
        name, df = get_plausibility_annotations(subsetby=subsetby)
    elif path.endswith('explanations_0.csv'):
        name, df = read_our_format_exp(path, data_path, c2ind, four_gram_only=four_gram_only)
    else:
        raise f'Wrong file format: {path}'
    return name, df

def main(args):
    dicts = datasets.load_lookups(args)
    exps = []
    for efile in args.e:
        print(f'reading {efile}')
        name, df = read_exp(efile, 
                                args.data_path, 
                                dicts['c2ind'],
                                four_gram_only=args.four_gram_only,
                                subsetby=args.subsetby)
        df.rename({'NGRAM':f'{name}_NGRAM'}, axis=1, inplace=True)
        exps.append(df)
    merged_exps = reduce(
                    lambda x, y: pd.merge(
                        x, y, how='inner', 
                        on=['HADM_ID','LABEL']
                        ),
                    exps)
    drop_cols = [col for col in merged_exps.columns if (col.endswith('_x') or col.endswith('_y'))]

    # drop duplicate columns and reorder columns
    merged_exps.drop(drop_cols, axis=1, inplace=True)
    merged_exps = merged_exps[[ 
        'SUBJECT_ID',
        'HADM_ID',
        'LABEL',
        'ANNOTATION',
        'TEXT',
        'log_reg_NGRAM',
        'plausibility_annoations_NGRAM',
        'proxy_NGRAM',
        'CAML_mimic3_full_NGRAM'
        ]]
    merged_exps.to_csv(args.outfile, index=False)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="train a neural network on some clinical documents")
    parser.add_argument('vocab')
    parser.add_argument('data_path', help='path to {dev/test}_full.csv')
    parser.add_argument('outfile', help='file to save explanations to')
    parser.add_argument("--exp_file", "-e", dest='e', action='append', default=[], required=True,
                        help="path to explanations or \"plausibility_annotations\"")
    parser.add_argument('--Y', default='full')
    parser.add_argument('--version', default='mimic3')
    parser.add_argument("--public-model", dest="public_model", action="store_const", required=False, const=True,
                        help="optional flag for testing pre-trained models from the public github")
    parser.add_argument("--four_gram_only", action="store_true", required=False, default=False,
                        help="optional flag only calculating rouge overlap")
    parser.add_argument('--subsetby', choices=['+', '++', '+/++'], help='subset plausibility explanations by which annotation it received')
    args = parser.parse_args()
    main(args)
