from .base_prior import BasePrior
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from pandarallel import pandarallel
import pandas as pd
from typing import Tuple, List

class Likelihood(BasePrior):
    """A set of empirical likelihoods for each samples per label.

    For categorical features, calculate likelihood directly.
    For numerical features, calculate likelihood after digitize them.
    All calculations are performed under the assumption that all features are independent.
    
    Attributes:
        X: A dataframe of entire dataset.
        X_train: A dataframe of training dataset.
        y: A series of label.
        likelihoods: Likelhoods of each feature and label.
        continuous_cols: A list of numerical features to digitize.
        bins: A dictionary of bins for each numerical features for digitizing.
        pl_priors: A dataframe of likelihoods for each samples per label.
        scaler: A minmax scaler for likelihoods.
    """
    def __init__(self, 
                X: pd.DataFrame, 
                X_train: pd.DataFrame, 
                y: pd.Series, 
                continuous_cols: List[str],
                n_jobs: int,
                n_bins: int = 10,
        ) -> None:
        """Setup for empirical likelihood

        Args:
            X: An entire data.
            X_train: A training dataset of entire data.
            y: Labels for the given data.
            continuous_cols: A list of name for each numerical columns.
            n_jobs: The number of cpus to use.
        """
        self.likelihoods = {}
        self.bins = {}
        self.n_bins = n_bins
        
        super().__init__(X, X_train, y, continuous_cols, n_jobs)

    def init_prior(self) -> None:
        """Setup to calculate empirical likelihoods"""
        for y in self.y.unique():
            self.likelihoods['y_%d' % y] = np.log2((self.y == y).sum() / len(self.y))

        for col in self.X_train.columns:
            for label in self.y.unique():
                data = self.X_train[self.y == label][col]
                if col in self.continuous_cols:
                    data, bin = self.digitize(data)
                    self.bins['y_%d_' % label + col] = bin

                uq, cnt = np.unique(data, return_counts = True)
                self.likelihoods['y_%d_' % label + col] = dict(zip(uq, np.log2(cnt / cnt.sum())))
            
            data = self.X_train[col]
            if col in self.continuous_cols:
                data, bin = self.digitize(data)
                self.bins[col] = bin
            uq, cnt = np.unique(data, return_counts = True)
            self.likelihoods[col] = dict(zip(uq, np.log2(cnt / cnt.sum())))


    def init_scaler(self) -> None:
        """Initialize scaler"""
        self.scaler.fit(self.X_train.apply(self.get_parallel_prior, axis=1, result_type = 'expand').values)

    def digitize(self, 
                data: pd.Series, 
                bin: np.array = None
        ) -> Tuple[np.array, np.array]:
        """Digitize numerical features.

        Args:
            data: A series of a numerical feature.
            bin: A dictionary of bins for each numerical features for digitizing.
                If it is None, calculate it.
        Returns:
            Digitized feature and bin for it.
        """
        if bin is None:
            bin = np.quantile(data, np.arange(0, 1, 1 / self.n_bins))
        digitized = np.digitize(data, bin)

        return digitized, bin

    def calculate_prior(self, 
                        data: pd.Series, 
                        label: int
        ) -> pd.Series:
        """Calculate prior knowledge for a given sample for a given label.

        Args:
            data: A data sample to calculate prior knowledge.
            label: A target label to calculate prior knowledge.
        Returns:
            Prior knowledge of a sample for a given label.
        """
        likelihood = 0
        
        for col in self.X_train.columns:
            data_x = data[col]

            if col in self.continuous_cols:
                data_x, _ = self.digitize(data[col], bin = self.bins['y_%d_' % label + col])

            try:
                likelihood += self.likelihoods['y_%d_' % label + col][data_x]
            except:
                
                self.likelihoods['y_%d_' % label + col][data_x] = np.log2(1 / len(self.X_train))
                likelihood += self.likelihoods['y_%d_' % label + col][data_x]

        return likelihood

    def get_parallel_prior(self, 
                                data: pd.Series
        ) -> List[float]:
        """Calculate prior knowledge for a given sample for labels

        Args:
            data: A Series to calculate prior knowledge.
        
        Returns:
            Prior knowledge of a sample for labels.
        """
        return [self.calculate_prior(data = data, label = label) for label in self.y.unique()]

    def caching_prior(self) -> None:
        """Set prior knowledge for each samples for labels"""
        self.pl_priors = pd.DataFrame(self.X.parallel_apply(self.get_parallel_prior, axis=1, result_type = 'expand').values, index=self.X.index)
        

    def scale_prior(self, 
                    likelihood: np.array
        ) -> np.array:
        """Scale prior knowledge"""
        return self.scaler.transform(likelihood)
        
        