###############
#   Package   #
###############
import os
import time
import math
import logging
import numpy as np
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics

from typing import List, Dict, Tuple, Optional
from torch import Tensor

#######################
# package from myself #
#######################
from utils.util import ConsumingTime

#############
#   Class   #
#############
class PreTrainer():
    def __init__(self,
                 model: torch.nn.Module,
                 loss_fn: torch.nn.Module,
                 optimizer: torch.optim, # should be checked.
                 training_dataloader: torch.utils.data.DataLoader,
                 validation_dataloader: torch.utils.data.DataLoader,
                 logger: logging.Logger,
                 lr_scheduler: torch.optim = None,
                 checkpoint_save_path: str = '',
                 checkpoint_period: int = 5,
                 device: torch.device = torch.device("cpu"),
                 *args,
                 **kwargs,
                ):
            # check the correction of variables
            assert os.path.isdir(checkpoint_save_path), 'checkpoint saving dictionary does not exist.'

            # define the variables of the training step
            self.model = model
            
            self.loss_fn = loss_fn
            
            self.optimizer = optimizer

            self.training_dataloader = training_dataloader
            self.validation_dataloader = validation_dataloader
            
            self.lr_scheduler = lr_scheduler

            self.logger = logger
            self.ckp_save_path = checkpoint_save_path
            self.ckp_period = checkpoint_period

            self.device = device

    def _train_epoch(self) -> tuple:
        self.model.train()
        self.model.to(self.device)

        for batch_idx, (x_num_idx, x_num, x_num_mask, x_cat_idx, x_cat, x_cat_mask) in enumerate(self.training_dataloader):
            # put all variables to appropriate device.
            x_num_idx = x_num_idx.to(self.device)
            x_num = x_num.to(self.device)
            x_num_mask = x_num_mask.to(self.device)
            x_cat_idx = x_cat_idx.to(self.device)
            x_cat = x_cat.to(self.device)
            x_cat_mask = x_cat_mask.to(self.device)

            self.optimizer.zero_grad()

            # feed data to the model and get the output
            x_num_idx_reconst, x_num_reconst, x_cat_idx_reconst, x_cat_reconst = self.model(x_num_idx, x_num, x_num_mask, x_cat_idx, x_cat, x_cat_mask)

            # compute loss
            loss, num_idx_recons_loss, num_value_recons_loss, cat_idx_recons_loss, cat_value_recons_loss = self.loss_fn(x_num_idx, x_num_idx_reconst, x_num, x_num_reconst, x_cat_idx, x_cat_idx_reconst, x_cat, x_cat_reconst, x_num_mask, x_cat_mask)

            # optimization step
            loss.backward()
            self.optimizer.step()

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return (loss.detach().cpu().item(), num_idx_recons_loss.detach().cpu().item(), num_value_recons_loss.detach().cpu().item(), cat_idx_recons_loss.detach().cpu().item(), cat_value_recons_loss.detach().cpu().item())

    def _valid_epoch(self):
        with torch.no_grad():
            self.model.eval()
            self.model.to(self.device)

            for batch_idx, (x_num_idx, x_num, x_num_mask, x_cat_idx, x_cat, x_cat_mask) in enumerate(self.validation_dataloader):
                # put all variables to appropriate device.
                x_num_idx = x_num_idx.to(self.device)
                x_num = x_num.to(self.device)
                x_num_mask = x_num_mask.to(self.device)
                x_cat_idx = x_cat_idx.to(self.device)
                x_cat = x_cat.to(self.device)
                x_cat_mask = x_cat_mask.to(self.device)

                # feed data to the model and get the output
                x_num_idx_reconst, x_num_reconst, x_cat_idx_reconst, x_cat_reconst = self.model(x_num_idx, x_num, x_num_mask, x_cat_idx, x_cat, x_cat_mask)

                # compute loss
                loss, num_idx_recons_loss, num_value_recons_loss, cat_idx_recons_loss, cat_value_recons_loss = self.loss_fn(x_num_idx, x_num_idx_reconst, x_num, x_num_reconst, x_cat_idx, x_cat_idx_reconst, x_cat, x_cat_reconst, x_num_mask, x_cat_mask)

        return (loss.detach().cpu().item(), num_idx_recons_loss.detach().cpu().item(), num_value_recons_loss.detach().cpu().item(), cat_idx_recons_loss.detach().cpu().item(), cat_value_recons_loss.detach().cpu().item())

    def train(self, epoch: int = 100):
        for ep_idx in range(1, epoch+1):
            training_loss, training_num_idx_recons_loss, training_num_value_recons_loss, training_cat_idx_recons_loss, training_cat_value_recons_loss = self._train_epoch()
            val_loss, val_num_idx_recons_loss, val_num_value_recons_loss, val_cat_idx_recons_loss, val_cat_value_recons_loss = self._valid_epoch()

            # message builder
            msg_line_1 = f'Epoch [{ep_idx}/{epoch}] | '
            msg_line_2 = " "*(len(msg_line_1)-2) + "| "
            msg_line_1 += '(train) total loss = {:.6f}, num_idx_loss = {:.6f}, num_loss = {:.6f}, cat_idx_loss = {:.6f}, cat_loss = {:.6f}\n'.format(training_loss, training_num_idx_recons_loss, training_num_value_recons_loss, training_cat_idx_recons_loss, training_cat_value_recons_loss)
            msg_line_2 += '(valid) total loss = {:.6f}, num_idx_loss = {:.6f}, num_loss = {:.6f}, cat_idx_loss = {:.6f}, cat_loss = {:.6f}'.format(val_loss, val_num_idx_recons_loss, val_num_value_recons_loss, val_cat_idx_recons_loss, val_cat_value_recons_loss)
            msg = '\n' + msg_line_1 + msg_line_2

            self.logger.warning(msg)
            
            if ep_idx % self.ckp_period == 0:
                torch.save(self.model.state_dict(), os.path.join(self.ckp_save_path, f'ckt_ep_{ep_idx}.pth'))

if __name__ == '__main__':
    pass
