'''
Script that calculates variance in feature weights

'''
import argparse
import logging
import time
import json
import numpy as np
from tqdm import tqdm
import datasets
import evaluation
from sklearn import metrics
from scipy import stats
import torch
from torch.autograd import Variable
import numpy as np
from learn import tools
import mimic_proxy
import datasets

LABEL = '995.92'

def get_variance(args):

    dicts = datasets.load_lookups(args, desc_embed=True)
    ind2w, w2ind, ind2c, c2ind = dicts['ind2w'], dicts['w2ind'], dicts['ind2c'], dicts['c2ind']
    LABEL_IDX = c2ind[LABEL]
    num_labels = len(ind2c)
    model = tools.pick_model(args, dicts)
    gen = datasets.data_generator(args.data_path.replace('train','test'), dicts, 
                                        1, num_labels,
                                        version=args.version, desc_embed=True)
    feature_weights = {} 

    model.zero_grad()
    for batch_idx, tup in tqdm(enumerate(gen)):
        # import ipdb;ipdb.set_trace()
        data, target, _, _, descs = tup
    # if target[0][LABEL_IDX] == 1:
        data, target = Variable(torch.LongTensor(data)), Variable(torch.FloatTensor(target))
        if args.gpu:
            data = data.cuda()
            target = target.cuda()

        get_attn = True #and (np.random.rand() < 0.02 or (fold == 'test' and testing))
        output, loss, alpha = model(data, target, desc_data=descs, get_attention=True)
        for w in range(data.shape[1]):
            wind = data[0][w].item()
            w_weight = alpha[0][LABEL_IDX][w].item()
            # import ipdb;ipdb.set_trace()
            if wind not in feature_weights:
                feature_weights[wind] = []
            feature_weights[wind].append(w_weight)
        output = torch.sigmoid(output)
        output = output.data.cpu().numpy()
    
    with open(f'drcaml-featureweights-{LABEL}.json', 'w') as fout:
        json.dump(feature_weights, fout, indent=2)

def analyze_variance(args):
    dicts = datasets.load_lookups(args, desc_embed=True)
    ind2w, w2ind, ind2c, c2ind = dicts['ind2w'], dicts['w2ind'], dicts['ind2c'], dicts['c2ind']
    with open(f'drcaml-featureweights-{LABEL}.json') as fin:
        feature_weights = json.load(fin)
    word_variance = {}
    for word_idx in feature_weights:
        if int(word_idx) in ind2w:
            word = ind2w[int(word_idx)]
            var = np.var(feature_weights[word_idx])
            word_variance[word] = {
                'var': np.var(feature_weights[word_idx]),
                'min':np.min(feature_weights[word_idx]),
                'max':np.max(feature_weights[word_idx]),
                'var/mean':np.var(feature_weights[word_idx])/np.abs(np.mean(feature_weights[word_idx]))
            }
    with open(f'drcaml-featurevariance-{LABEL}.json', 'w') as fout:
        json.dump(word_variance, fout, indent=2)

def main(args):
    get_variance(args)
    analyze_variance(args)



if __name__=="__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('data_path')
    parser.add_argument('vocab')
    parser.add_argument("--model", default="conv_attn", type=str, choices=["cnn_vanilla", "rnn", "conv_attn", "multi_conv_attn", "logreg", "saved"], help="model")
    parser.add_argument('--Y', default='full')
    parser.add_argument('--version', default='mimic3')
    # parser.add_argument("--rnn-layers", type=int, required=False, dest="rnn_layers", default=1,
    #                     help="number of layers for RNN models (default: 1)")
    parser.add_argument("--embed-file", type=str, required=False, dest="embed_file",
                        help="path to a file holding pre-trained embeddings")
    parser.add_argument("--embed-size", type=int, required=False, dest="embed_size", default=100,
                        help="size of embedding dimension. (default: 100)")
    parser.add_argument("--filter-size", type=str, required=False, dest="filter_size", default=4,
                        help="size of convolution filter to use. (default: 3) For multi_conv_attn, give comma separated integers, e.g. 3,4,5")
    parser.add_argument("--num-filter-maps", type=int, required=False, dest="num_filter_maps", default=50,
                        help="size of conv output (default: 50)")
    parser.add_argument("--pool", choices=['max', 'avg'], required=False, dest="pool", help="which type of pooling to do (logreg model only)")
    parser.add_argument("--code-emb", type=str, required=False, dest="code_emb", 
                        help="point to code embeddings to use for parameter initialization, if applicable")
    parser.add_argument("--weight-decay", type=float, required=False, dest="weight_decay", default=0,
                        help="coefficient for penalizing l2 norm of model weights (default: 0)")
    parser.add_argument("--lr", type=float, required=False, dest="lr", default=1e-3,
                        help="learning rate for Adam optimizer (default=1e-3)")
    parser.add_argument("--batch-size", type=int, required=False, dest="batch_size", default=16,
                        help="size of training batches")
    parser.add_argument("--dropout", dest="dropout", type=float, required=False, default=0.5,
                        help="optional specification of dropout (default: 0.5)")
    parser.add_argument("--lmbda", type=float, required=False, dest="lmbda", default=0.2,
                        help="hyperparameter to tradeoff BCE loss and similarity embedding loss. defaults to 0, which won't create/use the description embedding module at all. ")
    parser.add_argument("--ngram", dest="ngram_size", required=False, type=int, help="ngram size for explanations, defaults to filter size")
    parser.add_argument("--dataset", type=str, choices=['mimic2', 'mimic3'], dest="version", default='mimic3', required=False,
                        help="version of MIMIC in use (default: mimic3)")
    parser.add_argument("--test-model", type=str, dest="test_model", required=False, help="path to a saved model to load and evaluate")
    parser.add_argument("--criterion", type=str, default='f1_micro', required=False, dest="criterion",
                        help="which metric to use for early stopping (default: f1_micro)")
    parser.add_argument("--patience", type=int, default=3, required=False,
                        help="how many epochs to wait for improved criterion metric before early stopping (default: 3)")
    parser.add_argument("--fold", dest="fold", choices=['train', 'dev', 'test', 'eval'],
                        required=False, help="which fold to predict run on")
    parser.add_argument("--gpu", dest="gpu", action="store_const", required=False, const=True,
                        help="optional flag to use GPU if available")
    parser.add_argument("--public-model", dest="public_model", action="store_const", required=False, const=True,
                        help="optional flag for testing pre-trained models from the public github")
    parser.add_argument("--stack-filters", dest="stack_filters", action="store_const", required=False, const=True,
                        help="optional flag for multi_conv_attn to instead use concatenated filter outputs, rather than pooling over them")
    parser.add_argument("--samples", dest="samples", action="store_const", required=False, const=True,
                        help="optional flag to save samples of good / bad predictions")
    parser.add_argument("--quiet", dest="quiet", action="store_const", required=False, const=True,
                        help="optional flag not to print so much during training")
    parser.add_argument("--save-features-path", dest="save_features_path", required=False, help="path to save features to")
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')

    start = time.time()
    main(args)
    end = time.time()
    logging.info(f'Time to run script: {end-start} secs')