from .base_prior import BasePrior

import numpy as np
from pandarallel import pandarallel
import pandas as pd
from typing import Tuple, List
from statsmodels.nonparametric.kernel_density import KDEMultivariate

class Density(BasePrior):
    """A set of densities for each samples per label.

    For categorical features, apply ????.
    For numerical features, apply Gaussian kernel.
    
    Attributes:
        X: A dataframe of entire dataset.
        X_train: A dataframe of training dataset.
        y: A series of label.
        continuous_cols: A list of numerical features to digitize.
        density_estimator: A dictionary of density estimators for each label. 
        pl_priors: A dataframe of likelihoods for each samples per label.
        scaler: A minmax scaler for likelihoods.
        __scale_factor: A scalar value to avoid numerical values owing to extremely small values.
    """
    def __init__(self, 
                X: pd.DataFrame, 
                X_train: pd.DataFrame, 
                y: pd.Series, 
                continuous_cols: List[str], 
                n_jobs: int
        ) -> None:
        """Setup for density estimater

        Args:
            X: An entire data.
            X_train: A training dataset of entire data.
            y: Labels for the given data.
            continuous_cols: A list of name for each numerical columns.
            n_jobs: The number of cpus to use.
        """
        self.density_estimator = {}
        
        self.__scale_factor = 100
        
        super().__init__(X, X_train, y, continuous_cols, n_jobs)

    def init_scaler(self) -> None:
        """Initialize scaler"""
        self.scaler.fit(self.X_train.apply(self.get_parallel_prior, axis=1, result_type = 'expand').values * self.__scale_factor)
        
    def init_prior(self) -> None:
        """Setup to calculate density"""
        var_type = ''
        for c in self.X.columns:
            if c in self.continuous_cols:
                var_type += 'c'
            else:
                var_type += 'u'
        
        for y in self.y.unique():
            if (self.y == y).sum() <= len(var_type):
                temp = np.concatenate((self.X_train[self.y == y].values, self.X_train[self.y == y].values), axis=0)
                while len(temp) <= len(var_type):
                    temp = np.concatenate((temp, self.X_train[self.y == y].values), axis=0)
                self.density_estimator['y_%d' % y] = KDEMultivariate(data = temp, var_type=var_type)
            else:
                self.density_estimator['y_%d' % y] = KDEMultivariate(data = self.X_train[self.y == y].values, var_type=var_type)

    def calculate_prior(self, 
                        data: pd.Series, 
                        label: int
        ) -> pd.Series:
        """Calculate prior knowledge for a given sample for a given label.

        Args:
            data: A data sample to calculate prior knowledge.
            label: A target label to calculate prior knowledge.
        Returns:
            Prior knowledge of a sample for a given label.
        """
        return self.density_estimator['y_%d' % label].pdf(data.values).item()

    def get_parallel_prior(self, 
                                data: pd.Series
        ) -> List[float]:
        """Calculate prior knowledge for a given sample for labels

        Args:
            data: A Series to calculate prior knowledge.
        
        Returns:
            Prior knowledge of a sample for labels.
        """
        return [self.calculate_prior(data = data, label = label) for label in self.y.unique()]

    def caching_prior(self) -> None:
        """Set prior knowledge for each samples for labels"""
        self.pl_priors = pd.DataFrame(self.X.parallel_apply(self.get_parallel_prior, axis=1, result_type = 'expand').values, index=self.X.index)
        

    def scale_prior(self, 
                    density: np.array
        ) -> np.array:
        """Scale prior knowledge"""
        return self.scaler.transform(density * self.__scale_factor)