{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6e9c5ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "from folktables_data import get_folks\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f81b565",
   "metadata": {},
   "outputs": [],
   "source": [
    "from folktables import BasicProblem, ACSDataSource\n",
    "from sklearn.preprocessing import MinMaxScaler\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "657544f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "ACSEmploymentCustom = BasicProblem(\n",
    "    features=[\n",
    "        'AGEP',\n",
    "        'SCHL',\n",
    "        'MAR',\n",
    "        'RELP',\n",
    "        'ESP',\n",
    "        'CIT',\n",
    "        'MIL',\n",
    "    ],\n",
    "    target='ESR',\n",
    "    target_transform=lambda x: x == 1,\n",
    "    group='RAC1P',\n",
    "    preprocess=lambda x: x,\n",
    "    postprocess=lambda x: np.nan_to_num(x, -1),\n",
    ")\n",
    "data_source = ACSDataSource(survey_year='2018', horizon='1-Year', survey='person')\n",
    "acs_data = data_source.get_data(states=[\"AL\"], download=True)\n",
    "features, label, group = ACSEmploymentCustom.df_to_pandas(acs_data)\n",
    "X = pd.DataFrame(features)\n",
    "y = pd.Series(label.iloc[:, 0], dtype=int)\n",
    "y.loc[y == 0] = -1\n",
    "print(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9eb13a4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Create the histogram\n",
    "plt.hist(y, bins=[-1.5, -0.5, 0.5, 1.5], rwidth=0.8, color='skyblue', edgecolor='black')\n",
    "\n",
    "# Formatting the plot\n",
    "plt.xticks([-1, 1])\n",
    "plt.xlabel('Target Value (y)')\n",
    "plt.ylabel('Frequency')\n",
    "plt.title('Distribution of Employment Target (y)')\n",
    "\n",
    "# To save or show the plot\n",
    "plt.savefig('y_histogram.png')\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81df6460",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(X)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10ce713e",
   "metadata": {},
   "outputs": [],
   "source": [
    "X.reset_index(drop=True, inplace=True)\n",
    "y.reset_index(drop=True, inplace=True)\n",
    "\n",
    "X = pd.DataFrame(MinMaxScaler().fit_transform(X), columns=X.columns)\n",
    "to_flip = [f for f in X.columns if X[f].corr(y) < 0]\n",
    "X[to_flip] = 1 - X[to_flip]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49c61eac",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(X.shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "SC_PI_ENV",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
