from .data_module import DataModule
from sklearn.datasets import fetch_openml
from sklearn.preprocessing import LabelEncoder
from preprocessing.imputation import Imputer
import os
import pickle
from types import SimpleNamespace
from typing import Tuple, List
import pandas as pd

class SickDataModule(DataModule):
    def __init__(self, 
        config: SimpleNamespace
        ) -> None:
        super().__init__(config)
    
    def load_data(self) -> Tuple[pd.DataFrame, pd.Series]:
        sick = fetch_openml(data_id = self.config.data.data_id, data_home='./data_cache')
        data = sick.data
        label = sick.target
        
        label = label.astype(object)

        le = LabelEncoder()
        label = pd.Series(le.fit_transform(sick.target))

        return data, label
    
    def prepare_data(self) -> Tuple[pd.DataFrame, pd.Series, List[str], List[str]]:
        if os.path.exists(self.config.data.dataset_path):
            with open(self.config.data.dataset_path, 'rb') as f:
                dataset = pickle.load(f)    
            return dataset['data'], dataset['label'], dataset['numeric_cols'], dataset['category_cols']

        data, label = self.load_data()

        numeric_cols = self.config.data.numeric_cols
        category_cols = [col for col in data.columns if not col in numeric_cols]

        le = LabelEncoder()
        for col in category_cols:
            data[col] = le.fit_transform(data[col])
        
        data = data.drop(self.config.data.drop_cols, axis=1)
        
        imputer = Imputer(random_state = self.config.runner_option.random_seed)
        data['age'] = imputer.numeric_impute(data[['age']])
        data['sex'] = imputer.binary_impute(data[['sex']])
        
        if self.config.runner_option.save_data:
            self.save_data(data, label, numeric_cols, category_cols)
            
        return data, label, numeric_cols, category_cols
    
