{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8eafda7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
    "\n",
    "import numpy as np\n",
    "from folktables import ACSDataSource\n",
    "import folktables\n",
    "from sklearn.preprocessing import OneHotEncoder, FunctionTransformer\n",
    "from sklearn.compose import ColumnTransformer\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn import set_config\n",
    "import pandas as pd\n",
    "import os\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "set_config(transform_output=\"pandas\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "62a5aef3",
   "metadata": {},
   "outputs": [],
   "source": [
    "OCCP_GROUPS = {\n",
    "    \"MGR\": range(10, 50),\n",
    "    \"BUS\": range(50, 100),\n",
    "    \"SCI\": range(100, 200),\n",
    "    \"ENG\": range(200, 300),\n",
    "    \"MED\": range(300, 400),\n",
    "    \"EDU\": range(400, 500),\n",
    "    \"ART\": range(500, 600),\n",
    "    \"LAW\": range(600, 700),\n",
    "    \"SVC\": range(700, 800),\n",
    "    \"SAL\": range(800, 900),\n",
    "    \"OFF\": range(900, 1000),\n",
    "    \"CON\": range(1000, 1100),\n",
    "    \"PRD\": range(1100, 1200),\n",
    "    \"TRN\": range(1200, 1300),\n",
    "}\n",
    "\n",
    "def occp_group_transform(X):\n",
    "    X = np.asarray(X).astype(int).ravel()\n",
    "\n",
    "    out = []\n",
    "    for occp in X:\n",
    "        label = \"OTHER\"\n",
    "        for k, r in OCCP_GROUPS.items():\n",
    "            if occp in r:\n",
    "                label = k\n",
    "                break\n",
    "        out.append(label)\n",
    "\n",
    "    return np.array(out).reshape(-1, 1)\n",
    "\n",
    "\n",
    "\n",
    "def binarize_mar(X):   # married vs not\n",
    "    return (X == 1).astype(int)\n",
    "\n",
    "def binarize_sex(X):\n",
    "    return (X == 1).astype(int)\n",
    "\n",
    "def binarize_dis(X):\n",
    "    return (X == 1).astype(int)\n",
    "\n",
    "def binarize_cit(X):   # citizen vs not\n",
    "    return (X != 5).astype(int)\n",
    "\n",
    "def binarize_hisp(X):  # non-hispanic vs hispanic\n",
    "    return (X == 1).astype(int)\n",
    "\n",
    "def binarize_mig(X):\n",
    "    return (X == 1).astype(int)\n",
    "\n",
    "def binarize_esr(X):   # in labor force vs not\n",
    "    return X #np.isin(X, [1, 2]).astype(int)\n",
    "\n",
    "def binarize_cow(X):\n",
    "    return (X == 1).astype(int)\n",
    "\n",
    "\n",
    "ordinal_features = [\n",
    "    \"SCHL\",\n",
    "    \"RELP\",\n",
    "    \"AGEP\",\n",
    "    \"WKHP\", \n",
    "    'RAC1P',\n",
    "    \"MAR\"\n",
    "]\n",
    "\n",
    "binary_features = {\n",
    "    \"SEX\": binarize_sex,\n",
    "    \"DIS\": binarize_dis,\n",
    "    \"HISP\": binarize_hisp,\n",
    "    \"MIG\": binarize_mig,\n",
    "    \"ESR\": binarize_esr,\n",
    "    \"COW\": binarize_cow\n",
    "}\n",
    "\n",
    "preprocessor = ColumnTransformer(\n",
    "    transformers=[\n",
    "        # Ordinal (pass-through)\n",
    "        (\"ord\", \"passthrough\", ordinal_features),\n",
    "\n",
    "        # Binary (custom)\n",
    "        *[\n",
    "            (\n",
    "                f\"bin_{col}\",\n",
    "                FunctionTransformer(fn, feature_names_out=\"one-to-one\"),\n",
    "                [col]\n",
    "            )\n",
    "            for col, fn in binary_features.items()\n",
    "        ],\n",
    "\n",
    "\n",
    "        # OCCP grouped → one-hot\n",
    "        (\n",
    "            \"occp\",\n",
    "            Pipeline(steps=[\n",
    "                (\"group\", FunctionTransformer(occp_group_transform)),\n",
    "                (\"ohe\", OneHotEncoder(sparse_output=False))\n",
    "            ]),\n",
    "            [\"OCCP\"]\n",
    "        ),\n",
    "\n",
    "    ],\n",
    "    remainder=\"drop\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f036747e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def correlation_with_label(df: pd.DataFrame, label_col: str, method: str = \"pearson\"):\n",
    "    if label_col not in df.columns:\n",
    "        raise ValueError(f\"Label column '{label_col}' not found in dataframe\")\n",
    "\n",
    "    # Keep only numeric columns\n",
    "    numeric_df = df.select_dtypes(include=\"number\")\n",
    "\n",
    "    if label_col not in numeric_df.columns:\n",
    "        raise ValueError(\"Label column must be numeric to compute correlation\")\n",
    "\n",
    "    corrs = numeric_df.corr(method=method)[label_col]\n",
    "\n",
    "    # Drop self-correlation and sort\n",
    "    corrs = corrs.drop(label_col).sort_values(key=abs, ascending=False)\n",
    "\n",
    "    return corrs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "69ee0695",
   "metadata": {},
   "outputs": [],
   "source": [
    "ACSIncome = folktables.BasicProblem(\n",
    "    features=[\n",
    "    'COW', #ONE HOT \n",
    "    'SCHL', #ordinal \n",
    "    'MAR', #binary \n",
    "    'OCCP', #ONE HOT according to code\n",
    "    'RELP', #consider ordinal \n",
    "    'SEX', #binary\n",
    "    'RAC1P', #ONE HOT \n",
    "    'AGEP', #ordinal \n",
    "    'DIS', #binary\n",
    "    'ESR', # binary\n",
    "    'HISP', #binary\n",
    "    'WKHP', #ordinal\n",
    "    'MIG', #binary\n",
    "    ],\n",
    "    target='PINCP',\n",
    "    target_transform=lambda x: (x > 50000).astype(int), #.replace({0: -1}),\n",
    "    preprocess=folktables.adult_filter,\n",
    "    postprocess=lambda x: x,  \n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "009051d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Label distribution BEFORE sampling:\n",
      "1 labels: 23430 (0.490)\n",
      "0 labels: 24351 (0.510)\n"
     ]
    }
   ],
   "source": [
    "task = ACSIncome\n",
    "weight_col = 'ord__WKHP'\n",
    "\n",
    "#------------------------\n",
    "# Load data\n",
    "#------------------------\n",
    "data_source = ACSDataSource(\n",
    "    survey_year='2018',\n",
    "    horizon='1-Year',\n",
    "    survey='person'\n",
    ")\n",
    "\n",
    "state = [\"NJ\"]\n",
    "acs_data = data_source.get_data(states=state, download=True)\n",
    "df_full, y_full, _ = task.df_to_pandas(acs_data)\n",
    "df_full = preprocessor.fit_transform(df_full)\n",
    "\n",
    "y_full = y_full.to_numpy()\n",
    "\n",
    "# ---- Print label ratios BEFORE sampling\n",
    "n_pos = np.sum(y_full == 1)\n",
    "n_neg = np.sum(y_full == 0)\n",
    "\n",
    "print(\"Label distribution BEFORE sampling:\")\n",
    "print(f\"1 labels: {n_pos} ({n_pos / len(y_full):.3f})\")\n",
    "print(f\"0 labels: {n_neg} ({n_neg / len(y_full):.3f})\")\n",
    "\n",
    "# ---- Indices for balanced sampling\n",
    "pos_idx = np.where(y_full == 1)[0]\n",
    "neg_idx = np.where(y_full == 0)[0]\n",
    "\n",
    "df_full[task.target] = y_full\n",
    "df_full[\"original_weight\"] = df_full[weight_col]\n",
    "\n",
    "cols =  [ 'ord__RAC1P',\n",
    "       'ord__MAR', 'bin_SEX__SEX', 'bin_DIS__DIS', 'bin_HISP__HISP',\n",
    "       'bin_MIG__MIG', 'bin_ESR__ESR', 'bin_COW__COW', 'occp__x0_ART',\n",
    "       'occp__x0_BUS', 'occp__x0_CON', 'occp__x0_EDU', 'occp__x0_ENG',\n",
    "       'occp__x0_LAW', 'occp__x0_MED', 'occp__x0_MGR', 'occp__x0_OFF',\n",
    "       'occp__x0_OTHER', 'occp__x0_PRD', 'occp__x0_SAL', 'occp__x0_SCI',\n",
    "       'occp__x0_SVC', 'occp__x0_TRN']\n",
    "rng = np.random.default_rng(42)  # 42 is the seed\n",
    "noise = rng.normal(0, 0.4, size=(len(df_full), len(cols)))\n",
    "df_full[cols] = df_full[cols].astype(float) + noise\n",
    "updated_features = [f for f in df_full.columns \n",
    "                    if ((f != weight_col) and (f != 'original_weight') and (f != task.target))]\n",
    "\n",
    "# all_x = df_full[updated_features].to_numpy()\n",
    "# all_y = df_full[task.target].to_numpy()\n",
    "# all_w = df_full[weight_col].to_numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8e87cb2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "folder = \"final_data\"\n",
    "os.makedirs(folder, exist_ok=True)\n",
    "\n",
    "df_full.to_csv(\"final_data/NJ_data_with_noise.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2ca8968",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "myenv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
