###############
#   Package   #
###############
import os
import time
import math
import logging
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics

from typing import List, Dict, Tuple, Optional
from torch import Tensor
from sklearn.utils import resample

#######################
# package from myself #
#######################
from metric.concordance_index import ConcordanceIndex

#############
#   Class   #
#############
class BaseTester():
    def __init__(self,
                 model: torch.nn.Module,
                 loss_fn: torch.nn.Module,
                 testing_dataloader: torch.utils.data.DataLoader,
                 logger: logging.Logger,
                 resume_model_path: str = None,
                 device: torch.device = torch.device("cpu"),
                ):
        # check if the resumed model (.pth) file exists.
        assert os.path.isfile(resume_model_path), 'resumed model file error.'

        # define the variables of the tester
        self.model = model
        self.model.load_state_dict(torch.load(resume_model_path))
        self.loss_fn = loss_fn
        self.testing_dataloader = testing_dataloader
        
        self.logger = logger

        self.device = device

    def _test_epoch(self) -> Tuple[Tensor]:
        with torch.no_grad():
            self.model.eval()
            self.model.to(self.device)
            outputs = torch.FloatTensor().to(self.device)
            targets = torch.FloatTensor().to(self.device)

            for batch_idx, (x_num_idx, x_num, x_num_mask, x_cat_idx, x_cat, x_cat_mask, target) in enumerate(self.testing_dataloader):
                # put all variables to the appropriate device
                x_num_idx = x_num_idx.to(self.device)
                x_num = x_num.to(self.device)
                x_num_mask = x_num_mask.to(self.device)
                x_cat_idx = x_cat_idx.to(self.device)
                x_cat = x_cat.to(self.device)
                x_cat_mask = x_cat_mask.to(self.device)
                target = target.to(self.device)

                # feed the data to the model and get the output
                output = self.model(x_num_idx, x_num, x_num_mask, x_cat_idx, x_cat, x_cat_mask)

                # collect the target and output
                outputs = torch.cat((outputs, output))
                targets = torch.cat((targets, target))

        return (outputs, targets)

    def test(self, do_bootstrape: bool = False) -> None:
        # get the output and target
        testing_outputs, testing_targets = self._test_epoch() # this should be changed.
        print("Predicted Probability = {:.6f}".format(testing_outputs[0].item()))

    def _sampler(self, number_of_samples: int, sample_times: int = 1000):
        self.sample = []
        for i in range(sample_times):
            self.sample.append(resample(np.arange(number_of_samples)))

    def _bootsrapping(self, testing_outputs: Tensor, testing_targets: Tensor, measurer = None, sample_times: int = 1000):
        assert measurer is not None, print("Measurer Wrong.")
        metric_value_record = []
        try:
            isinstance(self.sample, list)
        except:
            self._sampler(len(testing_outputs), sample_times)
        
        with torch.no_grad():
            for sample_idx in self.sample:
                metric_value = measurer(testing_outputs[sample_idx, :], testing_targets[sample_idx, :])
                metric_value_record.append(metric_value.cpu().item())

            metric_value_record = np.array(metric_value_record)
            half_confidence_interval_length = (np.quantile(metric_value_record, q=0.975) - np.quantile(metric_value_record, q=0.025)) / 2

        return half_confidence_interval_length


if __name__ == '__main__':
    pass
