import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import argparse
import time

from datasets import Dataset, Sequence
from models.medair_models.bertha.modelling_bertha import BerthaForSequenceClassification
from models.medair_models.bertha.configuration_bertha import BerthaConfig
from torch.optim import Adam
from transformers import get_constant_schedule, Trainer, TrainingArguments, AutoTokenizer
from evaluate import load
from collators import DataCollatorForMultiLabelsDiseasePrediction

import torch
import numpy as np


parser = argparse.ArgumentParser(description='choose VSA embedding variables')
parser.add_argument('--snomed_group', type=str, choices=['ignore', 'group_zero', 'group_vectors'], default='ignore')
parser.add_argument('--dp_type', type=str, choices=['words', 'words_rv', 'atomic'], default='atomic')
parser.add_argument('--dp_composition', type=str, choices=['icd_name', 'snomed_name', 'snomed_all'], default='snomed_all')
parser.add_argument('--lr', type=float, default=4e-5)
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--model_checkpoint', type=str, default=None)
parser.add_argument('--positive_class_weight', type=float, default=1)
parser.add_argument('--num_trials', type=int, choices=range(1, 15), default=3)

# Using random embeddings overrides VSA options above
use_vsa_parser = parser.add_mutually_exclusive_group(required=False)
use_vsa_parser.add_argument('--use_vsa', dest='use_vsa', action='store_true')
use_vsa_parser.add_argument('--use_random_embeddings', dest='use_vsa', action='store_false')
parser.set_defaults(use_vsa=True)

args, unknown = parser.parse_known_args() 

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

BATCH_SIZE = 80

LEARNING_RATE = args.lr  # By default 4e-5

EPOCHS = args.epochs  # By default 10

DROPOUT_PROB = 0.3

DIM = 768

INITIALIZER_STD = 0.02

WEIGHT_DECAY = 4e-6

def compute_metrics(eval_pred):
    """
    Calculate metrics for model predictions during evaluation.

    Parameters
    ----------
    eval_pred: 
        Predictions from the model including the logits or preds, and the corresponding labels

    Returns
    -------
    metrics: dict {str: float}
        Metrics dictionary
    """
    preds, labels = eval_pred
    # print(preds)
    # print(preds[1].shape)
    # print(preds[1].reshape(preds.shape[0]*preds.shape[1]))
    # print(labels.flatten().shape)
    
    # active_mask = labels.flatten() != -100
    # active_labels = labels.flatten()[active_mask]
    # active_preds = preds.flatten()[active_mask]
    # return accuracy_metric.compute(
    #         predictions=(preds.detach().cpu().numpy()>0).astype(int), 
    #         references=labels.detach().cpu().numpy(),
    #         num_labels = 2,
    #         ignore_index = -1,
    #         reduce_labels=False
    #         )
    return accuracy_metric.compute(predictions=preds[0].flatten(), references=labels.flatten())

# def miou(pred,label):
#     overlap = pred*label # Logical AND
#     union = pred + label # Logical OR
#     IOU = overlap.sum()/float(union.sum()) 
#     return IOU

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        """
        Custom loss function for calculating the Cross Entropy loss given set of inputs and the training accuracy

         Parameters
         ----------
         model: BERT-type model

         inputs: array-like
            Inputs to the model for computing loss and accuracy

        return_outputs: bool
            Whether to return outputs in a tuple with loss. Defaults to False

        Returns
        --------
        loss: float
            Calculated loss for the given inputs.

        Yields
        ------
        acc: float
            Training accuracy on the inputs given
        """
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        loss = outputs.get("loss")
        # compute loss        
        # loss_fct = torch.nn.BCEWithLogitsLoss()
        # loss = loss_fct(logits, labels)

        # preds = torch.argmax(logits, axis=-1)
        # # Get active labels
        # active_mask = labels.view(-1) != -100
        # print(logits.shape)
        # print(labels.shape)
        # print(preds.shape)
        # print(active_mask.shape)
        # active_labels = torch.masked_select(labels.view(-1), active_mask)
        # active_preds = torch.masked_select(preds.view(-1), active_mask)
        #acc = miou((logits.detach().cpu().numpy()>0).astype(int), labels.detach().cpu().numpy())
        acc = accuracy_metric.compute(
            predictions=(logits.detach().cpu().numpy()>0).flatten().astype(int), 
            references=labels.detach().cpu().numpy().flatten(),
            )
        self.log(acc)
        return (loss, outputs) if return_outputs else loss

    def create_optimizer_and_scheduler(self, num_training_steps: int):
        """
        Custom function for setting the optimizer and the scheduler for the learning rate

        Parameters
        ----------

        num_training_steps: int
            Total number of steps expected for training
    
        """
        self.optimizer = Adam(params=self.model.parameters(), lr=self.args.learning_rate, weight_decay=self.args.weight_decay)
        self.lr_scheduler = get_constant_schedule(
            optimizer=self.optimizer)


