from .base_editor import BaseEditor

import numpy as np
from pandarallel import pandarallel
import pandas as pd
from typing import Tuple, List
from scipy.spatial.distance import cdist
from numpy.linalg import inv
from numpy import cov
class Mahalanobis(BaseEditor):

    def __init__(self, 
                X: pd.DataFrame, 
                X_train: pd.DataFrame, 
                y: pd.Series, 
                continuous_cols: List[str], 
                n_jobs: int
        ) -> None:
        super().__init__(X, X_train, y, continuous_cols, n_jobs)

    def init_criteria(self) -> None:
        pass
    
    def init_scaler(self) -> None:
        """Initialize scaler"""
        self.scaler.fit(self.X_train.apply(self.get_parallel_criteria, axis=1, result_type = 'expand').values)
        
        

    def get_mahalanobis_dist(self, 
                        data: pd.Series, 
                        label: int
        ) -> pd.Series:
        """Calculate Mahalanobis Distance for a given sample for a given label.

        Args:
            data: A data sample to calculate likehioods.
            label: A target label for likelihoods.
        Returns:
            Likelihoods of a sample for a given label.
        """
        data = data.values.reshape((1, -1))
        p_idx = (self.y == label)
        n_idx = (self.y != label)
        
        p_df = self.X_train[p_idx]

        # while len(p_df) <= self.X_train.shape[1]:
        #     p_df = pd.concat((p_df, self.X_train[p_idx]))
            
        n_df = self.X_train[n_idx]
        # while len(n_df) <= self.X_train.shape[1]:
        #     n_df = pd.concat((n_df, self.X_train[n_idx]))
        # CV = np.atleast_2d(np.cov(data.astype(np.double, copy=False).T))
        # VI = np.linalg.inv(CV).T.copy()
        # p = np.nanmean(cdist(data, p_df,'mahalanobis', VI=VI))
        # n = np.nanmean(cdist(data, n_df,'mahalanobis', VI=VI))
        p = np.nanmean(cdist(data, p_df,'mahalanobis'))
        n = np.nanmean(cdist(data, n_df,'mahalanobis'))
        w = np.abs(p - n)
        return w

    def get_parallel_criteria(self, 
                                data: pd.Series
        ) -> List[float]:
        """Calculate likelihoods for a given sample for labels
        Args:
            data: A Series to calculate likehioods.
        
        Returns:
            Likelihoods of a sample for each label.
        """
        return [self.get_mahalanobis_dist(data = data, label = label) for label in self.y.unique()]

    def caching_criteria(self) -> None:
        """Set likelihoods for each samples for labels"""
        self.pl_criteria = pd.DataFrame(self.X.parallel_apply(self.get_parallel_criteria, axis=1, result_type = 'expand').values, index=self.X.index)
        

    def scale_criteria(self, 
                    criteria: np.array
        ) -> np.array:
        """Scale likehioods"""
        return self.scaler.transform(criteria)