import os
import argparse
import time
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import utils
from models.bert_labeler import bert_labeler
from datasets.impressions_dataset import ImpressionsDataset
from constants import *

def collate_fn_labels(sample_list):
     """Custom collate function to pad reports in each batch to the max len
     @param sample_list (List): A list of samples. Each sample is a dictionary with
                                keys 'imp', 'label', 'len' as returned by the __getitem__
                                function of ImpressionsDataset
     
     @returns batch (dictionary): A dictionary with keys 'imp', 'label', 'len' but now
                                  'imp' is a tensor with padding and batch size as the
                                   first dimension. 'label' is a stacked tensor of labels
                                   for the whole batch with batch size as first dim. And
                                   'len' is a list of the length of each sequence in batch
     """
     tensor_list = [s['imp'] for s in sample_list]
     batched_imp = torch.nn.utils.rnn.pad_sequence(tensor_list,
                                                   batch_first=True,
                                                   padding_value=PAD_IDX)
     label_list = [s['label'] for s in sample_list]
     batched_label = torch.stack(label_list, dim=0)
     len_list = [s['len'] for s in sample_list]
     
     batch = {'imp': batched_imp, 'label': batched_label, 'len': len_list}
     return batch

def load_data(train_csv_path, train_list_path, dev_csv_path,
              dev_list_path, train_weights=None, batch_size=BATCH_SIZE,
              shuffle=True, num_workers=NUM_WORKERS):
     """ Create ImpressionsDataset objects for train and test data
     @param train_csv_path (string): path to training csv file containing labels 
     @param train_list_path (string): path to list of encoded impressions for train set
     @param dev_csv_path (string): same as train_csv_path but for dev set
     @param dev_list_path (string): same as train_list_path but for dev set
     @param train_weights (torch.Tensor): Tensor of shape (train_set_size) containing weights
                                          for each training example, for the purposes of batch
                                          sampling with replacement
     @param batch_size (int): the batch size. As per the BERT repository, the max batch size
                              that can fit on a TITAN XP is 6 if the max sequence length
                              is 512, which is our case. We have 3 TITAN XP's
     @param shuffle (bool): Whether to shuffle data before each epoch, ignored if train_weights
                            is not None
     @param num_workers (int): How many worker processes to use to load data

     @returns dataloaders (tuple): tuple of two ImpressionsDataset objects, for train and dev sets
     """
     collate_fn = collate_fn_labels
     train_dset = ImpressionsDataset(train_csv_path, train_list_path)
     dev_dset = ImpressionsDataset(dev_csv_path, dev_list_path)

     if train_weights is None:
          train_loader = torch.utils.data.DataLoader(train_dset, batch_size=batch_size, shuffle=shuffle,
                                                     num_workers=num_workers, collate_fn=collate_fn)
     else:
          sampler = torch.utils.data.WeightedRandomSampler(weights=train_weights,
                                                           num_samples=len(train_weights),
                                                           replacement=True)
          train_loader = torch.utils.data.DataLoader(train_dset,
                                                     batch_size=batch_size,
                                                     num_workers=num_workers,
                                                     collate_fn=collate_fn,
                                                     sampler=sampler)
          
     dev_loader = torch.utils.data.DataLoader(dev_dset, batch_size=batch_size, shuffle=shuffle,
                                              num_workers=num_workers, collate_fn=collate_fn)
     dataloaders = (train_loader, dev_loader)
     return dataloaders

def load_test_data(test_csv_path, test_list_path, batch_size=BATCH_SIZE, 
                   num_workers=NUM_WORKERS, shuffle=False):
     """ Create ImpressionsDataset object for the test set
     @param test_csv_path (string): path to test csv file containing labels 
     @param test_list_path (string): path to list of encoded impressions
     @param batch_size (int): the batch size. As per the BERT repository, the max batch size
                              that can fit on a TITAN XP is 6 if the max sequence length
                              is 512, which is our case. We have 3 TITAN XP's 
     @param num_workers (int): how many worker processes to use to load data 
     @param shuffle (bool): whether to shuffle the data or not

     @returns test_loader (dataloader): dataloader object for test set
     """
     collate_fn = collate_fn_labels
     test_dset = ImpressionsDataset(test_csv_path, test_list_path)
     test_loader = torch.utils.data.DataLoader(test_dset, batch_size=batch_size, shuffle=shuffle,
                                               num_workers=num_workers, collate_fn=collate_fn)
     return test_loader

