"""Preprocess datasets from TextTabBench (TTB). Adapted from: https://github.com/mrazmartin/TextTabBench/tree/main"""

from pathlib import Path
import json
import pandas as pd
import numpy as np
import re
from sklearn.preprocessing import LabelEncoder

DATA_DIR = Path(__file__).parents[2] / "data/tabular"

# Load configurations from JSON files
with open(Path(__file__).parent / "ttb_data_config.json", "r") as f:
    data_configs = json.load(f)

with open(Path(__file__).parent / "ttb_extra_config.json", "r") as f:
    extra_configs = json.load(f)

all_configs = {**data_configs, **extra_configs}


missing_ratio_thresholds = {
    "customer_complaints": 0.5,
    "job_frauds": 0.5,
    "hs_cards": 0.63,
    "kickstarter": 0.5,
    "osha_accidents": 0.5,
    "spotify": 0.5,
    "airbnb": 0.5,
    "beer": 0.5,
    "calif_houses": 0.5,
    "laptops": 0.5,
    "mercari": 0.5,
    "sf_permits": 0.52,
    "wine": 0.6,
    "diabetes": 0.5,
    "lending_club": 0.5,
    "okcupid": 0.5,
    "covid_trials": 0.5,
    "drugs_rating": 0.54,
    "insurance_complaints": 0.5,
    "it_salary": 0.5,
    "stack_overflow": 0.5,
}

columns_to_drop = {
    "customer_complaints": [
        "Complaint ID",
        "Date sent to company",
        "Timely response?",
        "Consumer disputed?",
    ],
    "job_frauds": ["job_id"],
    "hs_cards": ["id"],
    "kickstarter": ["usd_pledged", "id"],
    "osha_accidents": [
        "summary_nr",
        "proj_cost",
        "proj_type",
        "nature_of_inj",
        "part_of_body",
        "event_type",
        "evn_factor",
        "hum_factor",
        "task_assigned",
    ],
    "spotify": ["track_id"],
    "airbnb": [
        "id",
        "listing_url",
        "thumbnail_url",
        "host_url",
        "medium_url",
        "picture_url",
        "xl_picture_url",
        "host_id",
        "host_thumbnail_url",
        "host_picture_url",  # so far removing bad features
        "name",
        "host_name",
        "host_location",
        "street",
        "host_neighbourhood",
        "host_listings_count",  # removing duplicate features
        "host_total_listings_count",
        "weekly_price",
        "cleaning_fee",  # removing potential leakage
        "maximum_nights",
        "first_review",
        "last_review",  # removing weak features
    ],
    "beer": ["review_aroma", "review_appearance", "review_palate", "review_taste"],
    "calif_houses": [
        "Id",
        "Address",  # non-informative
        "High School",
        "Middle School",
        "Elementary School",  # too fined grained
        "State",
        "Region",  # only 2 states and region is redundant to ZIP,
        "Sold Price",
        "Listed Price",  # way to correlated to target
        "Last Sold On",
        "Listed On",  # last sold is often empty, listed on is
    ],
    "laptops": ["link", "Part Number", "Model Number", "Model Name"],
    "mercari": ["train_id", "test_id"],
    "sf_permits": [
        "Permit Number",
        "Record ID",
        "Permit Expiration Date",
        "Completed Date",
        "First Construction Document Date",
        "Current Status Date",
        "Permit Creation Date",
    ],
    "wine": [],
    "diabetes": [],
    "lending_club": [
        "id",
        "member_id",
        "issue_d",
        "url",
        "last_pymnt_d",
        "last_credit_pull_d",
    ],
    "okcupid": ["last_online"],
    "covid_trials": [
        "Rank",
        "NCT Number",
        "First Posted",
        "URL",
        "Last Update Posted",
        "Primary Completion Date",
        "Other IDs",
        "Founded Bys",
        "Study Results",  # Study Results are post outcome data - leaks future
    ],
    "drugs_rating": ["drug_link", "medical_condition_url", "generic_name", "drug_name"],
    "insurance_complaints": ["File No."],
    "it_salary": ["Timestamp"],
    "stack_overflow": ["ResponseId"],
}


