"""
Contains functions for pre-processing tabular data.
This includes encoding all types of features into numerical ones.
"""

import json
import logging
import numpy as np
import pandas as pd


log = logging.getLogger()
log.setLevel(logging.INFO)


def load_data(name: str, base_path: str):
    """
    Load data and configuration file from a local path and return the padas
    dataframe of the file and the configuration as a dictionary.
    """
    # Load config file
    dataset_path = base_path + name + "/"
    with open(dataset_path + "conf.json", "r") as fp:
        config = json.load(fp)
    # Load data file
    file_path = dataset_path + config["raw_name"]
    if file_path.endswith("xlsx"):
        df = pd.read_excel(file_path, **config.get("load") or {})
    elif file_path.endswith("csv"):
        df = pd.read_csv(file_path, **config.get("load") or {})
    else:
        raise "Unknown file format"
    log.info(f"{name} dataset loaded successfully.")
    return df, config


def save_data(df, name: str, base_path: str):
    """
    Save dataframe to a local path as a csv file.
    """
    dataset_path = base_path + name + "/"
    file_path = dataset_path + "processed.csv"
    df.to_csv(file_path, index=False)
    log.info(f"Saved processed {name} dataset.")


def handle_missing(df, col_thresh: float = 0.5, row_thresh: float = 0.5):
    """
    If only few values are missing these are imputed
    """
    assert 0 < col_thresh <= 1
    assert 0 < row_thresh <= 1

    # Replace missing with NaNs
    df.replace("?", np.nan, inplace=True)
    df.replace(r"^\s*$", np.nan, regex=True, inplace=True)
    conv_columns = []
    for col in df.columns:
        if df[col].dtype == "object":
            try:
                df[col] = df[col].astype(float)
                conv_columns.append(col)
            except:
                pass
    log.info(f"Converted {conv_columns} from object to float.")
    n_rows, n_cols = df.shape
    row_thresh = int(row_thresh * n_cols)
    col_thresh = int(col_thresh * n_rows)
    # Remove rows and cols with more than thresh NaNs
    df = df.dropna(axis=0, thresh=row_thresh)
    df = df.dropna(axis=1, thresh=col_thresh)
    log.info(
        (f"Removed {n_rows - df.shape[0]} rows " f"and {n_cols - df.shape[1]} columns.")
    )
    # Fill missing values with mean and mode imputation
    n_nan = df.isna().sum().sum()
    for col in df.columns:
        if df[col].dtype == "object":
            df[col] = df[col].fillna(df[col].mode()[0])
        else:
            df[col] = df[col].fillna(df[col].mean())
    assert df.isna().sum().sum() == 0
    log.info(f"Imputed {n_nan} values")
    return df


def format_dates(df):
    """
    Convert string columns that contain dates to datetime objects and extract
    year, month, day and weekday as new features.
    """

    def is_date_column(series):
        # Possible date or datetime regex patterns
        date_regex = (
            r"^\d{4}-\d{2}-\d{2}( \d{2}:\d{2}:\d{2})?$|"
            + r"^\d{1,2}/\d{1,2}/\d{4}( \d{2}:\d{2}:\d{2})?$|"
            + r"^\d{1,2}-\d{1,2}-\d{4}( \d{2}:\d{2}:\d{2})?$|"
            + r"^\d{1,2} [ADFJMNOS]\w* \d{4}( \d{2}:\d{2}:\d{2})?$"
        )
        try:
            # Check if any value matches the date regex pattern
            matches = series.astype(str).str.match(date_regex).sum()
            # If more 50% match the pattern, consider it a date column
            return matches / len(series) > 0.5
        except Exception as e:
            return False

    # Convert string-date to datatime columns
    date_cols = []
    for col in df.columns:
        if is_date_column(df[col]):
            date_cols.append(col)
            df[col] = pd.to_datetime(df[col], errors="coerce")
    log.info(f"Converted {date_cols} to datetime objects.")
    # Extract year, month, day and weekday from datetime columns
    date_cols = []
    for col in df.columns:
        if df[col].dtype == "datetime64[ns]":
            date_cols.append(col)
            df[col + "_year"] = df[col].dt.year
            df[col + "_month"] = df[col].dt.month
            df[col + "_day"] = df[col].dt.day
            df[col + "_dayofweek"] = df[col].dt.dayofweek
            df = df.drop(columns=col)
    log.info(f"Extracted year, month, day and weekday from {date_cols}.")
    return df


def encode_categorical(df, name: str):
    """
    Apply one hot encoding to all categorical featues.
    """
    # Detect categorical columns
    categorical_cols = df.select_dtypes(include=["object"]).columns
    # encode predefined ordinal (Shoppers)
    months_map = {
        "Jan": 1,
        "Feb": 2,
        "Mar": 3,
        "APR": 4,
        "May": 5,
        "June": 6,
        "Jul": 7,
        "Aug": 8,
        "Sep": 9,
        "Oct": 10,
        "Nov": 11,
        "Dec": 12,
    }
    for col in categorical_cols:
        if len(set(months_map).intersection(df[col])) > 3:
            df[col] = df[col].map(months_map)
            df[col] = df[col].astype(int, errors="raise")
    # One-hot encode remaining categorical columns
    categorical_cols = df.select_dtypes(include=["object"]).columns
    if name == "Adult":
        df["income"] = df["income"].apply(lambda x: x.replace(".", ""))
    df = pd.get_dummies(df, columns=categorical_cols, drop_first=False, dtype="int")
    log.info(f"One-hot encoded {list(categorical_cols)}.")
    return df


def preprocessing(
    name: str, base_path: str = "./../../artifacts/data/"
) -> pd.DataFrame:
    """
    End to end preprocessing and of the dataset.
    """
    df, config = load_data(name, base_path)
    df = format_dates(df)
    df = handle_missing(df)
    df = encode_categorical(df, name)
    assert df.isna().sum().sum() == 0
    save_data(df, name, base_path)
    return df
