import argparse
import re
from tqdm import tqdm
from pathlib import Path
import pandas as pd
import numpy as np
import torch
import pickle
from torch.autograd import Variable
from functools import reduce
import datasets
from learn import tools
from compare_explanations import load_plausibility_annotations, match_annotations_to_hadm


'''
SUBJECT_ID,HADM_ID,TEXT,LABELS,length
'''
def get_plausibility_annotations(subsetby=None):
    annotations, counts, text_regexes = load_plausibility_annotations()
    annotation_to_hadm, full_texts = match_annotations_to_hadm(annotations, counts, text_regexes)
    out_rows = []
    for idx, annotation in enumerate(annotations):
        try:
            hadm =  annotation_to_hadm[annotation['id']]
            full_text = full_texts[hadm]
            row = [
                    f'{annotation["id"]}',
                    hadm,
                    full_text,
                    annotation['code'],
                    len(full_text.split())  
                    ]
            # ngrams = []
            for exp in annotation['explanations']:
                row.append(exp['ngram'])
            out_rows.append(row)
        except KeyError:
            print(f"Annotation ID {annotation['id']} not in annotation_to_hadm")
    df = pd.DataFrame(out_rows, columns=['SUBJECT_ID','HADM_ID', 'TEXT','LABELS','length', 'NGRAM-0', 'NGRAM-1', 'NGRAM-2', 'NGRAM-3'])
    df['HADM_ID'] = df['HADM_ID'].astype(np.int64)
    if subsetby is not None:
        if subsetby == '+/++':
            df = df[~df['ANNOTATION'].isna()]
        else:
            df = df[df['ANNOTATION'] == subsetby]
    return df

def find_sub_list(sl,l):
    sll=len(sl)
    for ind in (i for i,e in enumerate(l) if e==sl[0]):
        if l[ind:ind+sll]==sl:
            return ind,ind+sll-1

def scores_from_attn(attn, data, output, annot, dicts, agg='first'):
    label_idx = dicts['c2ind'][annot['LABELS']]
    full_text = annot['TEXT'].split()
    scores = []
    for i in range(min(len(full_text), attn.shape[2])):
        #uses score of first word in ngram
        if agg=='first':
            scores.append(attn[0][label_idx][i].item())
        elif agg=='avg':
            avg = sum(attn[0][label_idx][i:i+4]).item()/4
            scores.append(avg)
        elif agg=='sum':
            _sum = sum(attn[0][label_idx][i:i+4]).item()
            scores.append(_sum)
        else:
            raise Exception('Unknown aggregate method.')
    max_score = max(scores)
    max_idx = scores.index(max_score)
    return max_score, max_idx

def get_exp_caml(args):
    print('loading lookups..')
    dicts = datasets.load_lookups(args, desc_embed=True)
    print('loading model...')
    model = tools.pick_model(args, dicts)
    desc_embed = model.lmbda > 0
    print('loading annotations...')
    plausibility_annotations = get_plausibility_annotations(subsetby=args.subsetby)
    plausibility_annotations.to_csv('tmp.csv', index=False)
    num_labels = len(dicts['ind2c'])
    print('getting data generator...')
    gen = datasets.data_generator('tmp.csv', dicts, 1, num_labels, version=args.version, desc_embed=desc_embed)
    scores = []
    start_idxs = []
    four_grams = []
    fourteen_grams = []
    print('getting scores...')
    for batch_idx, tup in tqdm(enumerate(gen)):
        annot = plausibility_annotations.iloc[batch_idx]
        full_text = annot['TEXT'].split()
        data, target, hadm_ids, _, descs = tup
        data, target = Variable(torch.LongTensor(data)), Variable(torch.FloatTensor(target))
        if args.gpu:
            data = data.cuda()
            target = target.cuda()
        model.zero_grad()
        if desc_embed:
            desc_data = descs
        else:
            desc_data = None
        output, loss, alpha = model(data, target, desc_data=desc_data, get_attention=True)
        max_score, max_idx = scores_from_attn(alpha, data, output, annot, dicts, agg=args.agg)
        scores.append(max_score)
        start_idxs.append(max_idx)
        four_grams.append(' '.join(full_text[max_idx:max_idx+4]))
        fourteen_grams.append(' '.join(full_text[max_idx-5:max_idx+9]))

    plausibility_annotations.drop(columns='SUBJECT_ID', inplace=True)
    plausibility_annotations['SCORES'] = scores
    plausibility_annotations['STARTIDX'] = start_idxs
    plausibility_annotations['4GRAM'] = four_grams
    plausibility_annotations['14GRAM'] = fourteen_grams
    
    plausibility_annotations.to_csv(args.outpath, index=False)

