from abc import ABC, abstractmethod

import numpy as np
from sklearn.preprocessing import MinMaxScaler
from pandarallel import pandarallel
import pandas as pd
from typing import Tuple, List
from numpy.typing import NDArray

class BasePrior(ABC):
    """A set of prior knowledge for each samples per label.

    Calculating and caching prior knowledge, density or likelihood, of samples for each class.
    
    Attributes:
        X: A dataframe of entire dataset.
        X_train: A dataframe of training dataset.
        y: A series of label.
        pl_priors: prior knowledge for each samlpes per label.
        continuous_cols: A list of continuous features to digitize.
        scaler: A minmax scaler for likelihoods.
    """
    def __init__(self,
                 X: pd.DataFrame,
                 X_train: pd.DataFrame,
                 y: pd.Series,
                 continuous_cols: List[str],
                 n_jobs: int
    ) -> None:
        self.X = X
        self.X_train = X_train.copy()
        self.y = y.copy()
        self.pl_priors = None
        self.continuous_cols = continuous_cols
        self.scaler = MinMaxScaler()
        self.alpha = None
        
        pandarallel.initialize(progress_bar=False, nb_workers = n_jobs, use_memory_fs=False)
        
        self.init_prior()
        self.init_scaler()
        self.caching_prior()
    
    @abstractmethod
    def init_prior(self) -> None:
        pass
    
    @abstractmethod
    def init_scaler(self) -> None:
        pass
    
    @abstractmethod
    def caching_prior(self) -> None:
        pass
    
    @abstractmethod
    def calculate_prior(self, data: pd.Series, label: int) -> pd.Series:
        pass
    
    @abstractmethod
    def get_parallel_prior(self, data: pd.Series) -> List[float]:
        pass
    
    @abstractmethod
    def scale_prior(self, prior: NDArray[np.float32]) -> NDArray[np.float32]:
        pass
    
    def get_prior(self, idx: NDArray[np.float32]) -> NDArray[np.float32]:
        return self.scale_prior(self.pl_priors.loc[idx, :])