###############
#   Package   #
###############
import os
import time
import math
import logging
import pickle
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics

from typing import List, Dict, Tuple, Optional
from torch import Tensor
from sklearn.utils import resample

#######################
# package from myself #
#######################
from metric.concordance_index import ConcordanceIndex

#############
#   Class   #
#############
class TreeBasedTester():
    def __init__(self,
                 model,
                 loss_fn: torch.nn.Module,
                 testing_dataloader: torch.utils.data.DataLoader,
                 logger: logging.Logger,
                 resume_model_path: str = None,
                 device: torch.device = torch.device("cpu"),
                 *args,
                 **kwargs
                ):
        # check if the resumed model (.pth) file exists.
        assert os.path.isfile(resume_model_path), AssertionError('resumed model file error.')

        # define the variables of the tester
        self.model = model
        try:
            self.model.load_model(resume_model_path)
        except:
            self.model = pickle.load(open(resume_model_path, 'rb'))
        self.loss_fn = loss_fn
        self.testing_dataloader = testing_dataloader
        
        self.logger = logger

        self.device = device

    def _get_test_data(self) -> Tuple[np.ndarray]:
        X_test = []
        y_test = []
        day_delta_test = []
        for batch_idx, (x_num_idx, x_num, x_num_mask, x_cat_idx, x_cat, x_cat_mask, target) in enumerate(self.testing_dataloader):
            # put all variables to the appropriate device
            x_num = x_num.detach().numpy()
            x_num_mask = x_num_mask.detach().numpy()
            x_cat = x_cat.detach().numpy()
            x_cat_mask = x_cat_mask.detach().numpy()

            # collect the variable
            X = np.concatenate([x_num, x_cat, x_num_mask, x_cat_mask], axis=-1)

            X_test.append(X)
            y_test.append(target)

        X_test = np.concatenate(X_test)
        y_test = np.concatenate(y_test)

        N_test = X_test.shape[0]
        X_test = X_test.reshape(N_test, -1)
        y_test = np.squeeze(y_test)

        return (X_test, y_test)

    def test(self, do_bootstrape: bool = False) -> None:
        # get the output and target
        testing_x, testing_y = self._get_test_data()
        testing_outputs = torch.from_numpy(self.model.predict_proba(testing_x)[:, 1]).to(self.device)
        testing_targets = torch.from_numpy(testing_y).to(self.device)
        print("Predicted Probability = {:.6f}".format(testing_outputs[0].item()))
            
    def _sampler(self, number_of_samples: int, sample_times: int = 1000):
        self.sample = []
        for i in range(sample_times):
            self.sample.append(resample(np.arange(number_of_samples)))

    def _bootsrapping(self, testing_outputs: Tensor, testing_targets: Tensor, measurer = None, day_delta = None, sample_times: int = 1000):
        assert measurer is not None, print("Measurer Wrong.")
        metric_value_record = []
        try:
            isinstance(self.sample, list)
        except:
            self._sampler(len(testing_outputs), sample_times)
        
        with torch.no_grad():
            if day_delta is None:
                for sample_idx in self.sample:
                    metric_value = measurer(testing_outputs[sample_idx], testing_targets[sample_idx])
                    metric_value_record.append(metric_value.cpu().item())
            else:
                for sample_idx in self.sample:
                    metric_value = measurer(day_delta[sample_idx, :], testing_targets[sample_idx, :], testing_outputs[sample_idx, :])
                    metric_value_record.append(metric_value)

            metric_value_record = np.array(metric_value_record)
            half_confidence_interval_length = (np.quantile(metric_value_record, q=0.975) - np.quantile(metric_value_record, q=0.025)) / 2

        return half_confidence_interval_length

if __name__ == '__main__':
    pass
