import argparse
import re
from tqdm import tqdm
from pathlib import Path
import pandas as pd
import csv
import numpy as np
import torch
import pickle
from torch.autograd import Variable
from functools import reduce
import datasets
from learn import tools
from compare_explanations import load_plausibility_annotations, match_annotations_to_hadm


'''
SUBJECT_ID,HADM_ID,TEXT,LABELS,length
'''
def get_plausibility_annotations(subsetby=None):
    annotations, counts, text_regexes = load_plausibility_annotations()
    annotation_to_hadm, full_texts = match_annotations_to_hadm(annotations, counts, text_regexes)
    out_rows = []
    for idx, annotation in enumerate(annotations):
        try:
            hadm =  annotation_to_hadm[annotation['id']]
            full_text = full_texts[hadm]
            row = [
                    f'{annotation["id"]}',
                    hadm,
                    full_text,
                    annotation['code'],
                    len(full_text.split())  
                    ]
            # ngrams = []
            for exp in annotation['explanations']:
                row.append(exp['ngram'])
            out_rows.append(row)
        except KeyError:
            print(f"Annotation ID {annotation['id']} not in annotation_to_hadm")
    df = pd.DataFrame(out_rows, columns=['SUBJECT_ID','HADM_ID', 'TEXT','LABELS','length', 'NGRAM-0', 'NGRAM-1', 'NGRAM-2', 'NGRAM-3'])
    df['HADM_ID'] = df['HADM_ID'].astype(np.int64)
    if subsetby is not None:
        if subsetby == '+/++':
            df = df[~df['ANNOTATION'].isna()]
        else:
            df = df[df['ANNOTATION'] == subsetby]
    return df

def find_sub_list(sl,l):
    sll=len(sl)
    for ind in (i for i,e in enumerate(l) if e==sl[0]):
        if l[ind:ind+sll]==sl:
            return ind,ind+sll-1

def scores_from_attn(attn, data, output, annot, dicts):
    label_idx = dicts['c2ind'][annot['LABELS']]
    full_text = annot['TEXT'].split()
    scores = []
    for i in range(4):
        ngram = annot[f'NGRAM-{i}']
        start,end = find_sub_list(ngram.split(), full_text)
        try:
            _s = attn[0][label_idx][start].item()
        except IndexError:
            _s = 0.0 # When ngram happens after max token length
        scores.append(_s)
    ranking = [sorted(scores, reverse=True).index(x) for x in scores]
    return scores, ranking

def get_ranking_caml(args):
    print('loading lookups..')
    dicts = datasets.load_lookups(args, desc_embed=True)
    print('loading model...')
    model = tools.pick_model(args, dicts)
    desc_embed = model.lmbda > 0
    print('loading annotations...')
    plausibility_annotations = get_plausibility_annotations(subsetby=args.subsetby)
    plausibility_annotations.to_csv('tmp.csv', index=False)
    num_labels = len(dicts['ind2c'])
    print('getting data generator...')
    gen = datasets.data_generator('tmp.csv', dicts, 1, num_labels, version=args.version, desc_embed=desc_embed)
    scores = []
    rankings = []
    print('getting scores...')
    for batch_idx, tup in tqdm(enumerate(gen)):
        annot = plausibility_annotations.iloc[batch_idx]
        data, target, hadm_ids, _, descs = tup
        data, target = Variable(torch.LongTensor(data)), Variable(torch.FloatTensor(target))
        if args.gpu:
            data = data.cuda()
            target = target.cuda()
        model.zero_grad()
        if desc_embed:
            desc_data = descs
        else:
            desc_data = None

        output, loss, alpha = model(data, target, desc_data=desc_data, get_attention=True)
        score,ranking = scores_from_attn(alpha, data, output, annot, dicts)
        scores.append(';'.join((str(s) for s in score)))
        rankings.append(';'.join((str(r) for r in ranking)))
    plausibility_annotations.drop(columns='SUBJECT_ID', inplace=True)
    plausibility_annotations['SCORES'] = scores
    plausibility_annotations['RANKING'] = rankings
    plausibility_annotations.to_csv(args.outpath, index=False)

