###############
#   Package   #
###############
import os
import time
import math
import logging
import pickle
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics

from typing import List, Dict, Tuple, Optional
from torch import Tensor

#######################
# package from myself #
#######################
from utils.util import ConsumingTime

#############
#   Class   #
#############
class TreeBasedTrainer():
    def __init__(self,
                 model,
                 loss_fn: torch.nn.Module,
                 training_dataloader: torch.utils.data.DataLoader,
                 validation_dataloader: torch.utils.data.DataLoader,
                 logger: logging.Logger,
                 checkpoint_save_path: str = '',
                 plot_probability_distribution: bool = False,
                 plot_save_path: str = '',
                 device: torch.device = torch.device("cpu"),
                 *args,
                 **kwargs,
                ):
            # check the correction of variables
            assert os.path.isdir(checkpoint_save_path), AssertionError('checkpoint saving dictionary does not exist.')
            assert not plot_probability_distribution or os.path.isdir(plot_save_path), AssertionError('plot saving dictionary does not exist.')

            # define the variables of the training step
            self.model = model
            self.loss_fn = loss_fn
            self.training_dataloader = training_dataloader
            self.validation_dataloader = validation_dataloader

            self.logger = logger
            self.ckp_save_path = checkpoint_save_path

            self.plot_probability_distribution = plot_probability_distribution
            self.plot_save_path = plot_save_path

            self.device = device

    def _get_train_data(self) -> Tuple[np.ndarray]:
        X_train = []
        y_train = []

        for batch_idx, (x_num_idx, x_num, x_num_mask, x_cat_idx, x_cat, x_cat_mask, target) in enumerate(self.training_dataloader):
            # change the type of training variable.
            x_num = x_num.detach().numpy()
            x_num_mask = x_num_mask.detach().numpy()
            x_cat = x_cat.detach().numpy()
            x_cat_mask = x_cat_mask.detach().numpy()
            #target = target.detach().numpy() # not sure

            # collect the variable
            X = np.concatenate([x_num, x_cat, x_num_mask, x_cat_mask], axis=-1)

            X_train.append(X)
            y_train.append(target)
        X_train = np.concatenate(X_train)
        y_train = np.concatenate(y_train)

        N_train = X_train.shape[0]
        X_train = X_train.reshape(N_train, -1)
        y_train = np.squeeze(y_train)

        return (X_train, y_train)

    def _get_valid_data(self) -> Tuple[np.ndarray]:
        X_val = []
        y_val = []
        for batch_idx, (x_num_idx, x_num, x_num_mask, x_cat_idx, x_cat, x_cat_mask, target) in enumerate(self.validation_dataloader):
            # put all variables to appropriate device.
            x_num = x_num.detach().numpy()
            x_num_mask = x_num_mask.detach().numpy()
            x_cat = x_cat.detach().numpy()
            x_cat_mask = x_cat_mask.detach().numpy()

            X = np.concatenate([x_num, x_cat, x_num_mask, x_cat_mask], axis=-1)

            X_val.append(X)
            y_val.append(target)
        X_val = np.concatenate(X_val)
        y_val = np.concatenate(y_val)

        N_val = X_val.shape[0]
        X_val = X_val.reshape(N_val, -1)
        y_val = np.squeeze(y_val)

        return (X_val, y_val)

    def _probability_distribution_plot(self, prediction: np.ndarray, targets: np.ndarray, fig_name: str = "valid") -> None:
        prediction_0 = prediction[(targets == 0)]
        prediction_1 = prediction[(targets == 1)]
        fig = plt.figure()
        sns.histplot(prediction_0, stat='density', bins=[0.01*x for x in range(101)], edgecolor="none", kde=True, color='green', label=f'label 0 ({len(prediction_0)} samples)')
        sns.histplot(prediction_1, stat='density', bins=[0.01*x for x in range(101)], edgecolor="none", kde=True, color='red', label=f'label 1 ({len(prediction_1)} samples)')
        plt.ylabel('Density')
        plt.xlabel('Probability')
        plt.legend()
        plt.title(fig_name)
        fig.savefig(os.path.join(self.plot_save_path, fig_name+'.png'))
        plt.close()

    def train(self, epoch: int = 100):
        # we only track loss, accuracy, AUROC, AUPRC, and c_index of the training set and the validation set.
        ### c_index not implemented.
        key_list = ['loss', 'val_loss', 'accuracy', 'val_accuracy', 'AUROC', 'val_AUROC', 'AUPRC', 'val_AUPRC']
        record = dict([(key, []) for key in key_list])
        
        # define metric measurer
        accuracy_measurer = torchmetrics.classification.BinaryAccuracy().to(self.device)
        AUROC_measurer = torchmetrics.classification.BinaryAUROC().to(self.device)
        AUPRC_measurer = torchmetrics.classification.BinaryAveragePrecision().to(self.device)

        # get training data and validation data
        training_x, training_y = self._get_train_data()
        val_x, val_y = self._get_valid_data()

        # train model
        self.model.fit(training_x, training_y)

        # get the prediction of the training set
        training_outputs = torch.from_numpy(self.model.predict_proba(training_x)[:, 1]).to(self.device)
        training_targets = torch.from_numpy(training_y).to(self.device)
        val_outputs = torch.from_numpy(self.model.predict_proba(val_x)[:, 1]).to(self.device)
        val_targets = torch.from_numpy(val_y).to(self.device)

        # loss
        training_loss = self.loss_fn(training_outputs, training_targets)
        val_loss = self.loss_fn(val_outputs, val_targets)
        record['loss'].append(training_loss.cpu().item())
        record['val_loss'].append(val_loss.cpu().item())

        # accuracy
        training_acc = accuracy_measurer(training_outputs, training_targets)
        val_acc = accuracy_measurer(val_outputs, val_targets)
        record['accuracy'].append(training_acc.cpu().item())
        record['val_accuracy'].append(val_acc.cpu().item())

        # AUROC
        training_AUROC = AUROC_measurer(training_outputs, training_targets)
        val_AUROC = AUROC_measurer(val_outputs, val_targets)
        record['AUROC'].append(training_AUROC.cpu().item())
        record['val_AUROC'].append(val_AUROC.cpu().item())

        # AUPRC
        training_AUPRC = AUPRC_measurer(training_outputs, training_targets.long())
        val_AUPRC = AUPRC_measurer(val_outputs, val_targets.long())
        record['AUPRC'].append(training_AUPRC.cpu().item())
        record['val_AUPRC'].append(val_AUPRC.cpu().item())

        # message builder
        msg_line_1 = f'Epoch [1/1] | '
        msg_line_2 = " "*(len(msg_line_1)-2) + "| "
        msg_line_1 += '(train) loss = {:.6f}, accuracy = {:.6f}, AUROC = {:.6f}, AUPRC = {:.6f}\n'.format(record['loss'][-1], record['accuracy'][-1], record['AUROC'][-1], record['AUPRC'][-1])
        msg_line_2 += '(valid) loss = {:.6f}, accuracy = {:.6f}, AUROC = {:.6f}, AUPRC = {:.6f}'.format(record['val_loss'][-1], record['val_accuracy'][-1], record['val_AUROC'][-1], record['val_AUPRC'][-1])
        msg = '\n' + msg_line_1 + msg_line_2

        self.logger.warning(msg)
        
        if self.plot_probability_distribution:
            self._probability_distribution_plot(training_outputs.cpu().detach().numpy(), training_targets.cpu().detach().numpy(), f'Ep_{ep_idx}_train')
            self._probability_distribution_plot(val_outputs.cpu().detach().numpy(), val_targets.cpu().detach().numpy(), f'Ep_{ep_idx}_valid')

        try:
            self.model.save_model(os.path.join(self.ckp_save_path, f'model.txt'))
        except:
            pickle.dump(self.model, open(os.path.join(self.ckp_save_path, f'model.txt'), 'wb'))

if __name__ == '__main__':
    pass
