{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7a4d4d6b",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T15:53:46.370202Z",
     "iopub.status.busy": "2024-10-18T15:53:46.369656Z",
     "iopub.status.idle": "2024-10-18T15:53:46.860412Z",
     "shell.execute_reply": "2024-10-18T15:53:46.859579Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset Info:\n",
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 1265 entries, 0 to 1264\n",
      "Data columns (total 22 columns):\n",
      " #   Column             Non-Null Count  Dtype  \n",
      "---  ------             --------------  -----  \n",
      " 0   loc                1265 non-null   float64\n",
      " 1   v(g)               1265 non-null   float64\n",
      " 2   ev(g)              1265 non-null   float64\n",
      " 3   iv(g)              1265 non-null   float64\n",
      " 4   n                  1265 non-null   float64\n",
      " 5   v                  1265 non-null   float64\n",
      " 6   l                  1265 non-null   float64\n",
      " 7   d                  1265 non-null   float64\n",
      " 8   i                  1265 non-null   float64\n",
      " 9   e                  1265 non-null   float64\n",
      " 10  b                  1265 non-null   float64\n",
      " 11  t                  1265 non-null   float64\n",
      " 12  lOCode             1265 non-null   float64\n",
      " 13  lOComment          1265 non-null   int64  \n",
      " 14  lOBlank            1265 non-null   int64  \n",
      " 15  locCodeAndComment  1265 non-null   int64  \n",
      " 16  uniq_Op            1265 non-null   float64\n",
      " 17  uniq_Opnd          1265 non-null   float64\n",
      " 18  total_Op           1265 non-null   float64\n",
      " 19  total_Opnd         1265 non-null   float64\n",
      " 20  branchCount        1265 non-null   float64\n",
      " 21  defects            1265 non-null   bool   \n",
      "dtypes: bool(1), float64(18), int64(3)\n",
      "memory usage: 208.9 KB\n",
      "None\n",
      "\n",
      "First few rows of the dataset:\n",
      "    loc  v(g)  ev(g)  iv(g)     n       v     l      d      i        e  ...  \\\n",
      "0   4.0   1.0    1.0    1.0   3.0    4.75  1.00   1.00   4.75     4.75  ...   \n",
      "1  19.0   2.0    1.0    2.0  39.0  176.42  0.17   5.79  30.49  1020.71  ...   \n",
      "2   2.0   1.0    1.0    1.0   4.0    8.00  0.67   1.50   5.33    12.00  ...   \n",
      "3  11.0   1.0    1.0    1.0  14.0   51.81  0.25   4.00  12.95   207.22  ...   \n",
      "4   9.0   2.0    1.0    2.0  41.0  180.09  0.08  12.19  14.78  2194.79  ...   \n",
      "\n",
      "   lOCode  lOComment  lOBlank  locCodeAndComment  uniq_Op  uniq_Opnd  \\\n",
      "0     0.0          0        0                  0      2.0        1.0   \n",
      "1    13.0          0        4                  0      9.0       14.0   \n",
      "2     0.0          0        0                  0      3.0        1.0   \n",
      "3     6.0          0        1                  0      8.0        5.0   \n",
      "4     6.0          0        1                  0     13.0        8.0   \n",
      "\n",
      "   total_Op  total_Opnd  branchCount  defects  \n",
      "0       2.0         1.0          1.0    False  \n",
      "1      21.0        18.0          3.0    False  \n",
      "2       3.0         1.0          1.0    False  \n",
      "3       9.0         5.0          1.0     True  \n",
      "4      26.0        15.0          3.0     True  \n",
      "\n",
      "[5 rows x 22 columns]\n",
      "\n",
      "Missing values in the dataset:\n",
      "loc                  0\n",
      "v(g)                 0\n",
      "ev(g)                0\n",
      "iv(g)                0\n",
      "n                    0\n",
      "v                    0\n",
      "l                    0\n",
      "d                    0\n",
      "i                    0\n",
      "e                    0\n",
      "b                    0\n",
      "t                    0\n",
      "lOCode               0\n",
      "lOComment            0\n",
      "lOBlank              0\n",
      "locCodeAndComment    0\n",
      "uniq_Op              0\n",
      "uniq_Opnd            0\n",
      "total_Op             0\n",
      "total_Opnd           0\n",
      "branchCount          0\n",
      "defects              0\n",
      "dtype: int64\n",
      "\n",
      "Summary statistics for numerical columns:\n",
      "               loc         v(g)        ev(g)        iv(g)            n  \\\n",
      "count  1265.000000  1265.000000  1265.000000  1265.000000  1265.000000   \n",
      "mean     20.759763     2.910198     1.695178     2.610593    51.220791   \n",
      "std      31.166471     4.121797     2.220058     3.598494    88.540725   \n",
      "min       1.000000     1.000000     1.000000     1.000000     0.000000   \n",
      "25%       3.000000     1.000000     1.000000     1.000000     4.000000   \n",
      "50%       9.000000     1.000000     1.000000     1.000000    15.000000   \n",
      "75%      25.000000     3.000000     1.000000     3.000000    60.000000   \n",
      "max     288.000000    45.000000    26.000000    45.000000  1106.000000   \n",
      "\n",
      "                 v            l            d            i              e  ...  \\\n",
      "count  1265.000000  1265.000000  1265.000000  1265.000000    1265.000000  ...   \n",
      "mean    267.523067     0.315360     6.845763    21.272348    5630.293684  ...   \n",
      "std     551.542472     0.318018     8.135206    22.097028   19540.927555  ...   \n",
      "min       0.000000     0.000000     0.000000     0.000000       0.000000  ...   \n",
      "25%       8.000000     0.070000     1.500000     5.330000      12.000000  ...   \n",
      "50%      53.770000     0.180000     3.500000    13.500000     209.280000  ...   \n",
      "75%     278.630000     0.640000     9.500000    30.330000    2391.620000  ...   \n",
      "max    7918.820000     2.000000    53.750000   193.060000  324803.510000  ...   \n",
      "\n",
      "                  t       lOCode    lOComment      lOBlank  locCodeAndComment  \\\n",
      "count   1265.000000  1265.000000  1265.000000  1265.000000        1265.000000   \n",
      "mean     312.796087    14.860870     0.992885     1.704348           0.147826   \n",
      "std     1085.606427    25.555281     3.367160     3.626710           0.813920   \n",
      "min        0.000000     0.000000     0.000000     0.000000           0.000000   \n",
      "25%        0.670000     0.000000     0.000000     0.000000           0.000000   \n",
      "50%       11.630000     5.000000     0.000000     0.000000           0.000000   \n",
      "75%      132.870000    17.000000     0.000000     2.000000           0.000000   \n",
      "max    18044.640000   262.000000    44.000000    35.000000          12.000000   \n",
      "\n",
      "           uniq_Op    uniq_Opnd     total_Op   total_Opnd  branchCount  \n",
      "count  1265.000000  1265.000000  1265.000000  1265.000000  1265.000000  \n",
      "mean      7.584348     9.577233    31.948775    19.273676     4.812964  \n",
      "std       5.727508    12.499156    54.780209    34.016774     8.226025  \n",
      "min       0.000000     0.000000     0.000000     0.000000     1.000000  \n",
      "25%       3.000000     1.000000     3.000000     1.000000     1.000000  \n",
      "50%       6.000000     5.000000    10.000000     6.000000     1.000000  \n",
      "75%      11.000000    14.000000    37.000000    23.000000     5.000000  \n",
      "max      31.000000   120.000000   678.000000   428.000000    89.000000  \n",
      "\n",
      "[8 rows x 21 columns]\n",
      "\n",
      "Distribution of the target column 'defects':\n",
      "defects\n",
      "False    0.841897\n",
      "True     0.158103\n",
      "Name: proportion, dtype: float64\n",
      "\n",
      "Categorical columns:\n",
      "Index([], dtype='object')\n",
      "\n",
      "Numerical columns:\n",
      "Index(['loc', 'v(g)', 'ev(g)', 'iv(g)', 'n', 'v', 'l', 'd', 'i', 'e', 'b', 't',\n",
      "       'lOCode', 'lOComment', 'lOBlank', 'locCodeAndComment', 'uniq_Op',\n",
      "       'uniq_Opnd', 'total_Op', 'total_Opnd', 'branchCount'],\n",
      "      dtype='object')\n",
      "\n",
      "Correlation matrix for numerical features:\n",
      "                        loc      v(g)     ev(g)     iv(g)         n         v  \\\n",
      "loc                1.000000  0.904941  0.727279  0.897884  0.946337  0.935685   \n",
      "v(g)               0.904941  1.000000  0.808838  0.966308  0.925320  0.915120   \n",
      "ev(g)              0.727279  0.808838  1.000000  0.771624  0.761390  0.766888   \n",
      "iv(g)              0.897884  0.966308  0.771624  1.000000  0.895480  0.883756   \n",
      "n                  0.946337  0.925320  0.761390  0.895480  1.000000  0.994629   \n",
      "v                  0.935685  0.915120  0.766888  0.883756  0.994629  1.000000   \n",
      "l                 -0.419275 -0.369902 -0.262198 -0.357042 -0.396948 -0.360246   \n",
      "d                  0.842842  0.867440  0.709009  0.812041  0.882704  0.847344   \n",
      "i                  0.800354  0.677607  0.524859  0.677155  0.828330  0.800710   \n",
      "e                  0.823096  0.861730  0.744918  0.807998  0.910644  0.937281   \n",
      "b                  0.906241  0.888063  0.745416  0.857667  0.963827  0.969666   \n",
      "t                  0.823095  0.861729  0.744918  0.807998  0.910643  0.937281   \n",
      "lOCode             0.985030  0.925697  0.739334  0.922196  0.962132  0.950659   \n",
      "lOComment          0.684740  0.517000  0.427639  0.522013  0.536925  0.533080   \n",
      "lOBlank            0.829603  0.710369  0.579732  0.674606  0.771710  0.761379   \n",
      "locCodeAndComment  0.410060  0.376984  0.398340  0.348347  0.393152  0.410069   \n",
      "uniq_Op            0.777585  0.761583  0.624729  0.721651  0.797305  0.754043   \n",
      "uniq_Opnd          0.912365  0.824049  0.679190  0.798644  0.942163  0.926096   \n",
      "total_Op           0.944421  0.928561  0.761999  0.904119  0.998203  0.991426   \n",
      "total_Opnd         0.942262  0.913110  0.754662  0.874805  0.995333  0.992272   \n",
      "branchCount        0.904580  0.998748  0.810061  0.966100  0.926081  0.916020   \n",
      "\n",
      "                          l         d         i         e  ...         t  \\\n",
      "loc               -0.419275  0.842842  0.800354  0.823096  ...  0.823095   \n",
      "v(g)              -0.369902  0.867440  0.677607  0.861730  ...  0.861729   \n",
      "ev(g)             -0.262198  0.709009  0.524859  0.744918  ...  0.744918   \n",
      "iv(g)             -0.357042  0.812041  0.677155  0.807998  ...  0.807998   \n",
      "n                 -0.396948  0.882704  0.828330  0.910644  ...  0.910643   \n",
      "v                 -0.360246  0.847344  0.800710  0.937281  ...  0.937281   \n",
      "l                  1.000000 -0.482719 -0.401692 -0.243223  ... -0.243219   \n",
      "d                 -0.482719  1.000000  0.678330  0.766586  ...  0.766585   \n",
      "i                 -0.401692  0.678330  1.000000  0.592775  ...  0.592773   \n",
      "e                 -0.243223  0.766586  0.592775  1.000000  ...  1.000000   \n",
      "b                 -0.328547  0.819583  0.772323  0.909858  ...  0.909867   \n",
      "t                 -0.243219  0.766585  0.592773  1.000000  ...  1.000000   \n",
      "lOCode            -0.406069  0.857244  0.812209  0.842529  ...  0.842529   \n",
      "lOComment         -0.204070  0.468797  0.449673  0.432617  ...  0.432617   \n",
      "lOBlank           -0.344587  0.727240  0.641573  0.657751  ...  0.657751   \n",
      "locCodeAndComment -0.125750  0.327750  0.276608  0.421041  ...  0.421044   \n",
      "uniq_Op           -0.548106  0.924174  0.726877  0.608767  ...  0.608765   \n",
      "uniq_Opnd         -0.442773  0.830979  0.932769  0.762926  ...  0.762925   \n",
      "total_Op          -0.399218  0.881882  0.826123  0.905785  ...  0.905785   \n",
      "total_Opnd        -0.390175  0.877343  0.825600  0.911594  ...  0.911594   \n",
      "branchCount       -0.370004  0.868067  0.677766  0.863166  ...  0.863166   \n",
      "\n",
      "                     lOCode  lOComment   lOBlank  locCodeAndComment   uniq_Op  \\\n",
      "loc                0.985030   0.684740  0.829603           0.410060  0.777585   \n",
      "v(g)               0.925697   0.517000  0.710369           0.376984  0.761583   \n",
      "ev(g)              0.739334   0.427639  0.579732           0.398340  0.624729   \n",
      "iv(g)              0.922196   0.522013  0.674606           0.348347  0.721651   \n",
      "n                  0.962132   0.536925  0.771710           0.393152  0.797305   \n",
      "v                  0.950659   0.533080  0.761379           0.410069  0.754043   \n",
      "l                 -0.406069  -0.204070 -0.344587          -0.125750 -0.548106   \n",
      "d                  0.857244   0.468797  0.727240           0.327750  0.924174   \n",
      "i                  0.812209   0.449673  0.641573           0.276608  0.726877   \n",
      "e                  0.842529   0.432617  0.657751           0.421041  0.608767   \n",
      "b                  0.921822   0.521209  0.741981           0.416431  0.725222   \n",
      "t                  0.842529   0.432617  0.657751           0.421044  0.608765   \n",
      "lOCode             1.000000   0.606200  0.778142           0.383096  0.789319   \n",
      "lOComment          0.606200   1.000000  0.635630           0.311575  0.441710   \n",
      "lOBlank            0.778142   0.635630  1.000000           0.339919  0.698647   \n",
      "locCodeAndComment  0.383096   0.311575  0.339919           1.000000  0.301425   \n",
      "uniq_Op            0.789319   0.441710  0.698647           0.301425  1.000000   \n",
      "uniq_Opnd          0.917983   0.536348  0.776149           0.352005  0.842047   \n",
      "total_Op           0.961780   0.534274  0.764391           0.390400  0.798052   \n",
      "total_Opnd         0.955425   0.537156  0.777681           0.394705  0.790041   \n",
      "branchCount        0.925254   0.517805  0.708736           0.377859  0.762300   \n",
      "\n",
      "                   uniq_Opnd  total_Op  total_Opnd  branchCount  \n",
      "loc                 0.912365  0.944421    0.942262     0.904580  \n",
      "v(g)                0.824049  0.928561    0.913110     0.998748  \n",
      "ev(g)               0.679190  0.761999    0.754662     0.810061  \n",
      "iv(g)               0.798644  0.904119    0.874805     0.966100  \n",
      "n                   0.942163  0.998203    0.995333     0.926081  \n",
      "v                   0.926096  0.991426    0.992272     0.916020  \n",
      "l                  -0.442773 -0.399218   -0.390175    -0.370004  \n",
      "d                   0.830979  0.881882    0.877343     0.868067  \n",
      "i                   0.932769  0.826123    0.825600     0.677766  \n",
      "e                   0.762926  0.905785    0.911594     0.863166  \n",
      "b                   0.896522  0.960730    0.961829     0.888646  \n",
      "t                   0.762925  0.905785    0.911594     0.863166  \n",
      "lOCode              0.917983  0.961780    0.955425     0.925254  \n",
      "lOComment           0.536348  0.534274    0.537156     0.517805  \n",
      "lOBlank             0.776149  0.764391    0.777681     0.708736  \n",
      "locCodeAndComment   0.352005  0.390400    0.394705     0.377859  \n",
      "uniq_Op             0.842047  0.798052    0.790041     0.762300  \n",
      "uniq_Opnd           1.000000  0.933822    0.948466     0.825028  \n",
      "total_Op            0.933822  1.000000    0.987762     0.929219  \n",
      "total_Opnd          0.948466  0.987762    1.000000     0.914030  \n",
      "branchCount         0.825028  0.929219    0.914030     1.000000  \n",
      "\n",
      "[21 rows x 21 columns]\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "# Load the training dataset\n",
    "train_df = pd.read_csv('/data/datasets/kc1/split_train.csv')\n",
    "\n",
    "# Display basic information about the dataset\n",
    "print(\"Dataset Info:\")\n",
    "print(train_df.info())\n",
    "\n",
    "# Display the first few rows of the dataset\n",
    "print(\"\\nFirst few rows of the dataset:\")\n",
    "print(train_df.head())\n",
    "\n",
    "# Check for missing values\n",
    "print(\"\\nMissing values in the dataset:\")\n",
    "print(train_df.isnull().sum())\n",
    "\n",
    "# Summary statistics for numerical columns\n",
    "print(\"\\nSummary statistics for numerical columns:\")\n",
    "print(train_df.describe())\n",
    "\n",
    "# Check the distribution of the target column 'defects'\n",
    "print(\"\\nDistribution of the target column 'defects':\")\n",
    "print(train_df['defects'].value_counts(normalize=True))\n",
    "\n",
    "# Identify categorical and numerical columns\n",
    "categorical_cols = train_df.select_dtypes(include=['object']).columns\n",
    "numerical_cols = train_df.select_dtypes(include=['number']).columns\n",
    "\n",
    "print(\"\\nCategorical columns:\")\n",
    "print(categorical_cols)\n",
    "print(\"\\nNumerical columns:\")\n",
    "print(numerical_cols)\n",
    "\n",
    "# Correlation matrix for numerical features\n",
    "print(\"\\nCorrelation matrix for numerical features:\")\n",
    "print(train_df[numerical_cols].corr())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2b45a229",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T15:53:46.875222Z",
     "iopub.status.busy": "2024-10-18T15:53:46.874955Z",
     "iopub.status.idle": "2024-10-18T15:53:47.298334Z",
     "shell.execute_reply": "2024-10-18T15:53:47.297501Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "First few rows of the preprocessed train set:\n",
      "        loc      v(g)     ev(g)     iv(g)         n         v         l  \\\n",
      "0 -0.537962 -0.463621 -0.313259 -0.447751 -0.544832 -0.476621  2.153684   \n",
      "1 -0.056486 -0.220913 -0.313259 -0.169747 -0.138079 -0.165244 -0.457260   \n",
      "2 -0.602159 -0.463621 -0.313259 -0.447751 -0.533534 -0.470727  1.115598   \n",
      "3 -0.313273 -0.463621 -0.313259 -0.447751 -0.420547 -0.391263 -0.205603   \n",
      "4 -0.377470 -0.220913 -0.313259 -0.169747 -0.115482 -0.158587 -0.740375   \n",
      "\n",
      "          d         i         e  ...         t    lOCode  lOComment   lOBlank  \\\n",
      "0 -0.718860 -0.748014 -0.287999  ... -0.288005 -0.581749   -0.29499 -0.470129   \n",
      "1 -0.129828  0.417309 -0.235987  ... -0.235985 -0.072846   -0.29499  0.633235   \n",
      "2 -0.657375 -0.721756 -0.287628  ... -0.287627 -0.581749   -0.29499 -0.470129   \n",
      "3 -0.349947 -0.376776 -0.277634  ... -0.277638 -0.346871   -0.29499 -0.194288   \n",
      "4  0.657187 -0.293927 -0.175880  ... -0.175885 -0.346871   -0.29499 -0.194288   \n",
      "\n",
      "   locCodeAndComment   uniq_Op  uniq_Opnd  total_Op  total_Opnd  branchCount  \n",
      "0          -0.181694 -0.975390  -0.686496 -0.546924   -0.537408    -0.463708  \n",
      "1          -0.181694  0.247265   0.353985 -0.199946   -0.037457    -0.220481  \n",
      "2          -0.181694 -0.800725  -0.686496 -0.528662   -0.537408    -0.463708  \n",
      "3          -0.181694  0.072600  -0.366348 -0.419090   -0.419773    -0.463708  \n",
      "4          -0.181694  0.945925  -0.126237 -0.108636   -0.125684    -0.220481  \n",
      "\n",
      "[5 rows x 21 columns]\n"
     ]
    }
   ],
   "source": [
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "# Copy the DataFrame to avoid modifying the original\n",
    "train_df_copy = train_df.copy()\n",
    "dev_df = pd.read_csv('/data/datasets/kc1/split_dev.csv')\n",
    "dev_df_copy = dev_df.copy()\n",
    "test_df = pd.read_csv('/data/datasets/kc1/split_test_wo_target.csv')\n",
    "test_df_copy = test_df.copy()\n",
    "\n",
    "# Separate the target column\n",
    "X_train = train_df_copy.drop(columns=['defects'])\n",
    "y_train = train_df_copy['defects']\n",
    "X_dev = dev_df_copy.drop(columns=['defects'])\n",
    "y_dev = dev_df_copy['defects']\n",
    "X_test = test_df_copy\n",
    "\n",
    "# Scale numerical features\n",
    "scaler = StandardScaler()\n",
    "numerical_cols = X_train.select_dtypes(include=['number']).columns\n",
    "X_train[numerical_cols] = scaler.fit_transform(X_train[numerical_cols])\n",
    "X_dev[numerical_cols] = scaler.transform(X_dev[numerical_cols])\n",
    "X_test[numerical_cols] = scaler.transform(X_test[numerical_cols])\n",
    "\n",
    "# Print the first few rows of the preprocessed train set\n",
    "print(\"First few rows of the preprocessed train set:\")\n",
    "print(X_train.head())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "066e3036",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T15:53:52.214320Z",
     "iopub.status.busy": "2024-10-18T15:53:52.213675Z",
     "iopub.status.idle": "2024-10-18T15:53:53.319065Z",
     "shell.execute_reply": "2024-10-18T15:53:53.318165Z"
    }
   },
   "outputs": [],
   "source": [
    "from metagpt.tools.libs.data_preprocess import get_column_info\n",
    "\n",
    "# Using the train_df DataFrame from the finished tasks\n",
    "column_info = get_column_info(train_df)\n",
    "print(\"column_info\")\n",
    "print(column_info)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d94555f8",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T15:54:10.332778Z",
     "iopub.status.busy": "2024-10-18T15:54:10.331799Z",
     "iopub.status.idle": "2024-10-18T15:54:10.379058Z",
     "shell.execute_reply": "2024-10-18T15:54:10.378079Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "First few rows of the feature-engineered train set:\n",
      "        loc      v(g)     ev(g)     iv(g)         n         v         l  \\\n",
      "0 -0.537962 -0.463621 -0.313259 -0.447751 -0.544832 -0.476621  2.153684   \n",
      "1 -0.056486 -0.220913 -0.313259 -0.169747 -0.138079 -0.165244 -0.457260   \n",
      "2 -0.602159 -0.463621 -0.313259 -0.447751 -0.533534 -0.470727  1.115598   \n",
      "3 -0.313273 -0.463621 -0.313259 -0.447751 -0.420547 -0.391263 -0.205603   \n",
      "4 -0.377470 -0.220913 -0.313259 -0.169747 -0.115482 -0.158587 -0.740375   \n",
      "\n",
      "          d         i         e  ...  uniq_Op uniq_Opnd  uniq_Op total_Op  \\\n",
      "0 -0.718860 -0.748014 -0.287999  ...           0.669602          0.533465   \n",
      "1 -0.129828  0.417309 -0.235987  ...           0.087528         -0.049440   \n",
      "2 -0.657375 -0.721756 -0.287628  ...           0.549695          0.423313   \n",
      "3 -0.349947 -0.376776 -0.277634  ...          -0.026597         -0.030426   \n",
      "4  0.657187 -0.293927 -0.175880  ...          -0.119411         -0.102762   \n",
      "\n",
      "   uniq_Op total_Opnd  uniq_Op branchCount  uniq_Opnd total_Op  \\\n",
      "0            0.524183             0.452296            0.375461   \n",
      "1           -0.009262            -0.054517           -0.070778   \n",
      "2            0.430317             0.371303            0.362925   \n",
      "3           -0.030475            -0.033665            0.153533   \n",
      "4           -0.118888            -0.208558            0.013714   \n",
      "\n",
      "   uniq_Opnd total_Opnd  uniq_Opnd branchCount  total_Op total_Opnd  \\\n",
      "0              0.368929               0.318334             0.293922   \n",
      "1             -0.013259              -0.078047             0.007489   \n",
      "2              0.368929               0.318334             0.284108   \n",
      "3              0.153783               0.169879             0.175923   \n",
      "4              0.015866               0.027833             0.013654   \n",
      "\n",
      "   total_Op branchCount  total_Opnd branchCount  \n",
      "0              0.253613                0.249201  \n",
      "1              0.044084                0.008259  \n",
      "2              0.245145                0.249201  \n",
      "3              0.194335                0.194652  \n",
      "4              0.023952                0.027711  \n",
      "\n",
      "[5 rows x 231 columns]\n"
     ]
    }
   ],
   "source": [
    "from sklearn.preprocessing import PolynomialFeatures\n",
    "\n",
    "# Copy the dataframes to avoid modifying the original ones\n",
    "X_train_fe = X_train.copy()\n",
    "X_dev_fe = X_dev.copy()\n",
    "X_test_fe = X_test.copy()\n",
    "\n",
    "# Generate polynomial features\n",
    "poly = PolynomialFeatures(degree=2, interaction_only=True, include_bias=False)\n",
    "X_train_fe = poly.fit_transform(X_train_fe)\n",
    "X_dev_fe = poly.transform(X_dev_fe)\n",
    "X_test_fe = poly.transform(X_test_fe)\n",
    "\n",
    "# Convert back to DataFrame for easier handling\n",
    "X_train_fe = pd.DataFrame(X_train_fe, columns=poly.get_feature_names_out(X_train.columns))\n",
    "X_dev_fe = pd.DataFrame(X_dev_fe, columns=poly.get_feature_names_out(X_dev.columns))\n",
    "X_test_fe = pd.DataFrame(X_test_fe, columns=poly.get_feature_names_out(X_test.columns))\n",
    "\n",
    "# Print the first few rows of the feature-engineered train set\n",
    "print(\"First few rows of the feature-engineered train set:\")\n",
    "print(X_train_fe.head())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "03aed2ed",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T15:54:15.409254Z",
     "iopub.status.busy": "2024-10-18T15:54:15.408583Z",
     "iopub.status.idle": "2024-10-18T15:54:15.434178Z",
     "shell.execute_reply": "2024-10-18T15:54:15.433167Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "column_info\n",
      "{'Category': [], 'Numeric': ['loc', 'v(g)', 'ev(g)', 'iv(g)', 'n', 'Too many cols, omission here...'], 'Datetime': [], 'Others': []}\n"
     ]
    }
   ],
   "source": [
    "from metagpt.tools.libs.data_preprocess import get_column_info\n",
    "\n",
    "# Check the latest column information for the feature-engineered train set\n",
    "column_info = get_column_info(X_train_fe)\n",
    "print(\"column_info\")\n",
    "print(column_info)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6779c7f2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T15:54:35.092040Z",
     "iopub.status.busy": "2024-10-18T15:54:35.091381Z",
     "iopub.status.idle": "2024-10-18T15:54:35.902902Z",
     "shell.execute_reply": "2024-10-18T15:54:35.902026Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dev F1 Score: 0.45652173913043476\n",
      "Train F1 Score: 0.8895027624309392\n",
      "Dev F1 Score: 0.45652173913043476\n"
     ]
    }
   ],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.metrics import f1_score\n",
    "import joblib\n",
    "\n",
    "# Train a base model using RandomForestClassifier\n",
    "base_model = RandomForestClassifier(n_estimators=100, max_depth=10, random_state=42)\n",
    "base_model.fit(X_train_fe, y_train)\n",
    "\n",
    "# Predict on the dev set\n",
    "y_dev_pred = base_model.predict(X_dev_fe)\n",
    "\n",
    "# Calculate F1 score on the dev set\n",
    "dev_f1 = f1_score(y_dev, y_dev_pred)\n",
    "print(f\"Dev F1 Score: {dev_f1}\")\n",
    "\n",
    "# Predict on the test set\n",
    "y_test_pred = base_model.predict(X_test_fe)\n",
    "\n",
    "# Save the predictions\n",
    "dev_predictions = pd.DataFrame({'target': y_dev_pred})\n",
    "test_predictions = pd.DataFrame({'target': y_test_pred})\n",
    "\n",
    "dev_predictions.to_csv('../workspace/kc1/dev_predictions.csv', index=False)\n",
    "test_predictions.to_csv('../workspace/kc1/test_predictions.csv', index=False)\n",
    "\n",
    "# Print the train and dev set performance\n",
    "train_f1 = f1_score(y_train, base_model.predict(X_train_fe))\n",
    "print(f\"Train F1 Score: {train_f1}\")\n",
    "print(f\"Dev F1 Score: {dev_f1}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c4f4bba9",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T15:54:55.750245Z",
     "iopub.status.busy": "2024-10-18T15:54:55.749588Z",
     "iopub.status.idle": "2024-10-18T15:54:56.522922Z",
     "shell.execute_reply": "2024-10-18T15:54:56.522047Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dev F1 Score: 0.45652173913043476\n",
      "Train F1 Score: 0.8895027624309392\n",
      "Dev F1 Score: 0.45652173913043476\n"
     ]
    }
   ],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.metrics import f1_score\n",
    "import joblib\n",
    "\n",
    "# Train a base model using RandomForestClassifier\n",
    "base_model = RandomForestClassifier(n_estimators=100, max_depth=10, random_state=42)\n",
    "base_model.fit(X_train_fe, y_train)\n",
    "\n",
    "# Predict on the dev set\n",
    "y_dev_pred = base_model.predict(X_dev_fe)\n",
    "\n",
    "# Calculate F1 score on the dev set\n",
    "dev_f1 = f1_score(y_dev, y_dev_pred)\n",
    "print(f\"Dev F1 Score: {dev_f1}\")\n",
    "\n",
    "# Predict on the test set\n",
    "y_test_pred = base_model.predict(X_test_fe)\n",
    "\n",
    "# Save the predictions\n",
    "dev_predictions = pd.DataFrame({'target': y_dev_pred})\n",
    "test_predictions = pd.DataFrame({'target': y_test_pred})\n",
    "\n",
    "dev_predictions.to_csv('../workspace/kc1/dev_predictions.csv', index=False)\n",
    "test_predictions.to_csv('../workspace/kc1/test_predictions.csv', index=False)\n",
    "\n",
    "# Print the train and dev set performance\n",
    "train_f1 = f1_score(y_train, base_model.predict(X_train_fe))\n",
    "print(f\"Train F1 Score: {train_f1}\")\n",
    "print(f\"Dev F1 Score: {dev_f1}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bc8de177",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T15:55:01.681829Z",
     "iopub.status.busy": "2024-10-18T15:55:01.681047Z",
     "iopub.status.idle": "2024-10-18T15:55:01.696082Z",
     "shell.execute_reply": "2024-10-18T15:55:01.695159Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "column_info\n",
      "{'Category': [], 'Numeric': ['loc', 'v(g)', 'ev(g)', 'iv(g)', 'n', 'Too many cols, omission here...'], 'Datetime': [], 'Others': []}\n"
     ]
    }
   ],
   "source": [
    "from metagpt.tools.libs.data_preprocess import get_column_info\n",
    "\n",
    "# Check the latest column information for the feature-engineered train set\n",
    "column_info = get_column_info(X_train_fe)\n",
    "print(\"column_info\")\n",
    "print(column_info)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b73ef9c8",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T15:55:31.356228Z",
     "iopub.status.busy": "2024-10-18T15:55:31.355574Z",
     "iopub.status.idle": "2024-10-18T15:55:40.652637Z",
     "shell.execute_reply": "2024-10-18T15:55:40.651570Z"
    }
   },
   "outputs": [],
   "source": [
    "from sklearn.ensemble import StackingClassifier\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from xgboost import XGBClassifier\n",
    "from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "# Define base models\n",
    "base_models = [\n",
    "    ('rf', RandomForestClassifier(n_estimators=100, max_depth=10, random_state=42)),\n",
    "    ('xgb', XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42)),\n",
    "    ('et', ExtraTreesClassifier(n_estimators=100, max_depth=10, random_state=42)),\n",
    "    ('knn', KNeighborsClassifier(n_neighbors=5))\n",
    "]\n",
    "\n",
    "# Define the meta-model\n",
    "meta_model = LogisticRegression()\n",
    "\n",
    "# Create the stacking ensemble model\n",
    "stacking_model = StackingClassifier(estimators=base_models, final_estimator=meta_model, cv=5)\n",
    "\n",
    "# Train the stacking model\n",
    "stacking_model.fit(X_train_fe, y_train)\n",
    "\n",
    "# Predict on the dev set\n",
    "y_dev_pred = stacking_model.predict(X_dev_fe)\n",
    "\n",
    "# Calculate F1 score on the dev set\n",
    "dev_f1 = f1_score(y_dev, y_dev_pred)\n",
    "print(f\"Dev F1 Score: {dev_f1}\")\n",
    "\n",
    "# Predict on the test set\n",
    "y_test_pred = stacking_model.predict(X_test_fe)\n",
    "\n",
    "# Save the predictions\n",
    "dev_predictions = pd.DataFrame({'target': y_dev_pred})\n",
    "test_predictions = pd.DataFrame({'target': y_test_pred})\n",
    "\n",
    "dev_predictions.to_csv('../workspace/kc1/dev_predictions.csv', index=False)\n",
    "test_predictions.to_csv('../workspace/kc1/test_predictions.csv', index=False)\n",
    "\n",
    "# Print the train and dev set performance\n",
    "train_f1 = f1_score(y_train, stacking_model.predict(X_train_fe))\n",
    "print(f\"Train F1 Score: {train_f1}\")\n",
    "print(f\"Dev F1 Score: {dev_f1}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13543192",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T15:56:11.089648Z",
     "iopub.status.busy": "2024-10-18T15:56:11.088969Z",
     "iopub.status.idle": "2024-10-18T15:56:19.663701Z",
     "shell.execute_reply": "2024-10-18T15:56:19.662818Z"
    }
   },
   "outputs": [],
   "source": [
    "from sklearn.ensemble import StackingClassifier\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from xgboost import XGBClassifier\n",
    "from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "# Define base models\n",
    "base_models = [\n",
    "    ('rf', RandomForestClassifier(n_estimators=100, max_depth=10, random_state=42)),\n",
    "    ('xgb', XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42)),\n",
    "    ('et', ExtraTreesClassifier(n_estimators=100, max_depth=10, random_state=42)),\n",
    "    ('knn', KNeighborsClassifier(n_neighbors=5))\n",
    "]\n",
    "\n",
    "# Define the meta-model\n",
    "meta_model = LogisticRegression()\n",
    "\n",
    "# Create the stacking ensemble model\n",
    "stacking_model = StackingClassifier(estimators=base_models, final_estimator=meta_model, cv=5)\n",
    "\n",
    "# Train the stacking model\n",
    "stacking_model.fit(X_train_fe, y_train)\n",
    "\n",
    "# Predict on the dev set\n",
    "y_dev_pred = stacking_model.predict(X_dev_fe)\n",
    "\n",
    "# Calculate F1 score on the dev set\n",
    "dev_f1 = f1_score(y_dev, y_dev_pred)\n",
    "print(f\"Dev F1 Score: {dev_f1}\")\n",
    "\n",
    "# Predict on the test set\n",
    "y_test_pred = stacking_model.predict(X_test_fe)\n",
    "\n",
    "# Save the predictions\n",
    "dev_predictions = pd.DataFrame({'target': y_dev_pred})\n",
    "test_predictions = pd.DataFrame({'target': y_test_pred})\n",
    "\n",
    "dev_predictions.to_csv('../workspace/kc1/dev_predictions.csv', index=False)\n",
    "test_predictions.to_csv('../workspace/kc1/test_predictions.csv', index=False)\n",
    "\n",
    "# Print the train and dev set performance\n",
    "train_f1 = f1_score(y_train, stacking_model.predict(X_train_fe))\n",
    "print(f\"Train F1 Score: {train_f1}\")\n",
    "print(f\"Dev F1 Score: {dev_f1}\")\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
