{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cb2e8796",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T16:55:40.730187Z",
     "iopub.status.busy": "2024-10-18T16:55:40.729662Z",
     "iopub.status.idle": "2024-10-18T16:55:41.600726Z",
     "shell.execute_reply": "2024-10-18T16:55:41.599877Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Basic information about the train dataset:\n",
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 1200 entries, 0 to 1199\n",
      "Columns: 217 entries, att1 to class\n",
      "dtypes: int64(217)\n",
      "memory usage: 2.0 MB\n",
      "None\n",
      "\n",
      "First few rows of the train dataset:\n",
      "   att1  att2  att3  att4  att5  att6  att7  att8  att9  att10  ...  att208  \\\n",
      "0   382   482  1039   751   655   920    39    27    23     12  ...     604   \n",
      "1   376   486   969   792   644   850    36    26    19     13  ...     643   \n",
      "2   130   126   617   668   619   578     5     4     0      3  ...     531   \n",
      "3   275   381   844   755   606   715    27    22    15      4  ...     608   \n",
      "4   296   192   754   736   677   723    16    16    16     16  ...     693   \n",
      "\n",
      "   att209  att210  att211  att212  att213  att214  att215  att216  class  \n",
      "0     666     869      25      14      22      11      12      15      8  \n",
      "1     655     763      22      17      18      12      14      17      8  \n",
      "2     600     625      13      35      11      14      15      10      1  \n",
      "3     615     686      13      19      18      17      13      17     10  \n",
      "4     626     734       8      27      17       9      11      16      6  \n",
      "\n",
      "[5 rows x 217 columns]\n",
      "\n",
      "Missing values in the train dataset:\n",
      "att1      0\n",
      "att2      0\n",
      "att3      0\n",
      "att4      0\n",
      "att5      0\n",
      "         ..\n",
      "att213    0\n",
      "att214    0\n",
      "att215    0\n",
      "att216    0\n",
      "class     0\n",
      "Length: 217, dtype: int64\n",
      "\n",
      "Summary statistics for numerical columns:\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              att1         att2         att3         att4         att5  \\\n",
      "count  1200.000000  1200.000000  1200.000000  1200.000000  1200.000000   \n",
      "mean    272.878333   322.301667   774.733333   755.000833   639.276667   \n",
      "std      92.269379   108.526797   137.246084   109.975746    48.461262   \n",
      "min      67.000000    81.000000   515.000000   545.000000   437.000000   \n",
      "25%     206.000000   248.750000   657.000000   667.000000   606.000000   \n",
      "50%     274.000000   322.000000   768.000000   736.000000   636.000000   \n",
      "75%     340.000000   402.000000   878.000000   829.000000   666.000000   \n",
      "max     495.000000   565.000000  1264.000000  1134.000000   811.000000   \n",
      "\n",
      "              att6         att7         att8         att9        att10  ...  \\\n",
      "count  1200.000000  1200.000000  1200.000000  1200.000000  1200.000000  ...   \n",
      "mean    683.240833    19.415833    18.415000    15.332500     9.042500  ...   \n",
      "std      86.149014    11.195864     7.418803     8.814383     4.086924  ...   \n",
      "min     439.000000     0.000000     1.000000     0.000000     0.000000  ...   \n",
      "25%     640.000000    10.750000    13.000000     8.000000     5.000000  ...   \n",
      "50%     677.000000    20.000000    18.000000    15.000000    10.000000  ...   \n",
      "75%     716.000000    28.000000    24.000000    21.000000    12.000000  ...   \n",
      "max    1012.000000    40.000000    38.000000    38.000000    17.000000  ...   \n",
      "\n",
      "            att208       att209       att210       att211       att212  \\\n",
      "count  1200.000000  1200.000000  1200.000000  1200.000000  1200.000000   \n",
      "mean    643.785833   639.945000   687.387500    17.729167    21.523333   \n",
      "std     119.386184    57.025025    81.299049     5.200002     8.030870   \n",
      "min     432.000000   427.000000   462.000000     3.000000     1.000000   \n",
      "25%     546.000000   603.000000   639.000000    14.000000    15.000000   \n",
      "50%     628.500000   642.000000   682.000000    18.000000    23.000000   \n",
      "75%     727.250000   672.000000   734.000000    22.000000    27.000000   \n",
      "max    1012.000000   846.000000  1002.000000    30.000000    37.000000   \n",
      "\n",
      "            att213       att214       att215       att216        class  \n",
      "count  1200.000000  1200.000000  1200.000000  1200.000000  1200.000000  \n",
      "mean     17.702500    11.940833    13.662500    13.276667     5.460833  \n",
      "std       5.800462     2.548741     2.002149     4.681708     2.856326  \n",
      "min       1.000000     4.000000     6.000000     0.000000     1.000000  \n",
      "25%      13.750000    10.000000    12.000000    11.000000     3.000000  \n",
      "50%      18.000000    12.000000    14.000000    14.000000     5.000000  \n",
      "75%      21.000000    14.000000    15.000000    17.000000     8.000000  \n",
      "max      35.000000    18.000000    19.000000    21.000000    10.000000  \n",
      "\n",
      "[8 rows x 217 columns]\n",
      "\n",
      "Distribution of the target column 'class':\n",
      "class\n",
      "5     130\n",
      "7     129\n",
      "8     127\n",
      "1     124\n",
      "2     122\n",
      "9     119\n",
      "3     117\n",
      "4     115\n",
      "6     109\n",
      "10    108\n",
      "Name: count, dtype: int64\n",
      "\n",
      "Data types of the columns:\n",
      "att1      int64\n",
      "att2      int64\n",
      "att3      int64\n",
      "att4      int64\n",
      "att5      int64\n",
      "          ...  \n",
      "att213    int64\n",
      "att214    int64\n",
      "att215    int64\n",
      "att216    int64\n",
      "class     int64\n",
      "Length: 217, dtype: object\n",
      "\n",
      "Correlation matrix for numerical features:\n",
      "            att1      att2      att3      att4      att5      att6      att7  \\\n",
      "att1    1.000000  0.509515  0.557634  0.499373  0.348385  0.154185  0.497244   \n",
      "att2    0.509515  1.000000  0.646589  0.375526  0.254996  0.325720  0.782921   \n",
      "att3    0.557634  0.646589  1.000000 -0.065607  0.052975  0.327449  0.895533   \n",
      "att4    0.499373  0.375526 -0.065607  1.000000  0.505390  0.122594  0.012070   \n",
      "att5    0.348385  0.254996  0.052975  0.505390  1.000000  0.062408  0.086529   \n",
      "...          ...       ...       ...       ...       ...       ...       ...   \n",
      "att213  0.267408  0.370004  0.107930  0.471357  0.308207  0.105832  0.158024   \n",
      "att214 -0.068481 -0.365590 -0.266966 -0.048658 -0.159313 -0.038861 -0.355162   \n",
      "att215 -0.192742 -0.286265 -0.267997 -0.133739 -0.200722 -0.260535 -0.303037   \n",
      "att216  0.107492  0.086418  0.215119  0.019012 -0.145803  0.505645  0.203463   \n",
      "class   0.331509 -0.050180  0.085129  0.124407  0.058036  0.220588  0.023395   \n",
      "\n",
      "            att8      att9     att10  ...    att208    att209    att210  \\\n",
      "att1    0.335980  0.467901  0.362195  ...  0.513003  0.303419  0.070371   \n",
      "att2    0.653777  0.537845  0.469034  ...  0.295300  0.157499  0.261395   \n",
      "att3    0.338692  0.117909  0.360686  ... -0.114445 -0.073032  0.357372   \n",
      "att4    0.259118  0.790392  0.174240  ...  0.951307  0.496611 -0.119870   \n",
      "att5    0.178712  0.479380  0.200180  ...  0.494211  0.766556 -0.040173   \n",
      "...          ...       ...       ...  ...       ...       ...       ...   \n",
      "att213  0.288378  0.557727  0.045567  ...  0.443632  0.350091  0.234667   \n",
      "att214 -0.278127 -0.213259 -0.409946  ... -0.073973 -0.059007 -0.062708   \n",
      "att215 -0.145425 -0.253896  0.005118  ... -0.060129 -0.194848 -0.215828   \n",
      "att216 -0.087065  0.155555 -0.085440  ... -0.079369 -0.010555  0.405280   \n",
      "class   0.115891  0.081496 -0.079912  ...  0.109863  0.117306  0.061997   \n",
      "\n",
      "          att211    att212    att213    att214    att215    att216     class  \n",
      "att1    0.340139 -0.298115  0.267408 -0.068481 -0.192742  0.107492  0.331509  \n",
      "att2    0.348727 -0.727688  0.370004 -0.365590 -0.286265  0.086418 -0.050180  \n",
      "att3    0.394878 -0.424341  0.107930 -0.266966 -0.267997  0.215119  0.085129  \n",
      "att4    0.016130 -0.229574  0.471357 -0.048658 -0.133739  0.019012  0.124407  \n",
      "att5    0.168150 -0.131261  0.308207 -0.159313 -0.200722 -0.145803  0.058036  \n",
      "...          ...       ...       ...       ...       ...       ...       ...  \n",
      "att213 -0.050095 -0.295566  1.000000 -0.111426 -0.218284  0.002204 -0.135388  \n",
      "att214 -0.258716  0.442190 -0.111426  1.000000  0.083197  0.033245  0.121635  \n",
      "att215  0.061869  0.146065 -0.218284  0.083197  1.000000 -0.121361 -0.044388  \n",
      "att216 -0.159204  0.070103  0.002204  0.033245 -0.121361  1.000000  0.260328  \n",
      "class  -0.002147 -0.020848 -0.135388  0.121635 -0.044388  0.260328  1.000000  \n",
      "\n",
      "[217 rows x 217 columns]\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "# Load the train dataset\n",
    "train_df = pd.read_csv('/data/datasets/mfeat-factors/split_train.csv')\n",
    "\n",
    "# Display basic information about the dataset\n",
    "print(\"Basic information about the train dataset:\")\n",
    "print(train_df.info())\n",
    "\n",
    "# Display the first few rows of the dataset\n",
    "print(\"\\nFirst few rows of the train dataset:\")\n",
    "print(train_df.head())\n",
    "\n",
    "# Check for missing values\n",
    "print(\"\\nMissing values in the train dataset:\")\n",
    "print(train_df.isnull().sum())\n",
    "\n",
    "# Summary statistics for numerical columns\n",
    "print(\"\\nSummary statistics for numerical columns:\")\n",
    "print(train_df.describe())\n",
    "\n",
    "# Check the distribution of the target column 'class'\n",
    "print(\"\\nDistribution of the target column 'class':\")\n",
    "print(train_df['class'].value_counts())\n",
    "\n",
    "# Check the data types of the columns\n",
    "print(\"\\nData types of the columns:\")\n",
    "print(train_df.dtypes)\n",
    "\n",
    "# Analyze the correlation between numerical features\n",
    "numerical_columns = train_df.select_dtypes(include=[np.number]).columns\n",
    "correlation_matrix = train_df[numerical_columns].corr()\n",
    "print(\"\\nCorrelation matrix for numerical features:\")\n",
    "print(correlation_matrix)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "070ef064",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T16:55:41.617697Z",
     "iopub.status.busy": "2024-10-18T16:55:41.617404Z",
     "iopub.status.idle": "2024-10-18T16:55:42.141886Z",
     "shell.execute_reply": "2024-10-18T16:55:42.140988Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "First few rows of the processed train dataset:\n",
      "       att1      att2      att3      att4      att5      att6      att7  \\\n",
      "0  1.183135  1.472124  1.926298 -0.036394  0.324587  2.749397  1.749962   \n",
      "1  1.118081  1.508997  1.416053  0.336571  0.097507  1.936513  1.481894   \n",
      "2 -1.549137 -1.809539 -1.149753 -0.791421 -0.418584 -1.222123 -1.288140   \n",
      "3  0.023004  0.541090  0.504900 -0.000008 -0.686952  0.368808  0.677690   \n",
      "4  0.250693 -1.201141 -0.151130 -0.172845  0.778747  0.461709 -0.305225   \n",
      "\n",
      "       att8      att9     att10  ...    att208    att209    att210    att211  \\\n",
      "0  1.157677  0.870248  0.723951  ... -0.333392  0.457095  2.234814  1.398820   \n",
      "1  1.022828  0.416255  0.968736  ... -0.006585  0.264117  0.930442  0.821656   \n",
      "2 -1.943846 -1.740212 -1.479112  ... -0.945108 -0.700774 -0.767703 -0.909834   \n",
      "3  0.483433 -0.037738 -1.234327  ... -0.299874 -0.437622 -0.017074 -0.909834   \n",
      "4 -0.325660  0.075760  1.703090  ...  0.412399 -0.244644  0.573585 -1.871773   \n",
      "\n",
      "     att212    att213    att214    att215    att216  class  \n",
      "0 -0.937192  0.741198 -0.369290 -0.830704  0.368253      8  \n",
      "1 -0.563478  0.051310  0.023224  0.168639  0.795625      8  \n",
      "2  1.678808 -1.155993  0.808252  0.668311 -0.700179      1  \n",
      "3 -0.314335  0.051310  1.985795 -0.331032  0.795625     10  \n",
      "4  0.682236 -0.121162 -1.154319 -1.330375  0.581939      6  \n",
      "\n",
      "[5 rows x 217 columns]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1021943/2306521906.py:23: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df_copy['class'] = y\n",
      "/tmp/ipykernel_1021943/2306521906.py:23: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df_copy['class'] = y\n"
     ]
    }
   ],
   "source": [
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "# Function to preprocess data\n",
    "def preprocess_data(df, scaler=None):\n",
    "    df_copy = df.copy()\n",
    "    \n",
    "    # Separate target column if it exists\n",
    "    if 'class' in df_copy.columns:\n",
    "        y = df_copy.pop('class')\n",
    "    else:\n",
    "        y = None\n",
    "    \n",
    "    # Scale numerical features\n",
    "    numerical_columns = df_copy.select_dtypes(include=[np.number]).columns\n",
    "    if scaler is None:\n",
    "        scaler = StandardScaler()\n",
    "        df_copy[numerical_columns] = scaler.fit_transform(df_copy[numerical_columns])\n",
    "    else:\n",
    "        df_copy[numerical_columns] = scaler.transform(df_copy[numerical_columns])\n",
    "    \n",
    "    # Reattach the target column if it was separated\n",
    "    if y is not None:\n",
    "        df_copy['class'] = y\n",
    "    \n",
    "    return df_copy, scaler\n",
    "\n",
    "# Load dev and test datasets\n",
    "dev_df = pd.read_csv('/data/datasets/mfeat-factors/split_dev.csv')\n",
    "test_df = pd.read_csv('/data/datasets/mfeat-factors/split_test_wo_target.csv')\n",
    "\n",
    "# Preprocess train, dev, and test datasets\n",
    "train_df_processed, scaler = preprocess_data(train_df)\n",
    "dev_df_processed, _ = preprocess_data(dev_df, scaler)\n",
    "test_df_processed, _ = preprocess_data(test_df, scaler)\n",
    "\n",
    "# Display the first few rows of the processed train dataset\n",
    "print(\"First few rows of the processed train dataset:\")\n",
    "print(train_df_processed.head())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f417d0b7",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T16:55:47.193091Z",
     "iopub.status.busy": "2024-10-18T16:55:47.192427Z",
     "iopub.status.idle": "2024-10-18T16:55:48.287039Z",
     "shell.execute_reply": "2024-10-18T16:55:48.286164Z"
    }
   },
   "outputs": [],
   "source": [
    "from metagpt.tools.libs.data_preprocess import get_column_info\n",
    "\n",
    "# Using the processed train dataset from the finished tasks\n",
    "column_info = get_column_info(train_df_processed)\n",
    "print(\"column_info\")\n",
    "print(column_info)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f755df93",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T16:56:08.688160Z",
     "iopub.status.busy": "2024-10-18T16:56:08.687045Z",
     "iopub.status.idle": "2024-10-18T16:56:09.335616Z",
     "shell.execute_reply": "2024-10-18T16:56:09.334791Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "First few rows of the processed train dataset after removing highly correlated features:\n",
      "       att1      att2      att3      att4      att5      att6      att7  \\\n",
      "0  1.183135  1.472124  1.926298 -0.036394  0.324587  2.749397  1.749962   \n",
      "1  1.118081  1.508997  1.416053  0.336571  0.097507  1.936513  1.481894   \n",
      "2 -1.549137 -1.809539 -1.149753 -0.791421 -0.418584 -1.222123 -1.288140   \n",
      "3  0.023004  0.541090  0.504900 -0.000008 -0.686952  0.368808  0.677690   \n",
      "4  0.250693 -1.201141 -0.151130 -0.172845  0.778747  0.461709 -0.305225   \n",
      "\n",
      "       att8      att9     att10  ...    att205    att206    att209    att210  \\\n",
      "0  1.157677  0.870248  0.723951  ...  0.667057  1.215774  0.457095  2.234814   \n",
      "1  1.022828  0.416255  0.968736  ...  0.821444 -0.257395  0.264117  0.930442   \n",
      "2 -1.943846 -1.740212 -1.479112  ... -2.111917  0.493632 -0.700774 -0.767703   \n",
      "3  0.483433 -0.037738 -1.234327  ...  0.281088 -0.502923 -0.437622 -0.017074   \n",
      "4 -0.325660  0.075760  1.703090  ...  0.543547  0.175890 -0.244644  0.573585   \n",
      "\n",
      "     att211    att212    att213    att214    att215  class  \n",
      "0  1.398820 -0.937192  0.741198 -0.369290 -0.830704      8  \n",
      "1  0.821656 -0.563478  0.051310  0.023224  0.168639      8  \n",
      "2 -0.909834  1.678808 -1.155993  0.808252  0.668311      1  \n",
      "3 -0.909834 -0.314335  0.051310  1.985795 -0.331032     10  \n",
      "4 -1.871773  0.682236 -0.121162 -1.154319 -1.330375      6  \n",
      "\n",
      "[5 rows x 130 columns]\n"
     ]
    }
   ],
   "source": [
    "# Perform correlation analysis to identify highly correlated features\n",
    "correlation_threshold = 0.9\n",
    "highly_correlated = set()\n",
    "\n",
    "# Calculate the correlation matrix for numerical features\n",
    "correlation_matrix = train_df_processed.corr()\n",
    "\n",
    "# Identify pairs of features with a correlation coefficient greater than the threshold\n",
    "for i in range(len(correlation_matrix.columns)):\n",
    "    for j in range(i):\n",
    "        if abs(correlation_matrix.iloc[i, j]) > correlation_threshold:\n",
    "            colname = correlation_matrix.columns[i]\n",
    "            highly_correlated.add(colname)\n",
    "\n",
    "# Remove highly correlated features from the datasets\n",
    "def remove_highly_correlated_features(df, highly_correlated):\n",
    "    df_copy = df.copy()\n",
    "    if 'class' in df_copy.columns:\n",
    "        y = df_copy.pop('class')\n",
    "    else:\n",
    "        y = None\n",
    "    df_copy = df_copy.drop(columns=highly_correlated)\n",
    "    if y is not None:\n",
    "        df_copy['class'] = y\n",
    "    return df_copy\n",
    "\n",
    "train_df_processed = remove_highly_correlated_features(train_df_processed, highly_correlated)\n",
    "dev_df_processed = remove_highly_correlated_features(dev_df_processed, highly_correlated)\n",
    "test_df_processed = remove_highly_correlated_features(test_df_processed, highly_correlated)\n",
    "\n",
    "print(\"First few rows of the processed train dataset after removing highly correlated features:\")\n",
    "print(train_df_processed.head())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6d3c8a3d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T16:56:14.091876Z",
     "iopub.status.busy": "2024-10-18T16:56:14.091216Z",
     "iopub.status.idle": "2024-10-18T16:56:14.110607Z",
     "shell.execute_reply": "2024-10-18T16:56:14.109345Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "column_info\n",
      "{'Category': [], 'Numeric': ['att1', 'att2', 'att3', 'att4', 'att5', 'att6', 'att7', 'att8', 'att9', 'att10', 'att11', 'att12', 'att13', 'att14', 'att15', 'att17', 'att18', 'att20', 'att21', 'att22', 'att23', 'att24', 'att25', 'att26', 'att27', 'att28', 'att29', 'att30', 'att31', 'att32', 'att33', 'att34', 'att35', 'att36', 'att37', 'att38', 'att41', 'att42', 'att43', 'att44', 'att45', 'att46', 'att47', 'att48', 'att49', 'att55', 'att56', 'att60', 'att61', 'att62', 'att64', 'att69', 'att71', 'att72', 'att73', 'att74', 'att75', 'att77', 'att79', 'att80', 'att82', 'att83', 'att84', 'att85', 'att86', 'att92', 'att94', 'att95', 'att96', 'att97', 'att98', 'att99', 'att101', 'att103', 'att104', 'att106', 'att107', 'att108', 'att109', 'att110', 'att112', 'att115', 'att116', 'att117', 'att118', 'att119', 'att120', 'att121', 'att122', 'att126', 'att128', 'att130', 'att131', 'att132', 'att133', 'att134', 'att138', 'att139', 'att140', 'att141', 'att145', 'att146', 'att152', 'att156', 'att157', 'att158', 'att159', 'att163', 'att164', 'att168', 'att169', 'att172', 'att178', 'att182', 'att184', 'att189', 'att193', 'att195', 'att203', 'att204', 'att205', 'att206', 'att209', 'att210', 'att211', 'att212', 'att213', 'att214', 'att215', 'class'], 'Datetime': [], 'Others': []}\n"
     ]
    }
   ],
   "source": [
    "from metagpt.tools.libs.data_preprocess import get_column_info\n",
    "\n",
    "column_info = get_column_info(train_df_processed)\n",
    "print(\"column_info\")\n",
    "print(column_info)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a795880a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T16:56:51.405367Z",
     "iopub.status.busy": "2024-10-18T16:56:51.404938Z",
     "iopub.status.idle": "2024-10-18T16:57:09.834239Z",
     "shell.execute_reply": "2024-10-18T16:57:09.830288Z"
    }
   },
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import f1_score\n",
    "from xgboost import XGBClassifier\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.ensemble import StackingClassifier\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "\n",
    "# Prepare the data\n",
    "X_train = train_df_processed.drop(columns=['class'])\n",
    "y_train = train_df_processed['class']\n",
    "X_dev = dev_df_processed.drop(columns=['class'])\n",
    "y_dev = dev_df_processed['class']\n",
    "X_test = test_df_processed\n",
    "\n",
    "# Encode the target labels\n",
    "label_encoder = LabelEncoder()\n",
    "y_train_encoded = label_encoder.fit_transform(y_train)\n",
    "y_dev_encoded = label_encoder.transform(y_dev)\n",
    "\n",
    "# Define base models\n",
    "base_models = [\n",
    "    ('xgb', XGBClassifier(use_label_encoder=False, eval_metric='mlogloss', n_estimators=200, max_depth=5, learning_rate=0.1)),\n",
    "    ('rf', RandomForestClassifier(n_estimators=200, max_depth=10, random_state=42)),\n",
    "    ('knn', KNeighborsClassifier(n_neighbors=5)),\n",
    "    ('lr', LogisticRegression(max_iter=200, multi_class='auto', solver='lbfgs'))\n",
    "]\n",
    "\n",
    "# Define the stacking model\n",
    "stacking_model = StackingClassifier(estimators=base_models, final_estimator=LogisticRegression(max_iter=200, multi_class='auto', solver='lbfgs'))\n",
    "\n",
    "# Train the stacking model\n",
    "stacking_model.fit(X_train, y_train_encoded)\n",
    "\n",
    "# Predict on the dev set\n",
    "y_dev_pred = stacking_model.predict(X_dev)\n",
    "\n",
    "# Calculate F1 weighted score on the dev set\n",
    "f1_weighted_dev = f1_score(y_dev_encoded, y_dev_pred, average='weighted')\n",
    "print(f\"F1 Weighted Score on Dev Set: {f1_weighted_dev}\")\n",
    "\n",
    "# Predict on the test set\n",
    "y_test_pred = stacking_model.predict(X_test)\n",
    "\n",
    "# Save the predictions\n",
    "dev_predictions = pd.DataFrame({'target': label_encoder.inverse_transform(y_dev_pred)})\n",
    "test_predictions = pd.DataFrame({'target': label_encoder.inverse_transform(y_test_pred)})\n",
    "\n",
    "dev_predictions.to_csv('../workspace/mfeat-factors/dev_predictions.csv', index=False)\n",
    "test_predictions.to_csv('../workspace/mfeat-factors/test_predictions.csv', index=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6c59068",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T16:57:20.265698Z",
     "iopub.status.busy": "2024-10-18T16:57:20.264872Z",
     "iopub.status.idle": "2024-10-18T16:57:23.699434Z",
     "shell.execute_reply": "2024-10-18T16:57:23.697884Z"
    }
   },
   "outputs": [],
   "source": [
    "# Evaluate the base models on the dev set and print the f1 weighted score\n",
    "base_model_scores = {}\n",
    "for name, model in base_models:\n",
    "    model.fit(X_train, y_train_encoded)\n",
    "    y_dev_pred = model.predict(X_dev)\n",
    "    f1_weighted_dev = f1_score(y_dev_encoded, y_dev_pred, average='weighted')\n",
    "    base_model_scores[name] = f1_weighted_dev\n",
    "    print(f\"F1 Weighted Score for {name} on Dev Set: {f1_weighted_dev}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "33f7a010",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T16:57:35.843233Z",
     "iopub.status.busy": "2024-10-18T16:57:35.842528Z",
     "iopub.status.idle": "2024-10-18T16:57:35.856461Z",
     "shell.execute_reply": "2024-10-18T16:57:35.855188Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "column_info\n",
      "{'Category': [], 'Numeric': ['att1', 'att2', 'att3', 'att4', 'att5', 'att6', 'att7', 'att8', 'att9', 'att10', 'att11', 'att12', 'att13', 'att14', 'att15', 'att17', 'att18', 'att20', 'att21', 'att22', 'att23', 'att24', 'att25', 'att26', 'att27', 'att28', 'att29', 'att30', 'att31', 'att32', 'att33', 'att34', 'att35', 'att36', 'att37', 'att38', 'att41', 'att42', 'att43', 'att44', 'att45', 'att46', 'att47', 'att48', 'att49', 'att55', 'att56', 'att60', 'att61', 'att62', 'att64', 'att69', 'att71', 'att72', 'att73', 'att74', 'att75', 'att77', 'att79', 'att80', 'att82', 'att83', 'att84', 'att85', 'att86', 'att92', 'att94', 'att95', 'att96', 'att97', 'att98', 'att99', 'att101', 'att103', 'att104', 'att106', 'att107', 'att108', 'att109', 'att110', 'att112', 'att115', 'att116', 'att117', 'att118', 'att119', 'att120', 'att121', 'att122', 'att126', 'att128', 'att130', 'att131', 'att132', 'att133', 'att134', 'att138', 'att139', 'att140', 'att141', 'att145', 'att146', 'att152', 'att156', 'att157', 'att158', 'att159', 'att163', 'att164', 'att168', 'att169', 'att172', 'att178', 'att182', 'att184', 'att189', 'att193', 'att195', 'att203', 'att204', 'att205', 'att206', 'att209', 'att210', 'att211', 'att212', 'att213', 'att214', 'att215', 'class'], 'Datetime': [], 'Others': []}\n"
     ]
    }
   ],
   "source": [
    "from metagpt.tools.libs.data_preprocess import get_column_info\n",
    "\n",
    "column_info = get_column_info(train_df_processed)\n",
    "print(\"column_info\")\n",
    "print(column_info)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bd3a0b2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T16:58:05.135282Z",
     "iopub.status.busy": "2024-10-18T16:58:05.134610Z",
     "iopub.status.idle": "2024-10-18T16:58:23.440617Z",
     "shell.execute_reply": "2024-10-18T16:58:23.438785Z"
    }
   },
   "outputs": [],
   "source": [
    "# Train a weighted ensemble model using the best performing models from the base model evaluation\n",
    "\n",
    "# Define the best performing models based on the evaluation results\n",
    "best_models = [\n",
    "    ('xgb', XGBClassifier(use_label_encoder=False, eval_metric='mlogloss', n_estimators=200, max_depth=5, learning_rate=0.1)),\n",
    "    ('rf', RandomForestClassifier(n_estimators=200, max_depth=10, random_state=42)),\n",
    "    ('lr', LogisticRegression(max_iter=200, multi_class='auto', solver='lbfgs'))\n",
    "]\n",
    "\n",
    "# Define the stacking model with the best models\n",
    "stacking_model_best = StackingClassifier(estimators=best_models, final_estimator=LogisticRegression(max_iter=200, multi_class='auto', solver='lbfgs'))\n",
    "\n",
    "# Train the stacking model\n",
    "stacking_model_best.fit(X_train, y_train_encoded)\n",
    "\n",
    "# Predict on the dev set\n",
    "y_dev_pred_best = stacking_model_best.predict(X_dev)\n",
    "\n",
    "# Calculate F1 weighted score on the dev set\n",
    "f1_weighted_dev_best = f1_score(y_dev_encoded, y_dev_pred_best, average='weighted')\n",
    "print(f\"F1 Weighted Score on Dev Set with Best Models: {f1_weighted_dev_best}\")\n",
    "\n",
    "# Predict on the test set\n",
    "y_test_pred_best = stacking_model_best.predict(X_test)\n",
    "\n",
    "# Save the predictions\n",
    "dev_predictions_best = pd.DataFrame({'target': label_encoder.inverse_transform(y_dev_pred_best)})\n",
    "test_predictions_best = pd.DataFrame({'target': label_encoder.inverse_transform(y_test_pred_best)})\n",
    "\n",
    "dev_predictions_best.to_csv('../workspace/mfeat-factors/dev_predictions_best.csv', index=False)\n",
    "test_predictions_best.to_csv('../workspace/mfeat-factors/test_predictions_best.csv', index=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b72694a6",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T16:58:52.796064Z",
     "iopub.status.busy": "2024-10-18T16:58:52.795058Z",
     "iopub.status.idle": "2024-10-18T16:58:53.047249Z",
     "shell.execute_reply": "2024-10-18T16:58:53.045475Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "F1 Weighted Score on Train Set with Best Models: 1.0\n",
      "F1 Weighted Score on Dev Set with Best Models: 0.974909978712511\n"
     ]
    }
   ],
   "source": [
    "# Evaluate the final ensemble model on the train and dev sets and print the f1 weighted scores\n",
    "\n",
    "# Predict on the train set\n",
    "y_train_pred_best = stacking_model_best.predict(X_train)\n",
    "\n",
    "# Calculate F1 weighted score on the train set\n",
    "f1_weighted_train_best = f1_score(y_train_encoded, y_train_pred_best, average='weighted')\n",
    "print(f\"F1 Weighted Score on Train Set with Best Models: {f1_weighted_train_best}\")\n",
    "\n",
    "# Predict on the dev set\n",
    "y_dev_pred_best = stacking_model_best.predict(X_dev)\n",
    "\n",
    "# Calculate F1 weighted score on the dev set\n",
    "f1_weighted_dev_best = f1_score(y_dev_encoded, y_dev_pred_best, average='weighted')\n",
    "print(f\"F1 Weighted Score on Dev Set with Best Models: {f1_weighted_dev_best}\")\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
