import hydra
from omegaconf import OmegaConf
from omegaconf import DictConfig as Config
import pandas as pd
import logging
import os

_log = logging.getLogger(__name__)


def calculate_imbalance_ratio(df, target_column):
    counts = df[target_column].value_counts()
    majority = counts.max()
    minority = counts.min()
    ir = round(majority / minority, 2)
    return ir, counts.to_dict()


# uncomment line below if running this file directly
# @hydra.main(version_base=None, config_path='../conf', config_name="datagen")
def preprocess_data(cfg:Config) -> None:
    dataset_cfg = cfg.dataset
    raw_path = os.path.join(cfg.paths.raw_data_dir, dataset_cfg.name + ".csv")
    save_path = os.path.join(cfg.paths.clean_data_dir, "clean_" + dataset_cfg.name + ".csv")

    # Load and filter
    df = pd.read_csv(raw_path)
    expected_cols = (
        dataset_cfg.numerical_columns + 
        dataset_cfg.categorical_columns + 
        [dataset_cfg.target_column]
    )
    df = df[[col for col in df.columns if col in expected_cols]]

    # Drop bad rows
    df = df.drop_duplicates()
    df = df.dropna()

    # Report IR
    ir, class_dist = calculate_imbalance_ratio(df, dataset_cfg.target_column)
    _log.info(f"[{dataset_cfg.name}] IR = {ir} | Class counts = {class_dist}")

    # Save clean dataset
    os.makedirs(cfg.paths.clean_data_dir, exist_ok=True)
    df.to_csv(save_path, index=False)
    _log.info(f"Cleaned data saved to: {save_path}")


if __name__ == "__main__":
    preprocess_data()