from .data_module import DataModule

from sklearn.datasets import fetch_openml
import os
import pickle
from types import SimpleNamespace
from typing import Tuple, List
import pandas as pd
from sklearn.preprocessing import LabelEncoder
import numpy as np
from preprocessing.imputation import Imputer

class WallRobotNavigationDataModule(DataModule):
    def __init__(self, 
        config: SimpleNamespace
        ) -> None:
        super().__init__(config)
    
    def load_data(self) -> Tuple[pd.DataFrame, pd.Series]:
        robot = fetch_openml(data_id = self.config.data.data_id, data_home='./data_cache')

        data = robot.data
        
        label = pd.Series(LabelEncoder().fit_transform(robot.target))

        return data, label
    
    def prepare_data(self) -> Tuple[pd.DataFrame, pd.Series, List[str], List[str]]:
        if os.path.exists(self.config.data.dataset_path):
            with open(self.config.data.dataset_path, 'rb') as f:
                dataset = pickle.load(f)    
            return dataset['data'], dataset['label'], dataset['numeric_cols'], dataset['category_cols']

        data, label = self.load_data()

        category_cols = []
        numeric_cols = data.columns
        
        le = LabelEncoder()
        for col in category_cols:
            data[col] = le.fit_transform(data[col])

        if self.config.runner_option.save_data:
            self.save_data(data, label, numeric_cols, category_cols)
            
        return data, label, numeric_cols, category_cols