def _drop_empty_columns(data_df, threshold=0.5):
    """
    Drop columns with more than threshold missing values.
    """
    missing_values = data_df.isnull().mean()
    columns_to_drop = missing_values[missing_values > threshold].index

    # name the columns to be dropped
    print(f"Dropped: {columns_to_drop}")
    data_df = data_df.drop(columns=columns_to_drop)

    return data_df


def _drop_single_value_columns(data_df):
    """
    Drop columns with only one unique value.
    """
    unique_values = data_df.nunique()
    columns_to_drop = unique_values[unique_values == 1].index
    data_df = data_df.drop(columns=columns_to_drop)
    return data_df


def preprocess_data(dataset_name, dataset_config):
    print("========================================")
    print(f"Processing dataset: {dataset_name}")
    print("========================================")

    ## Load the data
    if dataset_config["task"] == "clf":
        dataset_subfolder = "classification/" + dataset_name
    elif dataset_config["task"] == "reg":
        dataset_subfolder = "regression/" + dataset_name
    else:
        raise ValueError(f"Unknown task: {dataset_config['task']}")

    raw_path = (
        DATA_DIR
        / "raw/texttabbench"
        / dataset_subfolder
        / dataset_config["rename_files"][0]
    )

    # Load data into dataframe
    try:
        df_file = pd.read_csv(raw_path, encoding="utf-8", on_bad_lines="skip")
    except UnicodeDecodeError:
        try:
            df_file = pd.read_csv(raw_path, encoding="latin-1", on_bad_lines="skip")
        except Exception as e:
            print(f"Failed to read {raw_path}: {e}")
            return

    # Handle large datasets
    if dataset_name in [
        "customer_complaints",
        "kickstarter",
        "calif_houses",
        "mercari",
    ]:

        # Downsample with fallback if not enough rows
        train_sample_size = min(
            len(df_file),
            (12000 if dataset_name in ["mercari", "calif_houses"] else 24000),
        )

        df_file = df_file.sample(n=train_sample_size, random_state=42).reset_index(
            drop=True
        )

    if dataset_name == "spotify":
        # there are way too many different targets -> downsample to 10 only
        target_genres = [
            "pop",
            "rock",
            "hip-hop",
            "jazz",
            "classical",
            "metal",
            "electronic",
            "indie",
            "r-n-b",
            "country",
        ]

        shape_before = df_file.shape
        # drop all rows which are not in the target_genres
        df_file.drop(
            df_file[~df_file["track_genre"].isin(target_genres)].index, inplace=True
        )
        print(f"Shape before: {shape_before}, after: {df_file.shape}")

    ## Run some basic data cleaning
    missing_ratio_threshold = missing_ratio_thresholds[dataset_name]

    df_size = df_file.shape
    if dataset_name == "stack_overflow":
        df_file = df_file[df_file["ConvertedCompYearly"].notna()]
    # 1. Drop columns with more than 50% missing values
    df_file = _drop_empty_columns(df_file, threshold=missing_ratio_threshold)
    # 2. Drop columns with only one unique value
    df_file = _drop_single_value_columns(df_file)
    # 3. remove duplicates
    df_file = df_file.drop_duplicates()
    # 4. remove rows with missing target values
    if dataset_name not in [
        "covid_trials",
        "sf_permits",
        "stack_overflow",
    ]:  # have not created the target column yet
        df_file = df_file[df_file[dataset_config["target"]].notna()]
    # 5. drop unnamed columns
    df_file = df_file.loc[:, ~df_file.columns.str.contains("^Unnamed")]
    print(f"Dataframe shape before/after basic cleaning: {df_size} / {df_file.shape}")

    ## Drop columns based on manual inspection
    cols_to_drop = columns_to_drop[dataset_name]
    df_size = df_file.shape
    for col in cols_to_drop:
        if col in df_file.columns:
            df_file.drop(col, axis=1, inplace=True)
        else:
            print(f"Column {col} not found in dataframe")
    print(f"Dataframe shape before/afrer by-hand cleaning: {df_size} / {df_file.shape}")

    ## Run some custom data cleaning
    print(f"Dataframe shape before custom cleaning: {df_file.shape}")

    if dataset_name == "hs_cards":
        # 1. drop the cards belonging to class 'deathknight' and 'dream'
        df_file = df_file[df_file["player_class"] != "DREAM"]
        df_file = df_file[df_file["player_class"] != "DEATHKNIGHT"]

    elif dataset_name == "osha_accidents":
        # 1. Convert 'Event Date' to timestamp (safe handling for NaT)
        df_file["Event Date"] = pd.to_datetime(
            df_file["Event Date"], format="%m/%d/%Y", errors="coerce"
        )
        df_file["Event Date"] = df_file["Event Date"].apply(
            lambda x: x.timestamp() if pd.notnull(x) else float("nan")
        )

    elif dataset_name == "customer_complaints":
        # 1. Drop non-essential target categories & make safe copy
        to_drop = [
            "Closed",
            "In progress",
            "Untimely response",
            "Closed with relief",
        ]
        df_file = df_file[~df_file[dataset_config["target"]].isin(to_drop)].copy()

        # 2. Convert 'Date received' to timestamp (handle NaT safely)
        df_file["Date received"] = pd.to_datetime(
            df_file["Date received"], format="%m/%d/%Y", errors="coerce"
        )
        df_file["Date received"] = df_file["Date received"].apply(
            lambda x: x.timestamp() if pd.notnull(x) else float("nan")
        )

        # 3. Clean 'ZIP code' column to remove non-numerical characters
        def clean_zip_code(series):
            """
            Cleans ZIP code values:
            - Removes non-numeric characters.
            - Truncates to first 3 digits.
            - Keeps NaN for invalid/masked ZIP codes.
            """
            cleaned = series.astype(str).str.extract(
                r"(\d{3})"
            )  # Extract first 3-digit prefix
            cleaned = cleaned[0].astype("Int64")  # Convert to float (handles NaN)
            return cleaned

        if "ZIP code" in df_file.columns:
            df_file["ZIP code"] = clean_zip_code(df_file["ZIP code"])

    elif dataset_name == "kickstarter":
        # 1. Convert to datetime safely
        df_file["launched_at"] = pd.to_datetime(df_file["launched_at"], errors="coerce")
        df_file["deadline"] = pd.to_datetime(df_file["deadline"], errors="coerce")

        # 2. Convert datetime to Unix timestamp as float (seconds since epoch)
        df_file["launched_at"] = (
            df_file["launched_at"].astype("int64") / 1e9
        )  # nanoseconds to seconds
        df_file["deadline"] = df_file["deadline"].astype("int64") / 1e9

    elif dataset_name == "airbnb":
        # 1. Convert 'host_since' to timestamp seconds
        if "host_since" in df_file.columns:
            df_file["host_since"] = pd.to_datetime(
                df_file["host_since"], format="%Y-%m-%d", errors="coerce"
            )
            df_file["host_since"] = (
                df_file["host_since"].astype("int64") / 1e9
            )  # nanoseconds to seconds
        # 2. Convert price to float: eg. "$85.00" -> 85.0
        if "price" in df_file.columns:
            df_file["price"] = df_file["price"].str.replace(r"[^0-9.]", "", regex=True)
            df_file["price"] = pd.to_numeric(df_file["price"], errors="coerce")

    elif dataset_name == "beer":
        # 1. Make sure all reviews scores are credible, so at least XX 'number of reviews'
        min_number_of_reviews = 5
        df_file.drop(
            df_file[df_file["number_of_reviews"] < min_number_of_reviews].index,
            inplace=True,
        )

    elif dataset_name == "calif_houses":
        if "Last Sold On" in df_file.columns:
            # convert to datetime
            df_file["Last Sold On"] = pd.to_datetime(
                df_file["Last Sold On"], format="%Y-%m-%d"
            )
            # convert to seconds
            df_file["Last Sold On"] = df_file["Last Sold On"].astype("int64") / 1e9

        if "Listed On" in df_file.columns:
            df_file["Last Sold On"] = pd.to_datetime(
                df_file["Last Sold On"], format="%Y-%m-%d"
            )
            df_file["Last Sold On"] = (
                df_file["Last Sold On"].astype("int64") / 1e9
            )  # nanoseconds to seconds

    elif dataset_name == "laptops":
        # remove the currency sign prefix
        # Step 1: Remove non-digit characters (like '?', ',')
        if df_file["Price"].dtype == "object":
            # Check if the column is of type object
            # If so, remove non-digit characters
            df_file["Price"] = df_file["Price"].str.replace(r"[^\d]", "", regex=True)

        # Step 2: Convert the column to integers
        df_file["Price"] = pd.to_numeric(df_file["Price"])

        # remove rows with missing values in the Price column
        df_file.dropna(subset=["Price"], inplace=True)

    elif dataset_name == "sf_permits":
        # 1. Create target column BEFORE overwriting dates
        filed_date = pd.to_datetime(df_file["Filed Date"], errors="coerce")
        issued_date = pd.to_datetime(df_file["Issued Date"], errors="coerce")

        df_file[dataset_config["target"]] = (issued_date - filed_date).dt.days
        # df_file.loc[issued_date.isna() | filed_date.isna(), dataset_config['target']] = float('inf') # setting to inf if either date is missing

        df_file.drop(["Issued Date"], axis=1, inplace=True)

        # 2. Split Location into latitude & longitude
        if "Location" in df_file.columns:
            location_split = df_file["Location"].str.split(",", expand=True)
            location_split[0] = (
                location_split[0]
                .str.replace("(", "", regex=False)
                .apply(pd.to_numeric, errors="coerce")
            )
            location_split[1] = (
                location_split[1]
                .str.replace(")", "", regex=False)
                .apply(pd.to_numeric, errors="coerce")
            )
            location_split.columns = ["Location_Latitude", "Location_Longitude"]
            df_file = pd.concat([df_file, location_split], axis=1)
            df_file.drop(["Location"], axis=1, inplace=True)

        if (
            True
        ):  # TODO: by default we transform all dataetime columns to UNIX timestamp
            # 3. Convert Filed Date to UNIX timestamp AFTER using it for target
            df_file["Filed Date"] = filed_date.apply(
                lambda x: x.timestamp() if pd.notnull(x) else np.nan
            )

        # 4. drop all rows with unreasonable large target values
        df_file = df_file[
            df_file[dataset_config["target"]] < 1000
        ]  # TODO: set a threshold for the target value

        # 5. drop all that are too short
        df_file = df_file[
            df_file[dataset_config["target"]] >= 0
        ]  # TODO: set a threshold for the target value

    elif dataset_name == "wine":
        # 1. Make the target column numeric
        # remove the currency sign prefix
        # Step 1: Remove non-digit characters (like '?', ',')
        if df_file["Price"].dtype == "object":
            # Check if the column is of type object
            # If so, remove non-digit characters
            df_file["Price"] = df_file["Price"].str.replace(r"[^\d]", "", regex=True)

        # Step 2: Convert the column to integers
        df_file["Price"] = pd.to_numeric(df_file["Price"])

        # remove rows with missing values in the Price column
        df_file.dropna(subset=["Price"], inplace=True)

        # 2. Make the ABV column numeric
        # Step 1: Remove non-digit characters (like '?', ',')
        if df_file["ABV"].dtype == "object":
            # Check if the column is of type object
            # If so, remove non-digit characters
            df_file["ABV"] = (
                df_file["ABV"]
                .str.replace(r"ABV\s*", "", regex=True)
                .str.replace("%", "", regex=False)
            )
            df_file["ABV"] = pd.to_numeric(df_file["ABV"], errors="coerce")

        # Step 2: Convert the column to integers
        df_file["ABV"] = pd.to_numeric(df_file["ABV"])

        # remove rows with missing values in the ABV column
        df_file.dropna(subset=["ABV"], inplace=True)

        # 3. Make the vintage have None when not an integer
        # Step 1: Remove non-digit characters (like '?', ',')
        if df_file["Vintage"].dtype == "object":
            # Check if the column is of type object
            # If so, remove non-digit characters
            df_file["Vintage"] = (
                df_file["Vintage"]
                .str.replace(r"Vintage\s*", "", regex=True)
                .str.replace("%", "", regex=False)
            )
            df_file["Vintage"] = pd.to_numeric(df_file["Vintage"], errors="coerce")

        # Step 2: Convert the column to integers
        df_file["Vintage"] = pd.to_numeric(df_file["Vintage"])

    elif dataset_name == "covid_trials":
        # Define valid statuses that imply study is finished
        valid_statuses = [
            "Completed",
            "Withdrawn",
            "Terminated",
            "Approved for marketing",
        ]
        # Define sanity thresholds
        max_reasonable_days = 365 * 10  # 10 years

        # 1. Filter to only studies with a final status
        df_file = df_file[df_file["Status"].isin(valid_statuses)].copy()

        # 2. Parse dates
        start_date = pd.to_datetime(df_file["Start Date"], errors="coerce")
        completion_date = pd.to_datetime(df_file["Completion Date"], errors="coerce")

        # 3. Compute target duration in days
        target_days = (completion_date - start_date).dt.days

        # 4. Add target column to df_file
        df_file[dataset_config["target"]] = target_days

        # 5. Drop rows with:
        #    - NaN target
        #    - Negative target (Completion before Start)
        #    - Unreasonably large durations (> 10 years)
        df_file = df_file[
            df_file[dataset_config["target"]].notna()
            & (df_file[dataset_config["target"]] >= 0)
            & (df_file[dataset_config["target"]] <= max_reasonable_days)
        ].copy()

        # 6. Drop 'Completion Date'
        df_file.drop(["Completion Date"], axis=1, inplace=True)

        # 7. Convert 'Start Date' to UNIX timestamp (only for kept rows)
        df_file["Start Date"] = start_date.loc[df_file.index].apply(
            lambda x: x.timestamp() if pd.notnull(x) else np.nan
        )

    elif dataset_name == "drugs_rating":
        # one robust pattern for all cases
        URL_RE = re.compile(
            r"(?:https?:\/\/)?"  # optional protocol
            r"\/?"  # optional leading slash
            r"(?:www\.)?"  # optional www.
            r"drugs\.com\/[^\s\|,]+"  # domain + path
        )

        # 1. remove any drugs.com URL
        if "related_drugs" in df_file.columns:
            df_file["related_drugs"] = (
                df_file["related_drugs"]
                # remove any drugs.com URL
                .str.replace(URL_RE, "", regex=True)
                # drop a colon immediately after the core drug name
                .str.replace(r"\s*:\s*", " ", regex=True)
                # drop bracketed notes
                .str.replace(r"\[.*?\]", "", regex=True)
                # tidy pipe spacing
                .str.replace(r"\s*\|\s*", " | ", regex=True)
                # drop stray commas
                .str.replace(",", " ", regex=False)
                # collapse repeated whitespace
                .str.replace(r"\s+", " ", regex=True)
                # trim leading/trailing pipes or spaces
                .str.strip(" |")
            )
        # 2. remove any rows with way too low number of ratings
        if "no_of_reviews" in df_file.columns:
            df_file = df_file[df_file["no_of_reviews"] > 5]

    elif dataset_name == "insurance_complaints":
        # Parse dates
        opened_dt = pd.to_datetime(df_file["Opened"], errors="coerce")
        closed_dt = pd.to_datetime(df_file["Closed"], errors="coerce")

        # Compute duration (days to resolve)
        df_file["days_to_resolve"] = (closed_dt - opened_dt).dt.days

        # Convert 'Opened' to timestamp (seconds since epoch)
        df_file["Opened_ts"] = opened_dt.apply(
            lambda x: x.timestamp() if pd.notnull(x) else np.nan
        )

        # Drop 'Closed' (already reflected in duration)
        df_file.drop(["Closed"], axis=1, inplace=True)

        # Drop original 'Opened' (keep only Opened_ts)
        df_file.drop(["Opened"], axis=1, inplace=True)

    elif dataset_name == "lending_club":
        # 1. Remove rows with target value 'Current'
        df_file = df_file[df_file[dataset_config["target"]] != "Current"].copy()

        # 2. Clean 'term': remove 'months' and convert to int
        if "term" in df_file.columns:
            df_file["term"] = (
                df_file["term"].astype(str).str.extract(r"(\d+)")[0].astype("Int64")
            )

        # 3. Clean 'int_rate': remove '%' and convert to float
        if "int_rate" in df_file.columns:
            df_file["int_rate"] = (
                df_file["int_rate"].astype(str).str.rstrip("%").astype(float)
            )

        # 4. Clean 'revol_util': remove '%' and convert to float
        if "revol_util" in df_file.columns:
            df_file["revol_util"] = (
                df_file["revol_util"].astype(str).str.rstrip("%").astype(float)
            )

        # 5. Drop leaking features (known post-loan outcome info)
        leakage_cols = [
            "total_pymnt",
            "total_pymnt_inv",
            "total_rec_prncp",
            "total_rec_int",
            "total_rec_late_fee",
            "recoveries",
            "collection_recovery_fee",
            "last_pymnt_amnt",
            "out_prncp",
            "out_prncp_inv",
            "pub_rec_bankruptcies",
        ]
        leakage_cols_to_drop = [col for col in leakage_cols if col in df_file.columns]
        df_file = df_file.drop(columns=leakage_cols_to_drop)

    elif dataset_name == "stack_overflow":
        # as special case for this dataset, we need to convert the salary to the currency selected by the user
        currency_to_usd = {
            "AED": 0.2723,  # United Arab Emirates dirham
            "AFN": 0.0154,  # Afghan afghani
            "ALL": 0.0105,  # Albanian lek
            "AMD": 0.0025,  # Armenian dram
            "AOA": 0.0012,  # Angolan kwanza
            "ARS": 0.0012,  # Argentine peso
            "AUD": 0.6810,  # Australian dollar
            "AZN": 0.5882,  # Azerbaijan manat
            "BAM": 0.5530,  # Bosnia and Herzegovina convertible mark
            "BDT": 0.0085,  # Bangladeshi taka
            "BGN": 0.5530,  # Bulgarian lev
            "BHD": 2.6596,  # Bahraini dinar
            "BIF": 0.00035,  # Burundi franc
            "BND": 0.7520,  # Brunei dollar
            "BOB": 0.1450,  # Bolivian boliviano
            "BRL": 0.2040,  # Brazilian real
            "BSD": 1.0000,  # Bahamian dollar
            "BTN": 0.0120,  # Bhutanese ngultrum
            "BWP": 0.0740,  # Botswana pula
            "BYN": 0.3930,  # Belarusian ruble
            "CAD": 0.7410,  # Canadian dollar
            "CHF": 1.1370,  # Swiss franc
            "CLP": 0.0012,  # Chilean peso
            "CNY": 0.1400,  # Chinese Yuan Renminbi
            "COP": 0.00025,  # Colombian peso
            "CRC": 0.0019,  # Costa Rican colon
            "CUP": 0.0417,  # Cuban peso
            "CVE": 0.0100,  # Cape Verdean escudo
            "CZK": 0.0430,  # Czech koruna
            "DKK": 0.1450,  # Danish krone
            "DOP": 0.0170,  # Dominican peso
            "DZD": 0.0074,  # Algerian dinar
            "EGP": 0.0320,  # Egyptian pound
            "ETB": 0.0175,  # Ethiopian birr
            "EUR": 1.0850,  # European Euro
            "FJD": 0.4510,  # Fijian dollar
            "FKP": 1.2700,  # Falkland Islands pound
            "GBP": 1.2700,  # Pound sterling
            "GEL": 0.3750,  # Georgian lari
            "GGP": 1.2700,  # Guernsey Pound
            "GHS": 0.0830,  # Ghanaian cedi
            "GTQ": 0.1280,  # Guatemalan quetzal
            "HKD": 0.1280,  # Hong Kong dollar
            "HNL": 0.0410,  # Honduran lempira
            "HUF": 0.0028,  # Hungarian forint
            "IDR": 0.000064,  # Indonesian rupiah
            "ILS": 0.2850,  # Israeli new shekel
            "IMP": 1.2700,  # Manx pound
            "INR": 0.0120,  # Indian rupee
            "IQD": 0.00076,  # Iraqi dinar
            "IRR": 0.000024,  # Iranian rial
            "ISK": 0.0073,  # Icelandic krona
            "JMD": 0.0065,  # Jamaican dollar
            "JOD": 1.4100,  # Jordanian dinar
            "JPY": 0.0066,  # Japanese yen
            "KES": 0.0063,  # Kenyan shilling
            "KGS": 0.0113,  # Kyrgyzstani som
            "KHR": 0.00024,  # Cambodian riel
            "KRW": 0.00077,  # South Korean won
            "KWD": 3.2500,  # Kuwaiti dinar
            "KZT": 0.0022,  # Kazakhstani tenge
            "LKR": 0.0031,  # Sri Lankan rupee
            "LYD": 0.2060,  # Libyan dinar
            "MAD": 0.0990,  # Moroccan dirham
            "MDL": 0.0560,  # Moldovan leu
            "MGA": 0.00022,  # Malagasy ariary
            "MKD": 0.0170,  # Macedonian denar
            "MMK": 0.00048,  # Myanmar kyat
            "MNT": 0.00029,  # Mongolian tugrik
            "MOP": 0.1240,  # Macanese pataca
            "MRU": 0.0250,  # Mauritanian ouguiya
            "MUR": 0.0220,  # Mauritian rupee
            "MVR": 0.0650,  # Maldivian rufiyaa
            "MWK": 0.00061,  # Malawian kwacha
            "MXN": 0.0580,  # Mexican peso
            "MYR": 0.2150,  # Malaysian ringgit
            "MZN": 0.0156,  # Mozambican metical
            "NAD": 0.0530,  # Namibian dollar
            "NGN": 0.0013,  # Nigerian naira
            "NIO": 0.0270,  # Nicaraguan cordoba
            "NOK": 0.0950,  # Norwegian krone
            "NPR": 0.0075,  # Nepalese rupee
            "NZD": 0.6210,  # New Zealand dollar
            "OMR": 2.6000,  # Omani rial
            "PEN": 0.2640,  # Peruvian sol
            "PHP": 0.0180,  # Philippine peso
            "PKR": 0.0036,  # Pakistani rupee
            "PLN": 0.2410,  # Polish zloty
            "PYG": 0.00014,  # Paraguayan guarani
            "QAR": 0.2747,  # Qatari riyal
            "RON": 0.2180,  # Romanian leu
            "RSD": 0.0091,  # Serbian dinar
            "RUB": 0.0108,  # Russian ruble
            "RWF": 0.00083,  # Rwandan franc
            "SAR": 0.2666,  # Saudi Arabian riyal
            "SEK": 0.0960,  # Swedish krona
            "SGD": 0.7520,  # Singapore dollar
            "SYP": 0.00040,  # Syrian pound
            "THB": 0.0280,  # Thai baht
            "TJS": 0.0910,  # Tajikistani somoni
            "TMT": 0.2850,  # Turkmen manat
            "TND": 0.3200,  # Tunisian dinar
            "TRY": 0.0340,  # Turkish lira
            "TTD": 0.1470,  # Trinidad and Tobago dollar
            "TWD": 0.0320,  # New Taiwan dollar
            "TZS": 0.00040,  # Tanzanian shilling
            "UAH": 0.0270,  # Ukrainian hryvnia
            "UGX": 0.00027,  # Ugandan shilling
            "USD": 1.0000,  # United States dollar
            "UYU": 0.0250,  # Uruguayan peso
            "UZS": 0.000082,  # Uzbekistani som
            "VES": 0.000029,  # Venezuelan bolivar
            "VND": 0.000041,  # Vietnamese dong
            "WST": 0.3650,  # Samoan tala
            "XAF": 0.0017,  # Central African CFA franc
            "XDR": 1.3400,  # SDR (Special Drawing Right)
            "XOF": 0.0017,  # West African CFA franc
            "XPF": 0.0091,  # CFP franc
            "YER": 0.0040,  # Yemeni rial
            "ZAR": 0.0530,  # South African rand
            "ZMW": 0.0520,  # Zambian kwacha
        }

        # Assuming:
        # - Your DataFrame is called df_file (or dataset_files_cleaned[0])
        # - Salary column is named 'Salary'
        # - Currency column is named 'Currency'
        # - You already have the `currency_to_usd` dictionary defined

        def extract_currency_code(value):
            """Extracts the currency code (e.g., 'USD') from values like 'USD United States dollar'."""
            if pd.isna(value):
                return None
            return str(value).split()[0].strip()

        # Add a column for currency codes (e.g., 'USD', 'EUR', etc.)
        df_file["Currency_Code"] = df_file["Currency"].apply(extract_currency_code)

        # Convert salary to USD for each row
        def convert_row_salary(row):
            currency_code = row["Currency_Code"]
            salary = row["ConvertedCompYearly"]
            if pd.isna(salary) or pd.isna(currency_code):
                return np.nan
            rate = currency_to_usd.get(currency_code)
            if rate is None:
                print(f"⚠️ Missing rate for {currency_code}")
                return np.nan
            return salary * rate

        # Create new column for salary in USD
        df_file["Yearly_Salary_USD"] = df_file.apply(convert_row_salary, axis=1)

        # now drop all unreasonable values
        # Define sensible salary boundaries in USD
        MIN_SALARY_USD = 1000
        MAX_SALARY_USD = 1_000_000

        # Filter out unreasonable salaries
        df_file = df_file[
            (df_file["Yearly_Salary_USD"] >= MIN_SALARY_USD)
            & (df_file["Yearly_Salary_USD"] <= MAX_SALARY_USD)
        ].copy()

        # drop the currency columns
        df_file = df_file.drop(columns=["Currency", "Currency_Code"])

    print(f"Dataframe shape after custom cleaning: {df_file.shape}")

    # Log transform the target column for regression tasks
    if dataset_config["task"] == "reg" and dataset_name != "beer":
        df_file[dataset_config["target"]] = df_file[dataset_config["target"]].apply(
            lambda x: np.log1p(x) if x > 0 else x
        )  # log1p to handle zero values safely

    ## Save preprocessed data
    ## Load the data
    if dataset_config["task"] == "clf":
        save_subfolder = "classification"
    elif dataset_config["task"] == "reg":
        save_subfolder = "regression"
    else:
        raise ValueError(f"Unknown task: {dataset_config['task']}")
    save_path = DATA_DIR / "processed" / save_subfolder / f"ttb_{dataset_name}.parquet"
    if (
        len(df_file) > 1050 and dataset_name != "okcupid"
    ):  # okcupid has too many long sentences
        print(f"Saved preprocessed data to: {save_path}. Shape: {df_file.shape}")
        df_file.to_parquet(save_path, index=False)
    else:
        print(f"⚠️ Skipping saving preprocessed data. Shape too small: {df_file.shape}")
    return


def add_label_encoding(dataset_name, dataset_config):
    """
    Add label encoding for the target column if it's categorical.
    """
    # Load the preprocessed data
    if dataset_config["task"] == "clf":
        dataset_subfolder = "classification"
    elif dataset_config["task"] == "reg":
        return
    else:
        raise ValueError(f"Unknown task: {dataset_config['task']}")
    data_path = (
        DATA_DIR / "processed" / dataset_subfolder / f"ttb_{dataset_name}.parquet"
    )
    data_df = pd.read_parquet(data_path)

    target_col = dataset_config.get("target")

    le = LabelEncoder()
    data_df[target_col] = le.fit_transform(data_df[target_col])

    # Save the updated data
    data_df.to_parquet(data_path, index=False)
    print(f"Updated data with label encoding saved to: {data_path}")
    return


if __name__ == "__main__":
    for dataset_name, dataset_config in all_configs.items():
        preprocess_data(dataset_name, dataset_config)
        add_label_encoding(dataset_name, dataset_config)