def preprocess_logits_for_metrics(logits, labels):
    """ 
    Reduce the dimensions of the predictions before they are passed to the compute_metrics function to prevent the out of memory errors

    Parameters
    ----------
    logits: array-like
        predictions
    labels: array-like
        corresponding token labels
    Returns
    -------
    pred_ids: array-like
        predictions with reduced dimensions
    labels: array-like
        corresponding token labels
    """
    pred_ids = (logits>0).int()
    return pred_ids, labels


if __name__ == "__main__":
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    if torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'

    tokenizer = AutoTokenizer.from_pretrained('./tokenizer-mimic-iv-icd-final')
    dataset_train = Dataset.load_from_disk('./data/Finetuning_Disease_Prediction/train')
    dataset_test = Dataset.load_from_disk('./data/Finetuning_Disease_Prediction/test')

  
    tokenization_params = {
        'max_length': 128,
        'truncation': True,
        'padding': 'max_length',
        'is_split_into_words': True,
        'return_special_tokens_mask': True
    }

    vsa_args = {
        "dp_composition": args.dp_composition,
        "dp_type": args.dp_type,
        "snomed_groups": args.snomed_group
    }

    config = BerthaConfig(
        vocab_size=tokenizer.vocab_size,
        hidden_dropout_prob=DROPOUT_PROB,
        attention_probs_dropout_prob=DROPOUT_PROB,
        max_position_embeddings=tokenization_params['max_length'],
        type_vocab_size=100,
        initializer_range=INITIALIZER_STD,
        layer_norm_position="post",
        position_embedding_type="absolute",
        apply_ffn_to_embeddings=None,
        freeze_word_embeddings=False,
        normalize_word_embeddings=False,
        path_to_vsa_parser_data="./data/icd_to_vsa_data.pkl",
        tokenizer_vocab=tokenizer.vocab,
        use_vsa=args.use_vsa,
        # num_labels=22,
        
        **vsa_args )

    
    accuracy_metric = load("accuracy")

    def model_init():
        if args.model_checkpoint is None:
            print("Instantiating a new model...")
            model = BerthaForSequenceClassification(config)
        else:
            print(f"Loading pre-trained model from {args.model_checkpoint}")
            model = BerthaForSequenceClassification.from_pretrained(args.model_checkpoint, problem_type="multi_label_classification", num_labels=22)
        model.to(device)
        return model

    seeds = [42, 123, 500, 87, 11, 5, 55, 57, 100, 301, 32, 77, 97, 117, 3]

    for trial in range(args.num_trials):
        output_log_folder = '~/output/disease_finetune_experiments_rood/'
        timestr = time.strftime("%Y_%m_%d-%H%M%S_")

        if args.use_vsa:
            output_dir = output_log_folder + timestr + f'{args.snomed_group}_{args.dp_type}_{args.dp_composition}/output'
            logging_dir = output_log_folder + timestr + f'{args.snomed_group}_{args.dp_type}_{args.dp_composition}/logs'
        else:
            output_dir = output_log_folder + timestr + "no_vsa/output"
            logging_dir = output_log_folder + timestr + "no_vsa/logs"
        

        trainingArgs = TrainingArguments(
            output_dir=output_dir,
            logging_dir=logging_dir,
            seed=seeds[trial],
            do_eval=True,
            do_train=True,
            evaluation_strategy="epoch",
            save_strategy="epoch",
            save_total_limit=2,
            load_best_model_at_end=True,
            learning_rate=LEARNING_RATE,
            weight_decay=WEIGHT_DECAY,
            per_device_train_batch_size=BATCH_SIZE, 
            per_device_eval_batch_size=8,
            eval_accumulation_steps=20,
            num_train_epochs=EPOCHS,
            logging_strategy="steps",
            logging_steps=32, 
            gradient_accumulation_steps=1,
            fp16=False,
            bf16=False,
            tf32=True,
            ddp_find_unused_parameters=False
        )


        trainer = CustomTrainer(
            model_init=model_init,
            args=trainingArgs,
            train_dataset=dataset_train,
            eval_dataset=dataset_test,
            compute_metrics=compute_metrics,
            tokenizer=tokenizer,
            data_collator=DataCollatorForMultiLabelsDiseasePrediction(tokenizer),
            preprocess_logits_for_metrics=preprocess_logits_for_metrics
        )

        trainer.train()