def get_exp_logreg(args):
    print('loading lookups..')
    dicts = datasets.load_lookups(args, desc_embed=True)
    w2ind, c2ind = dicts['w2ind'], dicts['c2ind']
    print('loading model...')
    model = pickle.load(open(args.test_model, 'rb'))
    mat = model.coef_
    print('loading annotations...')
    plausibility_annotations = get_plausibility_annotations(subsetby=args.subsetby)
    num_labels = len(dicts['ind2c'])
    print('getting data generator...')
    scores = []
    start_idxs = []
    four_grams = []
    fourteen_grams = []
    print('getting scores...')
    for batch_idx, annot in tqdm(plausibility_annotations.iterrows()):
        full_text = annot['TEXT'].split()
        c_idx = c2ind[annot['LABELS']]
        max_score, max_idx = 0, 0
        if c_idx < len(mat):
            word_weights = mat[c2ind[annot['LABELS']]]
            # For each ngram
            for i in range(len(full_text)):
                sum_weights = 0.0
                if args.agg=='first':
                    word = full_text[i]
                    if word in w2ind:
                        sum_weights += word_weights[w2ind[word]]
                elif args.agg=='sum' or args.agg=='avg':
                    for j in range(i,min(i+4, len(full_text))):
                        word = full_text[j]
                        if word in w2ind:
                            sum_weights = sum_weights + word_weights[w2ind[word]]
                    if args.agg=='avg':
                        sum_weights /= 4
                else:
                    raise Exception
                if sum_weights > max_score:
                    max_score = sum_weights
                    max_idx = i
        scores.append(max_score)
        start_idxs.append(max_idx)
        four_grams.append(' '.join(full_text[max_idx:max_idx+4]))
        fourteen_grams.append(' '.join(full_text[max_idx-5:max_idx+9]))

    plausibility_annotations.drop(columns='SUBJECT_ID', inplace=True)
    plausibility_annotations['SCORES'] = scores
    plausibility_annotations['STARTIDX'] = start_idxs
    plausibility_annotations['4GRAM'] = four_grams
    plausibility_annotations['14GRAM'] = fourteen_grams
    plausibility_annotations.to_csv(args.outpath, index=False)

def main(args):
    if args.model == 'conv_attn':
        get_exp_caml(args)
    elif args.model == 'logreg':
        get_exp_logreg(args)
    else:
        raise 'Unknown model type'
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="train a neural network on some clinical documents")
    parser.add_argument('vocab')
    parser.add_argument('data_path', help='path to {dev/test}_full.csv')
    parser.add_argument('test_model', help='path to model weights')
    parser.add_argument('model', choices=['conv_attn', 'logreg'])
    parser.add_argument('outpath', help='path to save scores to')
    parser.add_argument('agg', choices=['first','sum','avg'])
    parser.add_argument('--Y', default='full')
    parser.add_argument('--version', default='mimic3')
    parser.add_argument("--embed-file", type=str, required=False, dest="embed_file",
                        help="path to a file holding pre-trained embeddings")
    parser.add_argument('--filter-size', default=10)
    parser.add_argument("--num-filter-maps", type=int, required=False, dest="num_filter_maps", default=50,
                        help="size of conv output (default: 50)")
    parser.add_argument("--lmbda", type=float, required=False, dest="lmbda", default=0.01,
                        help="hyperparameter to tradeoff BCE loss and similarity embedding loss. defaults to 0, which won't create/use the description embedding module at all. ")
    parser.add_argument("--public-model", dest="public_model", action="store_const", required=False, const=True,
                        help="optional flag for testing pre-trained models from the public github")
    parser.add_argument("--gpu", dest="gpu", action="store_const", required=False, const=True,
                        help="optional flag to use GPU if available")
    parser.add_argument("--embed-size", type=int, required=False, dest="embed_size", default=100,
                        help="size of embedding dimension. (default: 100)")
    parser.add_argument("--dropout", dest="dropout", type=float, required=False, default=0.2,
                        help="optional specification of dropout (default: 0.5)")
    parser.add_argument("--code-emb", type=str, required=False, dest="code_emb", 
                        help="point to code embeddings to use for parameter initialization, if applicable")

    parser.add_argument("--four_gram_only", action="store_true", required=False, default=False,
                        help="optional flag only calculating rouge overlap")
    parser.add_argument('--subsetby', choices=['+', '++', '+/++'], help='subset plausibility explanations by which annotation it received')
    args = parser.parse_args()
    main(args)
