from pathlib import Path
import sys
sys.path.append(str(Path().absolute().parent))
import numpy as np
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from sklearn.ensemble import ExtraTreesRegressor
from sklearn.linear_model import BayesianRidge
import pandas as pd
from typing import List

class Converter(object):
    def __init__(self):
        pass

    def encode(self, 
                data: pd.DataFrame
        ) -> pd.DataFrame:
        self.encoder = {}
        for col in tqdm(data.columns):
            self.encoder[col] = LabelEncoder()
            self.encoder[col].fit(data[col])
            nan_idx = data[col].isna()
            data[col] = self.encoder[col].transform(data[col])
            data[col][nan_idx] = np.nan
        return data
    
    def decode(self, 
                data: pd.DataFrame
        ) -> pd.DataFrame:
        for col in tqdm(data.columns):
            data[col] = np.round(data[col])
            data[col] = self.encoder[col].inverse_transform(data[col])
        return data

class Imputer(object):
    """
        category col, binary col, numeric col 받아야함.
    """
    def __init__(self, 
                estimator: str = "BayesianRidge", 
                n_estimators: int = None, 
                random_state: int = 0, 
                # data_type : str = 'category'
        ) -> None:

        self.converter = Converter()
        self.random_state = random_state
        if estimator == "ExtraTree":
            self.estimator = ExtraTreesRegressor(random_state=self.random_state, n_estimators=n_estimators)
        elif estimator == "BayesianRidge":
            self.estimator = BayesianRidge()
    
    # def impute(self, data):
    #     return self.category_impute(data)

    def category_impute(self, 
                        data: pd.DataFrame
        ) -> pd.DataFrame:
        """
        data : pandas DataFrame
        """
        encoded = self.converter.encode(data)
        max_value = [len(self.converter.encoder[c].classes_) - 2 for c in data.columns]
        min_value = 0
        imputer = IterativeImputer(min_value=min_value, max_value=max_value, random_state = self.random_state, estimator=self.estimator).fit(encoded)
        imputed = imputer.transform(encoded)

        for i in range(len(data.columns)):
            data.iloc[:, i] = imputed[:, i].astype(np.uint32)
        converted = self.converter.decode(data)

        return converted

    def numeric_impute(self, 
                        data: pd.DataFrame
        ) -> pd.DataFrame:
        max_value = [data[col].max() for col in data.columns]
        min_value = [data[col].min() for col in data.columns]
        imputer = imputer = IterativeImputer(min_value=min_value, max_value=max_value, random_state = self.random_state, estimator=self.estimator).fit(data.to_numpy())
        imputed = imputer.transform(data.to_numpy())

        for i in range(len(data.columns)):
            data.iloc[:, i] = imputed[:, i].astype(np.float32)

        return data
    
    def binary_impute(self, 
                        data: pd.DataFrame
        ) -> pd.DataFrame:
        max_value = 1
        min_value = 0
        imputer = imputer = IterativeImputer(min_value=min_value, max_value=max_value, random_state = self.random_state, estimator=self.estimator).fit(data.to_numpy())
        imputed = imputer.transform(data.to_numpy())

        for i in range(len(data.columns)):
            data.iloc[:, i] = imputed[:, i].astype(np.uint32)

        for col in data.columns:
            data[col] = data[col].apply(lambda x: 0 if x < 0.5 else 1)
        
        return data

    def impute(self, 
                data: pd.DataFrame, 
                binary_cols: List[str], 
                numeric_cols: List[str], 
                category_cols: List[str]
        ) -> pd.DataFrame:
        binary_data = self.binary_impute(data[binary_cols])
        numeric_data = self.numeric_impute(data[numeric_cols])
        category_data = self.category_impute(data[category_cols])

        for col in binary_cols:
            data[col] = binary_data[col]
        
        for col in numeric_cols:
            data[col] = numeric_data[col]

        for col in category_cols:
            data[col] = category_data[col]
        
        return data