import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import classification_report
from transformers import AutoModel, AutoTokenizer, AdamW
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
import argparse
import logging
import time
import sys
import os
from tqdm import tqdm
import datetime
from clinical_bert.bert import BERT_Arch
import datasets
import evaluation
from sklearn import metrics
from scipy import stats
import mimic_proxy
from log_reg import read_bows, construct_X_Y
from utils import *
from constants import *

def labels_to_ind(c2ind, labels):
    # import ipdb;ipdb.set_trace()
    label_idxs = np.zeros((len(labels), len(c2ind)))
    # label_idxs = []
    for doc_idx in range(len(labels)):
        try:
            lls = labels[doc_idx].split(';')
        except AttributeError: # Document has no labels
            lls = []
        for l in lls:
            # label_idxs[doc_idx].append(c2ind[l])
            label_idxs[doc_idx][c2ind[l]] = 1
    return label_idxs

def one_epoch(model, train_dataloader, cross_entropy, optimizer):
    model.train()
    total_loss, total_accuracy = 0, 0
    total_preds=[]

    for step,batch in enumerate(train_dataloader):
        # progress update after every 50 batches.
        if step % 50 == 0 and not step == 0:
            print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(train_dataloader)))

        batch = [r.to(device) for r in batch]
        sent_id, mask, labels = batch
        model.zero_grad()        
        probs = model(sent_id, mask)
        # import ipdb;ipdb.set_trace()
        # _l = torch.max(labels, 1)[1]
        # loss = cross_entropy(probs, torch.max(labels, 1)[1])
        loss =  cross_entropy(probs, labels.float())
        total_loss = total_loss + loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        probs=probs.detach().cpu().numpy()
        total_preds.append(probs)

    # compute the training loss of the epoch
    avg_loss = total_loss / len(train_dataloader)
    total_preds  = np.concatenate(total_preds, axis=0)
    return avg_loss, total_preds

# function for evaluating the model
def evaluate(model, dataloader, cross_entropy):
    print("\nEvaluating...")
    model.eval()
    total_loss, total_accuracy = 0, 0
    total_preds = []
    # iterate over batches
    for step,batch in enumerate(dataloader):
        if step % 50 == 0 and not step == 0:
            print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(dataloader)))

        batch = [t.to(device) for t in batch]
        sent_id, mask, labels = batch
        with torch.no_grad():
            probs = model(sent_id, mask)
            loss =  cross_entropy(probs, labels.float())
            total_loss = total_loss + loss.item()
            probs = probs.detach().cpu().numpy()
            total_preds.append(probs)

    # compute the validation loss of the epoch
    avg_loss = total_loss / len(dataloader) 
    # reshape the predictions in form of (number of samples, no. of classes)
    total_preds  = np.concatenate(total_preds, axis=0)

    return avg_loss, total_preds

def train(bert, tokenizer, data_path, model_path, dicts,
        max_seq_len=512,batch_size=32, 
        lr=1e-5, nepochs=10):
    train_data = pd.read_csv(data_path)
    tokens_train = tokenizer.batch_encode_plus(
        train_data['TEXT'].tolist(),
        max_length = max_seq_len,
        truncation=True,
        padding=True
    )
    train_seq = torch.tensor(tokens_train['input_ids'])
    train_mask = torch.tensor(tokens_train['attention_mask'])
    train_y = torch.tensor(labels_to_ind(dicts['c2ind'], train_data['LABELS'].tolist())).to(dtype=torch.long)

    train_data = TensorDataset(train_seq, train_mask, train_y)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

    val_data = pd.read_csv(data_path.replace('train','dev'))
    tokens_val = tokenizer.batch_encode_plus(
        val_data['TEXT'].tolist(),
        max_length = max_seq_len,
        truncation=True,
        padding=True
    )

    val_seq = torch.tensor(tokens_val['input_ids'])
    val_mask = torch.tensor(tokens_val['attention_mask'])
    val_y = torch.tensor(labels_to_ind(dicts['c2ind'], val_data['LABELS'].tolist())).to(dtype=torch.long)

    val_data = TensorDataset(val_seq, val_mask, val_y)
    val_sampler = SequentialSampler(val_data)
    val_dataloader = DataLoader(val_data, sampler = val_sampler, batch_size=batch_size)

    # freeze embeddings
    # for param in bert.parameters():
    #     param.requires_grad = False
    
    model = BERT_Arch(bert, len(dicts['c2ind']))
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)
    model = model.to(device)

    optimizer = AdamW(model.parameters(), lr=lr)
    cross_entropy = F.binary_cross_entropy_with_logits


    best_valid_loss = float('inf')

    train_losses=[]
    valid_losses=[]

    for epoch in range(nepochs):
        print('\n Epoch {:} / {:}'.format(epoch + 1, nepochs))
        train_loss, _ = one_epoch(model, train_dataloader, cross_entropy, optimizer)

        valid_loss, _ = evaluate(model, val_dataloader, cross_entropy)
        
        #save the best model
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), model_path)
        
        # append training and validation loss
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)
        
        print(f'\nTraining Loss: {train_loss:.3f}')
        print(f'Validation Loss: {valid_loss:.3f}')

