from .data_module import DataModule
import pandas as pd
from sklearn.datasets import fetch_openml
import pickle
import os
from types import SimpleNamespace
from typing import Tuple, List

class SteelPlatesFaultDataModule(DataModule):
    def __init__(self, 
        config: SimpleNamespace
        ) -> None:
        super().__init__(config)
    
    def load_data(self) -> Tuple[pd.DataFrame, pd.Series]:
        steel = fetch_openml(data_id = self.config.data.data_id, data_home='./data_cache')
        
        data = steel.data
        
        label = data[['V28', 'V29', 'V30', 'V31', 'V32', 'V33']]
        data = data.drop(['V28', 'V29', 'V30', 'V31', 'V32', 'V33'], axis = 1)
        label['Class'] = steel.target.apply(lambda x : 1 if x == b'2' else 0)

        label = pd.Series(label.values.argmax(1), index = data.index)
        return data, label
    
    def prepare_data(self) -> Tuple[pd.DataFrame, pd.Series, List[str], List[str]]:
        if os.path.exists(self.config.data.dataset_path):
            with open(self.config.data.dataset_path, 'rb') as f:
                dataset = pickle.load(f)
            
            return dataset['data'], dataset['label'], dataset['numeric_cols'], dataset['category_cols']
            
        data, label = self.load_data()
        

        category_cols = []

        numeric_cols = list(map(str, data.columns))

        if self.config.runner_option.save_data:
            self.save_data(data, label, numeric_cols, category_cols)
            
        return data, label, numeric_cols, category_cols