def get_ranking_logreg(args):
    print('loading lookups..')
    dicts = datasets.load_lookups(args, desc_embed=True)
    w2ind, c2ind = dicts['w2ind'], dicts['c2ind']
    print('loading model...')
    model = pickle.load(open(args.test_model, 'rb'))
    mat = model.coef_
    # import ipdb;ipdb.set_trace()
    print('loading annotations...')
    plausibility_annotations = get_plausibility_annotations(subsetby=args.subsetby)
    num_labels = len(dicts['ind2c'])
    print('getting data generator...')
    scores = []
    rankings = []
    print('getting scores...')
    for batch_idx, annot in tqdm(plausibility_annotations.iterrows()):
        text = annot['TEXT'].split()
        c_idx = c2ind[annot['LABELS']]
        if c_idx < len(mat):
            word_weights = mat[c2ind[annot['LABELS']]]
            # For each ngram
            score = []
            for i in range(4):
                ranking = []
                ngram = annot[f'NGRAM-{i}'].split()
                sum_weights = 0
                # Sum words in ngram
                for j in range(4):
                    word = ngram[j]
                    if word in w2ind:
                        inx = w2ind[word]
                        #add coeff from logistic regression matrix for given word
                        sum_weights = sum_weights + word_weights[inx]
                    else:
                        #else if word not in vocab, adds 0 weight
                        pass
                score.append(sum_weights)
        # Label not in train set
        else:
            score = [0,0,0,0]
        ranking = [sorted(score, reverse=True).index(x) for x in score]
        scores.append(';'.join((str(s) for s in score)))
        rankings.append(';'.join((str(r) for r in ranking)))
    plausibility_annotations.drop(columns='SUBJECT_ID', inplace=True)
    plausibility_annotations['SCORES'] = scores
    plausibility_annotations['RANKING'] = rankings
    plausibility_annotations.to_csv(args.outpath, index=False)

def main(args):
    if args.model == 'conv_attn':
        get_ranking_caml(args)
    elif args.model == 'logreg':
        get_ranking_logreg(args)
    else:
        raise 'Unknown model type'
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="train a neural network on some clinical documents")
    parser.add_argument('vocab')
    parser.add_argument('data_path', help='path to {dev/test}_full.csv')
    parser.add_argument('test_model', help='path to model weights')
    parser.add_argument('model', choices=['conv_attn', 'logreg'])
    parser.add_argument('outpath', help='path to save scores to')
    parser.add_argument('--Y', default='full')
    parser.add_argument('--version', default='mimic3')
    parser.add_argument("--embed-file", type=str, required=False, dest="embed_file",
                        help="path to a file holding pre-trained embeddings")
    parser.add_argument('--filter-size', default=10)
    parser.add_argument("--num-filter-maps", type=int, required=False, dest="num_filter_maps", default=50,
                        help="size of conv output (default: 50)")
    parser.add_argument("--lmbda", type=float, required=False, dest="lmbda", default=0.01,
                        help="hyperparameter to tradeoff BCE loss and similarity embedding loss. defaults to 0, which won't create/use the description embedding module at all. ")
    parser.add_argument("--public-model", dest="public_model", action="store_const", required=False, const=True,
                        help="optional flag for testing pre-trained models from the public github")
    parser.add_argument("--gpu", dest="gpu", action="store_const", required=False, const=True,
                        help="optional flag to use GPU if available")
    parser.add_argument("--embed-size", type=int, required=False, dest="embed_size", default=100,
                        help="size of embedding dimension. (default: 100)")
    parser.add_argument("--dropout", dest="dropout", type=float, required=False, default=0.2,
                        help="optional specification of dropout (default: 0.5)")
    parser.add_argument("--code-emb", type=str, required=False, dest="code_emb", 
                        help="point to code embeddings to use for parameter initialization, if applicable")

    parser.add_argument("--four_gram_only", action="store_true", required=False, default=False,
                        help="optional flag only calculating rouge overlap")
    parser.add_argument('--subsetby', choices=['+', '++', '+/++'], help='subset plausibility explanations by which annotation it received')
    args = parser.parse_args()
    main(args)