def train(save_path, dataloaders, f1_weights, model=None, device=None,
          optimizer=None, lr=LEARNING_RATE, log_every=LOG_EVERY,
          valid_niter=VALID_NITER, best_metric=0.0):
     """ Main training loop for the labeler
     @param save_path (string): Directory in which model weights are stored
     @param model (nn.Module): the labeler model to train, if applicable
     @param device (torch.device): device for the model. If model is not None, this
                                   parameter is required
     @param dataloaders (tuple): tuple of dataloader objects as returned by load_data
     @param f1_weights (dictionary): maps conditions to weights for blank, negation,
                                     uncertain and positxive f1 task averaging
     @param optimizer (torch.optim.Optimizer): the optimizer to use, if applicable
     @param lr (float): learning rate to use in the optimizer, ignored if optimizer
                        is not None
     @param log_every (int): number of iterations to log after
     @param valid_niter (int): number of iterations after which to evaluate the model and
                               save it if it is better than old best model
     @param best_metric (float): save checkpoints only if dev set performance is higher
                                than best_metric
     """
     if model and not device:
          print("train function error: Model specified but not device")
          return
     
     if model is None:
          model = bert_labeler(pretrain_path=PRETRAIN_PATH)
          model.train()   #put the model into train mode
          device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
          if torch.cuda.device_count() > 1:
               print("Using", torch.cuda.device_count(), "GPUs!")
               model = nn.DataParallel(model) #to utilize multiple GPU's
          model = model.to(device)
     else:
          model.train()
          
     if optimizer is None:
          optimizer = torch.optim.Adam(model.parameters(), lr=lr)
          
     begin_time = time.time()
     report_examples = 0
     report_loss = 0.0
     train_ld = dataloaders[0]
     dev_ld = dataloaders[1]
     loss_func = nn.CrossEntropyLoss(reduction='sum')
     
     print('begin labeler training')
     for epoch in range(NUM_EPOCHS):
          for i, data in enumerate(train_ld, 0):
               batch = data['imp'] #(batch_size, max_len)
               batch = batch.to(device)
               label = data['label'] #(batch_size, 14)
               label = label.permute(1, 0).to(device)
               src_len = data['len']
               batch_size = batch.shape[0]
               attn_mask = utils.generate_attention_masks(batch, src_len, device)

               optimizer.zero_grad()
               out = model(batch, attn_mask) #list of 14 tensors

               batch_loss = 0.0
               for j in range(len(out)):
                    batch_loss += loss_func(out[j], label[j])
                    
               report_loss += batch_loss
               report_examples += batch_size
               loss = batch_loss / batch_size     
               loss.backward()
               optimizer.step()

               if (i+1) % log_every == 0:
                    print('epoch %d, iter %d, avg_loss %.3f, time_elapsed %.3f sec' % (epoch+1, i+1, report_loss/report_examples,
                                                                                       time.time() - begin_time))
                    report_loss = 0.0
                    report_examples = 0
                    
               if (i+1) % valid_niter == 0:
                    print('\n begin validation')
                    metrics = utils.evaluate(model, dev_ld, device, f1_weights)
                    weighted = metrics['weighted']
                    kappas = metrics['kappa']

                    for j in range(len(CONDITIONS)):
                         print('%s kappa: %.3f' % (CONDITIONS[j], kappas[j]))
                    print('average: %.3f' % (np.mean(kappas)))
                         
                    #for j in range(len(CONDITIONS)):
                    #     print('%s weighted_f1: %.3f' % (CONDITIONS[j], weighted[j]))
                    #print('average of weighted_f1: %.3f' % (np.mean(weighted)))

                    for j in range(len(CONDITIONS)):
                         print('%s blank_f1:  %.3f, negation_f1: %.3f, uncertain_f1: %.3f, positive: %.3f' % (CONDITIONS[j],
                                                                                                              metrics['blank'][j],
                                                                                                              metrics['negation'][j],
                                                                                                              metrics['uncertain'][j],
                                                                                                              metrics['positive'][j]))
                         
                    metric_avg = np.mean(kappas)
                    if metric_avg > best_metric: #new best network
                         print("saving new best network!\n")
                         best_metric = metric_avg
                         path = os.path.join(save_path, "model_epoch%d_iter%d" % (epoch+1, i+1))
                         torch.save({'epoch': epoch+1,
                                     'model_state_dict': model.state_dict(),
                                     'optimizer_state_dict': optimizer.state_dict()},
                                    path)

def model_from_ckpt(model, ckpt_path):
     """Load up model checkpoint
     @param model (nn.Module): the module to be loaded
     @param ckpt_path (string): path to a checkpoint. If this is None, then
                                model is trained from scratch

     @return (tuple): tuple containing the model, optimizer and device
     """
     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
     if torch.cuda.device_count() > 1:
          print("Using", torch.cuda.device_count(), "GPUs!")
          model = nn.DataParallel(model) #to utilize multiple GPU's
     model = model.to(device)
     optimizer = torch.optim.Adam(model.parameters())

     checkpoint = torch.load(ckpt_path)
     model.load_state_dict(checkpoint['model_state_dict'])
     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

     return (model, optimizer, device)

if __name__ == '__main__':
     parser = argparse.ArgumentParser(description='Train BERT-base model on task of labeling 14 medical conditions.')
     parser.add_argument('--train_csv', type=str, nargs='?', required=True,
                         help='path to csv containing train reports.')
     parser.add_argument('--dev_csv', type=str, nargs='?', required=True,
                         help='path to csv containing dev reports.')
     parser.add_argument('--train_imp_list', type=str, nargs='?', required=True,
                         help='path to list of tokenized train set report impressions')
     parser.add_argument('--dev_imp_list', type=str, nargs='?', required=True,
                         help='path to list of tokenized dev set report impressions')
     parser.add_argument('--output_dir', type=str, nargs='?', required=True,
                         help='path to output directory where checkpoints will be saved')
     parser.add_argument('--checkpoint', type=str, nargs='?', required=False,
                         help='path to existing checkpoint to initialize weights from')
     args = parser.parse_args()
     train_csv_path = args.train_csv
     dev_csv_path = args.dev_csv
     train_imp_path = args.train_imp_list
     dev_imp_path = args.dev_imp_list
     out_path = args.output_dir
     checkpoint_path = args.checkpoint

     if checkpoint_path:
          model, optimizer, device = model_from_ckpt(bert_labeler(), checkpoint_path)
     else:
          model, optimizer, device = None, None, None
     f1_weights = utils.get_weighted_f1_weights(dev_csv_path)
     dataloaders = load_data(train_csv_path, train_imp_path, dev_csv_path, dev_imp_path)
     train(save_path=out_path,
           dataloaders=dataloaders,
           model=model,
           optimizer=optimizer,
           device=device, 
           f1_weights=f1_weights)
     
