from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from pandas import DataFrame
    from numpy import ndarray
    from torch import Tensor

import torch
import pandas as pd
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder, StandardScaler

from env.user import ROOT, PROJECT_PATH
from helpers.saving import save_to_pickle

class PrepareCaliforniaHousing():
    """Prepare the California Housing data set.
    
    Imports the California Housing data set and performs a number of trans-
    formations such as:
    
    > Convert categorical features to one-hot encoded features.
    > Performs train-validation-test split.
    > Z-normalizes the data.
    
    Additionally, it also creates the trial subsample training sets and reads
    them again from disk.
    """

    def __init__(
            self,
            random_state: int=42,
            split: list[int] = [0.6, 0.25],
    ) -> None:

        raw_df = self.open_data()
        raw_df, trans_col_names = self.transform_cat_vars(dataframe=raw_df)
        raw_df = self.shuffle(raw_df, seed=random_state)
        # default split: 0.6 - 0.1 - 0.3
        train_df, test_df = self.split(
            dataframe=raw_df,
            random_state=random_state,
            train_size=split[0],
        )
        val_df, test_df = self.split(
            dataframe=test_df,
            random_state=random_state,
            train_size=split[1],
        )
        train_df.reset_index(drop=True, inplace=True)
        val_df.reset_index(drop=True, inplace=True)
        test_df.reset_index(drop=True, inplace=True)
        train_df, val_df, test_df = self.znormalize(
            trainframe=train_df,
            valframe=val_df,
            testframe=test_df,
            cat_col_names=trans_col_names,
        )
        X_train, y_train = self.extract_features(dataframe=train_df)
        X_val, y_val = self.extract_features(dataframe=val_df)
        X_test, y_test = self.extract_features(dataframe=test_df)

        self.X_train = X_train
        self.y_train = y_train
        self.X_val = X_val.to_numpy()
        self.y_val = y_val.to_numpy()
        self.X_test = X_test.to_numpy()
        self.y_test = y_test.to_numpy()

    def open_data(self) -> DataFrame:
        return pd.read_csv(ROOT + "/data/california_housing/housing.csv")
    
    def split(
            self,
            dataframe: DataFrame,
            random_state: int,
            train_size: int,
    ) -> tuple[DataFrame, DataFrame]:
        return train_test_split(
            dataframe,
            train_size=train_size,
            random_state=random_state
        )
    
    def transform_cat_vars(self, dataframe: DataFrame) -> DataFrame:
        cat_cols = dataframe.select_dtypes(include=['object']).columns
        cat_df = dataframe.loc[:, cat_cols]
        num_df = dataframe.drop(columns=cat_cols)

        enc = OneHotEncoder()
        cat_df_onehot = enc.fit_transform(cat_df).toarray()
        cat_df_onehot = pd.DataFrame(
            cat_df_onehot,
            columns=enc.get_feature_names_out().transpose(),
        )

        return (pd.concat([num_df, cat_df_onehot], axis=1),\
                 enc.get_feature_names_out().transpose())
    
    def shuffle(self, dataframe: DataFrame, seed: int) -> DataFrame:
        return dataframe.sample(
            frac=1,
            random_state=seed,
        ).reset_index(drop=True)
    
    def extract_features(
            self,
            dataframe: DataFrame,
    ) -> DataFrame:
        y = dataframe.loc[:, 'median_house_value']
        X = dataframe.drop(columns=['median_house_value'])

        return X, y
    
    def znormalize(
            self,
            trainframe: DataFrame,
            valframe: DataFrame,
            testframe: DataFrame,
            cat_col_names: ndarray,
    ) -> DataFrame:
        
        # fit and transform the train set
        train_cat_df = trainframe.loc[:, cat_col_names]
        train_num_df = trainframe.drop(columns=cat_col_names)

        scaler = StandardScaler()
        train_num_df_trans = scaler.fit_transform(train_num_df)
        train_num_df_trans = pd.DataFrame(
            train_num_df_trans,
            columns=scaler.get_feature_names_out().transpose(),
        )
        # fill missing values with mean
        train_num_df_trans.fillna(0.0, inplace=True)
        # concatenate
        X_train = pd.concat([train_num_df_trans, train_cat_df], axis=1)

        # repeat for val set
        val_cat_df = valframe.loc[:, cat_col_names]
        val_num_df = valframe.drop(columns=cat_col_names)

        val_num_df_trans = scaler.transform(val_num_df)
        val_num_df_trans = pd.DataFrame(
            val_num_df_trans,
            columns=scaler.get_feature_names_out().transpose(),
        )
        # fill missing values with mean
        val_num_df_trans.fillna(0.0, inplace=True)
        # concatenate
        X_val = pd.concat([val_num_df_trans, val_cat_df], axis=1)     
        
        # repeat for test set
        test_cat_df = testframe.loc[:, cat_col_names]
        test_num_df = testframe.drop(columns=cat_col_names)

        test_num_df_trans = scaler.transform(test_num_df)
        test_num_df_trans = pd.DataFrame(
            test_num_df_trans,
            columns=scaler.get_feature_names_out().transpose(),
        )
        # fill missing values with mean
        test_num_df_trans.fillna(0.0, inplace=True)
        # concatenate
        X_test = pd.concat([test_num_df_trans, test_cat_df], axis=1)

        return X_train, X_val, X_test
    
    def subsample(
            self,
            trials: int,
            exp_name: str,
            frac: float=0.9,
            *,
            replace: bool=False,
            save: bool=True,
            safe_mode: bool=True,
    ):
        path = PROJECT_PATH + '/data/' + exp_name

        dataframe = pd.concat([self.X_train, self.y_train], axis=1)
        for t in range(trials):
            self.subframe = dataframe.sample(
                frac=frac,
                replace=replace,
                random_state=t,
            )
            if save:
                save_to_pickle(
                    folder_path=path,
                    file_name="trial_"+ str(t) + "_" + exp_name,
                    file=self.subframe,
                    safe_mode=safe_mode,
                )

    def read_from_disk(self, exp_name: str, file_name: str):
        path = PROJECT_PATH + '/data/' + exp_name + '/' + file_name
        with open(path, "rb") as handle:
            # df = pickle.load(handle)
            df = pd.read_pickle(handle)
        
        self.X_train_sub, self.y_train_sub = self.extract_features(df)
        self.X_train_sub.reset_index(drop=True, inplace=True)
        self.y_train_sub.reset_index(drop=True, inplace=True)

        self.X_train_sub, self.y_train_sub = self.X_train_sub.to_numpy(), self.y_train_sub.to_numpy()

        return


class CaliforniaHousing(Dataset):

    def __init__(
            self,
            Xarray: ndarray,
            yarray: ndarray,
    ) -> None:
        super().__init__()
        self.X = self.convert_to_tensor(Xarray)
        self.X = self.X.type(torch.float32)
        self.y = self.convert_to_tensor(yarray).unsqueeze(dim=1)
        self.y = self.y.type(torch.float32)

    def convert_to_tensor(self, array: ndarray) -> Tensor:
        return torch.from_numpy(array)
    
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


        