def test(bert, tokenizer, data_path, model_path, dicts, outdir,
        max_seq_len=512, batch_size=32,
        threshold=0.5):
    c2ind, ind2c = dicts['c2ind'], dicts['ind2c']
    model = BERT_Arch(bert, len(dicts['c2ind']))
    model.load_state_dict(torch.load(model_path))
    model = model.to(device)
    test_data = pd.read_csv(data_path.replace('train', 'test'))
    tokens_test = tokenizer.batch_encode_plus(
        test_data['TEXT'].tolist(),
        max_length = max_seq_len,
        truncation=True,
        padding=True
    )
    test_seq = torch.tensor(tokens_test['input_ids'])
    test_mask = torch.tensor(tokens_test['attention_mask'])
    test_y = torch.tensor(labels_to_ind(c2ind, test_data['LABELS'].tolist())).to(dtype=torch.long)

    # test_data = TensorDataset(test_seq, test_mask, test_y)
    # test_sampler = RandomSampler(test_data)
    # test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)
    all_preds_idxs = []
    all_preds_labels = []
    all_probs = {}
    all_probs_matrix = []
    with torch.no_grad():
        for doc_idx, sent_id in tqdm(enumerate(test_seq)):
            sent_id, mask = sent_id.to(device).unsqueeze(0), test_mask[doc_idx].to(device).unsqueeze(0)
            probs = model(sent_id, mask)
            probs = probs.detach().cpu().numpy()
            probs = np.exp(probs)
            HID = str(test_data.iloc[doc_idx]['HADM_ID'])
            all_probs_matrix.append(probs[0].tolist())
            all_probs[HID] = {}
            for label_idx in range(probs.shape[1]):
                all_probs[HID][ind2c[label_idx]] = float(probs[0][label_idx])
                # import ipdb;ipdb.set_trace()
            idx_preds = np.where(probs>threshold)[1].tolist()
            all_preds_idxs.append(idx_preds)
            label_preds = [dicts['ind2c'][i] for i in idx_preds]
            all_preds_labels.append(label_preds)
    with open(os.path.join(outdir, 'pred_all_scores_test.json'), 'w') as fout:
        json.dump(all_probs, fout, indent=2)
    
    calc_classification_metrics(test_y.numpy(), np.array(all_probs_matrix), ind2c)

def main(args):
    dicts = datasets.load_lookups(args)
    tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
    bert = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", 
                                    num_labels=len(dicts['c2ind']),
                                    output_attentions=False,
                                    output_hidden_states=False)
    if args.mode == 'train':
        train(bert, tokenizer, args.data_path, args.model_path, dicts,
            max_seq_len=512,batch_size=args.bsz, 
            lr=args.lr, nepochs=args.nepochs)
    elif args.mode == 'test':
        test(bert, tokenizer, args.data_path, args.model_path, dicts, args.test_outdir,
        max_seq_len=512, batch_size=args.bsz)

if __name__=="__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('mode', choices=['train','test'])
    parser.add_argument('model_name')
    parser.add_argument('data_path')
    parser.add_argument('vocab')
    parser.add_argument('--Y', default='full')
    parser.add_argument('--batch_size', '--bsz', default=32, dest='bsz', type=int)
    parser.add_argument('--learning_rate', '--lr', default=1e-5, dest='lr', type=int)
    parser.add_argument('--nepochs', default=10, type=int)
    parser.add_argument('--test_outdir')
    parser.add_argument('--version', default='mimic3')
    parser.add_argument("--public-model", dest="public_model", action="store_const", required=False, const=True,
                        help="optional flag for testing pre-trained models from the public github")
    parser.add_argument('--gpu', action='store_true', default=False)
    parser.add_argument('--notes')
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')

    start = time.time()
    device = torch.device("cuda" if args.gpu else "cpu")
    date = datetime.datetime.now().strftime("%b_%d_%H:%M")
    model_dir = os.path.join(MODEL_DIR, f'clinical_bert_{date}')
    os.makedirs(model_dir)
    args.model_path = os.path.join(model_dir, args.model_name)
    if args.notes:
        with open(os.path.join(model_dir, 'notes.txt'), 'w') as fout:
            fout.write(args.notes)
    main(args)
    end = time.time()
    logging.info(f'Time to run script: {end-start} secs')
