from .base_editor import BaseEditor 
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from pandarallel import pandarallel
import pandas as pd
from typing import Tuple, List

class Likelihood(BaseEditor):
    """A set of likelihoods for each samples per label.

    For categorical features, calculate likelihood directly.
    For numerical features, calculate likelihood after digitize them.
    All calculations are performed under the assumption that all features are independent.
    
    Attributes:
        X: A dataframe of entire dataset.
        X_train: A dataframe of training dataset.
        y: A series of label.
        likelihoods: Likelhoods of each feature and label.
        continuous_cols: A list of numerical features to digitize.
        bins: A dictionary of bins for each numerical features for digitizing.
        pl_priors: A dataframe of likelihoods for each samples per label.
        scaler: A minmax scaler for likelihoods.
    """
    def __init__(self, 
                X: pd.DataFrame, 
                X_train: pd.DataFrame, 
                y: pd.Series, 
                continuous_cols: List[str],
                n_jobs: int,
                n_bins: int = 10,
        ) -> None:
        """Setup for likelihoods and calculate them

        Args:
            X: An entire data.
            X_train: A training dataset of entire data.
            y: Labels for the given data.
            continuous_cols: A list of name for each numerical columns.
            n_jobs: The number of cpus to use.
        """
        self.likelihoods = {}
        self.bins = {}
        self.n_bins = n_bins
        
        super().__init__(X, X_train, y, continuous_cols, n_jobs)

    def init_criteria(self) -> None:
        """Setup to calculate likelihoods"""
        for y in self.y.unique():
            self.likelihoods['y_%d' % y] = np.log2((self.y == y).sum() / len(self.y))

        for col in self.X_train.columns:
            for label in self.y.unique():
                data = self.X_train[self.y == label][col]
                if col in self.continuous_cols:
                    data, bin = self.digitize(data)
                    self.bins['y_%d_' % label + col] = bin

                uq, cnt = np.unique(data, return_counts = True)
                self.likelihoods['y_%d_' % label + col] = dict(zip(uq, np.log2(cnt / cnt.sum())))
            
            data = self.X_train[col]
            if col in self.continuous_cols:
                data, bin = self.digitize(data)
                self.bins[col] = bin
            uq, cnt = np.unique(data, return_counts = True)
            self.likelihoods[col] = dict(zip(uq, np.log2(cnt / cnt.sum())))


    def init_scaler(self) -> None:
        """Initialize scaler"""
        self.scaler.fit(self.X_train.apply(self.get_parallel_likelihood, axis=1, result_type = 'expand').values)

    def digitize(self, 
                data: pd.Series, 
                bin: np.array = None
        ) -> Tuple[np.array, np.array]:
        """Digitize numerical features.

        Args:
            data: A series of a numerical feature.
            bin: A dictionary of bins for each numerical features for digitizing.
                If it is None, calculate it.
        Returns:
            Digitized feature and bin for it.
        """
        if bin is None:
            bin = np.quantile(data, np.arange(0, 1, 1 / self.n_bins))
        digitized = np.digitize(data, bin)

        return digitized, bin

    def get_likelihood(self, 
                        data: pd.Series, 
                        label: int
        ) -> pd.Series:
        """Calculate likelihood for a given sample for a given label.

        Args:
            data: A data sample to calculate likehioods.
            label: A target label for likelihoods.
        Returns:
            Likelihoods of a sample for a given label.
        """
        likelihood = 0
        
        for col in self.X_train.columns:
            data_x = data[col]

            if col in self.continuous_cols:
                data_x, _ = self.digitize(data[col], bin = self.bins['y_%d_' % label + col])

            try:
                likelihood += self.likelihoods['y_%d_' % label + col][data_x]
            except:
                
                self.likelihoods['y_%d_' % label + col][data_x] = np.log2(1 / len(self.X_train))
                likelihood += self.likelihoods['y_%d_' % label + col][data_x]

        return likelihood

    def get_parallel_likelihood(self, 
                                data: pd.Series
        ) -> List[float]:
        """Calculate likelihoods for a given sample for labels

        Args:
            data: A Series to calculate likehioods.
        
        Returns:
            Likelihoods of a sample for each label.
        """
        return [self.get_likelihood(data = data, label = label) for label in self.y.unique()]

    def caching_criteria(self) -> None:
        """Set likelihoods for each samples for labels"""
        self.pl_criteria = pd.DataFrame(self.X.parallel_apply(self.get_parallel_likelihood, axis=1, result_type = 'expand').values, index=self.X.index)
        

    def scale_criteria(self, 
                    likelihood: np.array
        ) -> np.array:
        """Scale likehioods"""
        return self.scaler.transform(likelihood)
        
        