{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ad0e10c6",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T17:29:09.373240Z",
     "iopub.status.busy": "2024-10-18T17:29:09.372711Z",
     "iopub.status.idle": "2024-10-18T17:29:09.863709Z",
     "shell.execute_reply": "2024-10-18T17:29:09.862838Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Dataset Info:\n",
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 1386 entries, 0 to 1385\n",
      "Data columns (total 17 columns):\n",
      " #   Column                Non-Null Count  Dtype  \n",
      "---  ------                --------------  -----  \n",
      " 0   short.line.density.5  1386 non-null   float64\n",
      " 1   short.line.density.2  1386 non-null   float64\n",
      " 2   vedge.mean            1386 non-null   float64\n",
      " 3   vegde.sd              1386 non-null   float64\n",
      " 4   hedge.mean            1386 non-null   float64\n",
      " 5   hedge.sd              1386 non-null   float64\n",
      " 6   intensity.mean        1386 non-null   float64\n",
      " 7   rawred.mean           1386 non-null   float64\n",
      " 8   rawblue.mean          1386 non-null   float64\n",
      " 9   rawgreen.mean         1386 non-null   float64\n",
      " 10  exred.mean            1386 non-null   float64\n",
      " 11  exblue.mean           1386 non-null   float64\n",
      " 12  exgreen.mean          1386 non-null   float64\n",
      " 13  value.mean            1386 non-null   float64\n",
      " 14  saturation.mean       1386 non-null   float64\n",
      " 15  hue.mean              1386 non-null   float64\n",
      " 16  class                 1386 non-null   object \n",
      "dtypes: float64(16), object(1)\n",
      "memory usage: 184.2+ KB\n",
      "None\n",
      "\n",
      "Dev Dataset Info:\n",
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 462 entries, 0 to 461\n",
      "Data columns (total 17 columns):\n",
      " #   Column                Non-Null Count  Dtype  \n",
      "---  ------                --------------  -----  \n",
      " 0   short.line.density.5  462 non-null    float64\n",
      " 1   short.line.density.2  462 non-null    float64\n",
      " 2   vedge.mean            462 non-null    float64\n",
      " 3   vegde.sd              462 non-null    float64\n",
      " 4   hedge.mean            462 non-null    float64\n",
      " 5   hedge.sd              462 non-null    float64\n",
      " 6   intensity.mean        462 non-null    float64\n",
      " 7   rawred.mean           462 non-null    float64\n",
      " 8   rawblue.mean          462 non-null    float64\n",
      " 9   rawgreen.mean         462 non-null    float64\n",
      " 10  exred.mean            462 non-null    float64\n",
      " 11  exblue.mean           462 non-null    float64\n",
      " 12  exgreen.mean          462 non-null    float64\n",
      " 13  value.mean            462 non-null    float64\n",
      " 14  saturation.mean       462 non-null    float64\n",
      " 15  hue.mean              462 non-null    float64\n",
      " 16  class                 462 non-null    object \n",
      "dtypes: float64(16), object(1)\n",
      "memory usage: 61.5+ KB\n",
      "None\n",
      "\n",
      "Train Dataset Head:\n",
      "   short.line.density.5  short.line.density.2  vedge.mean    vegde.sd  \\\n",
      "0              0.000000              0.111111   11.666700  125.867000   \n",
      "1              0.111111              0.000000    2.888890    3.896310   \n",
      "2              0.000000              0.000000    2.333330    3.600000   \n",
      "3              0.000000              0.000000    0.444444    0.162962   \n",
      "4              0.000000              0.000000    3.055550    3.275610   \n",
      "\n",
      "   hedge.mean  hedge.sd  intensity.mean  rawred.mean  rawblue.mean  \\\n",
      "0    0.944444  0.685184         8.70370      5.88889       13.8889   \n",
      "1    1.944440  4.907410        86.29630     70.66670      112.1110   \n",
      "2    2.777780  6.162960         3.77778      1.00000        8.0000   \n",
      "3    1.666670  1.155550       102.51900     87.55560      125.0000   \n",
      "4    2.388890  2.870670        56.92590     50.44440       71.4444   \n",
      "\n",
      "   rawgreen.mean  exred.mean  exblue.mean  exgreen.mean  value.mean  \\\n",
      "0        6.33333    -8.44444      15.5556      -7.11111     13.8889   \n",
      "1       76.11110   -46.88890      77.4444     -30.55560    112.1110   \n",
      "2        2.33333    -8.33333      12.6667      -4.33333      8.0000   \n",
      "3       95.00000   -44.88890      67.4444     -22.55560    125.0000   \n",
      "4       48.88890   -19.44440      43.5556     -24.11110     71.4444   \n",
      "\n",
      "   saturation.mean  hue.mean    class  \n",
      "0         0.771517  -2.06692   window  \n",
      "1         0.369638  -2.23156      sky  \n",
      "2         0.909236  -2.26821  foliage  \n",
      "3         0.299601  -2.30247      sky  \n",
      "4         0.315862  -2.02178     path  \n",
      "\n",
      "Dev Dataset Head:\n",
      "   short.line.density.5  short.line.density.2  vedge.mean  vegde.sd  \\\n",
      "0              0.000000                   0.0    0.777777  0.385185   \n",
      "1              0.000000                   0.0    1.500000  1.633330   \n",
      "2              0.111111                   0.0    0.555555  0.118519   \n",
      "3              0.000000                   0.0    0.555555  0.403687   \n",
      "4              0.000000                   0.0    0.888889  0.720082   \n",
      "\n",
      "   hedge.mean  hedge.sd  intensity.mean  rawred.mean  rawblue.mean  \\\n",
      "0     3.66667  4.044440        22.33330     21.66670       28.4444   \n",
      "1     1.55556  0.874074        21.70370     21.22220       27.5556   \n",
      "2     1.61111  0.596295        22.44440     21.55560       29.2222   \n",
      "3     0.50000  0.459468         7.07407      4.66667       12.2222   \n",
      "4     7.94444  5.425930        42.77780     37.00000       54.4444   \n",
      "\n",
      "   rawgreen.mean  exred.mean  exblue.mean  exgreen.mean  value.mean  \\\n",
      "0       16.88890    -2.00000      18.3333     -16.33330     28.4444   \n",
      "1       16.33330    -1.44444      17.5556     -16.11110     27.5556   \n",
      "2       16.55560    -2.66667      20.3333     -17.66670     29.2222   \n",
      "3        4.33333    -7.22222      15.4444      -8.22222     12.2222   \n",
      "4       36.88890   -17.33330      35.0000     -17.66670     54.4444   \n",
      "\n",
      "   saturation.mean  hue.mean      class  \n",
      "0         0.403687  -1.62559  brickface  \n",
      "1         0.407364  -1.63224  brickface  \n",
      "2         0.432819  -1.67432  brickface  \n",
      "3         0.645891  -2.04872     window  \n",
      "4         0.318271  -2.08858     cement  \n",
      "\n",
      "Missing Values in Train Dataset:\n",
      "short.line.density.5    0\n",
      "short.line.density.2    0\n",
      "vedge.mean              0\n",
      "vegde.sd                0\n",
      "hedge.mean              0\n",
      "hedge.sd                0\n",
      "intensity.mean          0\n",
      "rawred.mean             0\n",
      "rawblue.mean            0\n",
      "rawgreen.mean           0\n",
      "exred.mean              0\n",
      "exblue.mean             0\n",
      "exgreen.mean            0\n",
      "value.mean              0\n",
      "saturation.mean         0\n",
      "hue.mean                0\n",
      "class                   0\n",
      "dtype: int64\n",
      "\n",
      "Missing Values in Dev Dataset:\n",
      "short.line.density.5    0\n",
      "short.line.density.2    0\n",
      "vedge.mean              0\n",
      "vegde.sd                0\n",
      "hedge.mean              0\n",
      "hedge.sd                0\n",
      "intensity.mean          0\n",
      "rawred.mean             0\n",
      "rawblue.mean            0\n",
      "rawgreen.mean           0\n",
      "exred.mean              0\n",
      "exblue.mean             0\n",
      "exgreen.mean            0\n",
      "value.mean              0\n",
      "saturation.mean         0\n",
      "hue.mean                0\n",
      "class                   0\n",
      "dtype: int64\n",
      "\n",
      "Summary Statistics for Numerical Columns in Train Dataset:\n",
      "       short.line.density.5  short.line.density.2   vedge.mean     vegde.sd  \\\n",
      "count           1386.000000           1386.000000  1386.000000  1386.000000   \n",
      "mean               0.014991              0.005291     1.958755     5.962304   \n",
      "std                0.041560              0.025484     2.835625    41.653966   \n",
      "min                0.000000              0.000000     0.000000     0.000000   \n",
      "25%                0.000000              0.000000     0.722222     0.374073   \n",
      "50%                0.000000              0.000000     1.277780     0.833333   \n",
      "75%                0.000000              0.000000     2.166670     1.780638   \n",
      "max                0.333333              0.222222    27.277800   752.241000   \n",
      "\n",
      "        hedge.mean     hedge.sd  intensity.mean  rawred.mean  rawblue.mean  \\\n",
      "count  1386.000000  1386.000000     1386.000000  1386.000000   1386.000000   \n",
      "mean      2.456069     8.355612       37.267543    32.955108     44.520602   \n",
      "std       3.715051    55.195733       37.915687    34.808332     43.250870   \n",
      "min       0.000000     0.000000        0.000000     0.000000      0.000000   \n",
      "25%       0.777779     0.429627        7.601850     7.111110     10.027775   \n",
      "50%       1.444440     0.962963       21.944450    20.000000     28.444400   \n",
      "75%       2.611110     2.272213       54.768525    48.833350     67.222200   \n",
      "max      44.722200  1386.330000      143.444000   137.111000    150.889000   \n",
      "\n",
      "       rawgreen.mean   exred.mean  exblue.mean  exgreen.mean   value.mean  \\\n",
      "count    1386.000000  1386.000000  1386.000000   1386.000000  1386.000000   \n",
      "mean       34.326922   -12.937309    21.759179     -8.821869    45.463924   \n",
      "std        36.095060    11.538322    19.557821     11.626460    42.643531   \n",
      "min         0.000000   -49.666700   -12.111100    -33.888900     0.000000   \n",
      "25%         6.333330   -18.666700     4.805557    -17.000000    11.888900   \n",
      "50%        20.444400   -11.111100    20.222200    -11.000000    29.000000   \n",
      "75%        47.777800    -4.666670    36.194425     -3.472220    67.222200   \n",
      "max       142.556000     6.333330    82.000000     24.666700   150.889000   \n",
      "\n",
      "       saturation.mean     hue.mean  \n",
      "count      1386.000000  1386.000000  \n",
      "mean          0.430001    -1.374036  \n",
      "std           0.227892     1.548843  \n",
      "min           0.000000    -2.869950  \n",
      "25%           0.285691    -2.195807  \n",
      "50%           0.377733    -2.056115  \n",
      "75%           0.541620    -1.597228  \n",
      "max           1.000000     2.912480  \n",
      "\n",
      "Summary Statistics for Numerical Columns in Dev Dataset:\n",
      "       short.line.density.5  short.line.density.2  vedge.mean    vegde.sd  \\\n",
      "count            462.000000            462.000000  462.000000  462.000000   \n",
      "mean               0.013709              0.003848    1.755411    4.231910   \n",
      "std                0.038017              0.021615    2.506415   33.302043   \n",
      "min                0.000000              0.000000    0.000000    0.000000   \n",
      "25%                0.000000              0.000000    0.611112    0.329630   \n",
      "50%                0.000000              0.000000    1.166670    0.828695   \n",
      "75%                0.000000              0.000000    2.111110    1.779508   \n",
      "max                0.222222              0.222222   27.944400  572.996000   \n",
      "\n",
      "       hedge.mean      hedge.sd  intensity.mean  rawred.mean  rawblue.mean  \\\n",
      "count  462.000000  4.620000e+02      462.000000   462.000000    462.000000   \n",
      "mean     2.373737  9.183724e+00       37.002637    32.929778     43.904041   \n",
      "std      3.797932  7.403887e+01       38.787255    35.595809     44.042005   \n",
      "min      0.000000 -1.589460e-08        0.000000     0.000000      0.000000   \n",
      "25%      0.777779  4.073950e-01        6.731485     7.333330      8.472220   \n",
      "50%      1.388890  9.333325e-01       22.388850    19.666700     28.555550   \n",
      "75%      2.444440  2.096718e+00       49.694450    44.777800     61.527800   \n",
      "max     44.722200  1.386330e+03      143.444000   136.889000    150.889000   \n",
      "\n",
      "       rawgreen.mean  exred.mean  exblue.mean  exgreen.mean  value.mean  \\\n",
      "count     462.000000  462.000000   462.000000    462.000000  462.000000   \n",
      "mean       34.174122  -12.218615    20.704184     -8.485571   44.853296   \n",
      "std        37.106392   11.498509    19.275191     11.350619   43.438083   \n",
      "min         0.000000  -44.888900   -12.444400    -33.888900    0.000000   \n",
      "25%         5.333333  -17.888900     3.333330    -16.222200   11.361075   \n",
      "50%        20.222200  -10.500000    20.000000    -10.388900   29.111100   \n",
      "75%        43.277750   -3.222220    34.666700     -2.583337   61.527800   \n",
      "max       142.556000    7.222220    78.777800     22.222200  150.889000   \n",
      "\n",
      "       saturation.mean    hue.mean  \n",
      "count       462.000000  462.000000  \n",
      "mean          0.415956   -1.358292  \n",
      "std           0.220064    1.537385  \n",
      "min           0.000000   -3.044180  \n",
      "25%           0.280482   -2.182335  \n",
      "50%           0.372031   -2.043790  \n",
      "75%           0.533433   -1.549287  \n",
      "max           1.000000    2.835610  \n",
      "\n",
      "Summary Statistics for Categorical Columns in Train Dataset:\n",
      "         class\n",
      "count     1386\n",
      "unique       7\n",
      "top     cement\n",
      "freq       203\n",
      "\n",
      "Summary Statistics for Categorical Columns in Dev Dataset:\n",
      "         class\n",
      "count      462\n",
      "unique       7\n",
      "top     window\n",
      "freq        74\n",
      "\n",
      "Target Distribution in Train Dataset:\n",
      "class\n",
      "cement       203\n",
      "foliage      199\n",
      "window       198\n",
      "path         198\n",
      "sky          196\n",
      "brickface    196\n",
      "grass        196\n",
      "Name: count, dtype: int64\n",
      "\n",
      "Target Distribution in Dev Dataset:\n",
      "class\n",
      "window       74\n",
      "path         73\n",
      "brickface    70\n",
      "sky          67\n",
      "grass        65\n",
      "cement       61\n",
      "foliage      52\n",
      "Name: count, dtype: int64\n",
      "\n",
      "Correlation Matrix for Numerical Features in Train Dataset:\n",
      "                      short.line.density.5  short.line.density.2  vedge.mean  \\\n",
      "short.line.density.5              1.000000             -0.007615   -0.035065   \n",
      "short.line.density.2             -0.007615              1.000000    0.302275   \n",
      "vedge.mean                       -0.035065              0.302275    1.000000   \n",
      "vegde.sd                         -0.038175              0.208119    0.628376   \n",
      "hedge.mean                       -0.021769              0.318856    0.576441   \n",
      "hedge.sd                         -0.041130              0.231032    0.477500   \n",
      "intensity.mean                   -0.000907             -0.006906    0.001888   \n",
      "rawred.mean                       0.001642             -0.011760   -0.007489   \n",
      "rawblue.mean                     -0.005892              0.001704    0.016427   \n",
      "rawgreen.mean                     0.002619             -0.012464   -0.006513   \n",
      "exred.mean                        0.023803             -0.038355   -0.086386   \n",
      "exblue.mean                      -0.033817              0.051472    0.098001   \n",
      "exgreen.mean                      0.033265             -0.048521   -0.079125   \n",
      "value.mean                        0.000348             -0.001210    0.013237   \n",
      "saturation.mean                  -0.065179              0.008694   -0.070412   \n",
      "hue.mean                          0.115209             -0.077339   -0.109574   \n",
      "\n",
      "                      vegde.sd  hedge.mean  hedge.sd  intensity.mean  \\\n",
      "short.line.density.5 -0.038175   -0.021769 -0.041130       -0.000907   \n",
      "short.line.density.2  0.208119    0.318856  0.231032       -0.006906   \n",
      "vedge.mean            0.628376    0.576441  0.477500        0.001888   \n",
      "vegde.sd              1.000000    0.527575  0.691195        0.003516   \n",
      "hedge.mean            0.527575    1.000000  0.695762        0.042106   \n",
      "hedge.sd              0.691195    0.695762  1.000000        0.016397   \n",
      "intensity.mean        0.003516    0.042106  0.016397        1.000000   \n",
      "rawred.mean          -0.001461    0.034188  0.011342        0.998054   \n",
      "rawblue.mean          0.007023    0.051325  0.019725        0.995720   \n",
      "rawgreen.mean         0.004075    0.038220  0.017099        0.995724   \n",
      "exred.mean           -0.047886   -0.105678 -0.058995       -0.825536   \n",
      "exblue.mean           0.026140    0.095621  0.035496        0.789978   \n",
      "exgreen.mean          0.003552   -0.055975 -0.001164       -0.509608   \n",
      "value.mean            0.004735    0.049879  0.017621        0.997325   \n",
      "saturation.mean       0.002915   -0.134086 -0.027258       -0.623178   \n",
      "hue.mean             -0.069363   -0.091332 -0.073541       -0.324428   \n",
      "\n",
      "                      rawred.mean  rawblue.mean  rawgreen.mean  exred.mean  \\\n",
      "short.line.density.5     0.001642     -0.005892       0.002619    0.023803   \n",
      "short.line.density.2    -0.011760      0.001704      -0.012464   -0.038355   \n",
      "vedge.mean              -0.007489      0.016427      -0.006513   -0.086386   \n",
      "vegde.sd                -0.001461      0.007023       0.004075   -0.047886   \n",
      "hedge.mean               0.034188      0.051325       0.038220   -0.105678   \n",
      "hedge.sd                 0.011342      0.019725       0.017099   -0.058995   \n",
      "intensity.mean           0.998054      0.995720       0.995724   -0.825536   \n",
      "rawred.mean              1.000000      0.990604       0.993845   -0.788738   \n",
      "rawblue.mean             0.990604      1.000000       0.984293   -0.850768   \n",
      "rawgreen.mean            0.993845      0.984293       1.000000   -0.821475   \n",
      "exred.mean              -0.788738     -0.850768      -0.821475    1.000000   \n",
      "exblue.mean              0.767354      0.843262       0.739035   -0.842992   \n",
      "exgreen.mean            -0.508070     -0.574202      -0.427944    0.425647   \n",
      "value.mean               0.991845      0.998641       0.989782   -0.855360   \n",
      "saturation.mean         -0.631306     -0.610609      -0.623370    0.429920   \n",
      "hue.mean                -0.324173     -0.379935      -0.254503    0.264417   \n",
      "\n",
      "                      exblue.mean  exgreen.mean  value.mean  saturation.mean  \\\n",
      "short.line.density.5    -0.033817      0.033265    0.000348        -0.065179   \n",
      "short.line.density.2     0.051472     -0.048521   -0.001210         0.008694   \n",
      "vedge.mean               0.098001     -0.079125    0.013237        -0.070412   \n",
      "vegde.sd                 0.026140      0.003552    0.004735         0.002915   \n",
      "hedge.mean               0.095621     -0.055975    0.049879        -0.134086   \n",
      "hedge.sd                 0.035496     -0.001164    0.017621        -0.027258   \n",
      "intensity.mean           0.789978     -0.509608    0.997325        -0.623178   \n",
      "rawred.mean              0.767354     -0.508070    0.991845        -0.631306   \n",
      "rawblue.mean             0.843262     -0.574202    0.998641        -0.610609   \n",
      "rawgreen.mean            0.739035     -0.427944    0.989782        -0.623370   \n",
      "exred.mean              -0.842992      0.425647   -0.855360         0.429920   \n",
      "exblue.mean              1.000000     -0.845580    0.824915        -0.426605   \n",
      "exgreen.mean            -0.845580      1.000000   -0.538782         0.290966   \n",
      "value.mean               0.824915     -0.538782    1.000000        -0.619384   \n",
      "saturation.mean         -0.426605      0.290966   -0.619384         1.000000   \n",
      "hue.mean                -0.633751      0.803673   -0.335812        -0.045524   \n",
      "\n",
      "                      hue.mean  \n",
      "short.line.density.5  0.115209  \n",
      "short.line.density.2 -0.077339  \n",
      "vedge.mean           -0.109574  \n",
      "vegde.sd             -0.069363  \n",
      "hedge.mean           -0.091332  \n",
      "hedge.sd             -0.073541  \n",
      "intensity.mean       -0.324428  \n",
      "rawred.mean          -0.324173  \n",
      "rawblue.mean         -0.379935  \n",
      "rawgreen.mean        -0.254503  \n",
      "exred.mean            0.264417  \n",
      "exblue.mean          -0.633751  \n",
      "exgreen.mean          0.803673  \n",
      "value.mean           -0.335812  \n",
      "saturation.mean      -0.045524  \n",
      "hue.mean              1.000000  \n",
      "\n",
      "Correlation Matrix for Numerical Features in Dev Dataset:\n",
      "                      short.line.density.5  short.line.density.2  vedge.mean  \\\n",
      "short.line.density.5              1.000000              0.000846    0.012078   \n",
      "short.line.density.2              0.000846              1.000000    0.228732   \n",
      "vedge.mean                        0.012078              0.228732    1.000000   \n",
      "vegde.sd                         -0.027314              0.221549    0.681193   \n",
      "hedge.mean                       -0.002917              0.319104    0.505723   \n",
      "hedge.sd                         -0.035134              0.315316    0.554271   \n",
      "intensity.mean                    0.017489              0.006856    0.040137   \n",
      "rawred.mean                       0.015496             -0.000205    0.027916   \n",
      "rawblue.mean                      0.019885              0.019940    0.058247   \n",
      "rawgreen.mean                     0.016377             -0.001973    0.029952   \n",
      "exred.mean                       -0.033073             -0.071279   -0.146916   \n",
      "exblue.mean                       0.030726              0.095298    0.156964   \n",
      "exgreen.mean                     -0.018675             -0.089624   -0.117721   \n",
      "value.mean                        0.024119              0.016319    0.057065   \n",
      "saturation.mean                  -0.047993              0.021139   -0.083988   \n",
      "hue.mean                          0.076389             -0.092742   -0.084204   \n",
      "\n",
      "                      vegde.sd  hedge.mean  hedge.sd  intensity.mean  \\\n",
      "short.line.density.5 -0.027314   -0.002917 -0.035134        0.017489   \n",
      "short.line.density.2  0.221549    0.319104  0.315316        0.006856   \n",
      "vedge.mean            0.681193    0.505723  0.554271        0.040137   \n",
      "vegde.sd              1.000000    0.497473  0.765082        0.031676   \n",
      "hedge.mean            0.497473    1.000000  0.731429        0.045419   \n",
      "hedge.sd              0.765082    0.731429  1.000000        0.032195   \n",
      "intensity.mean        0.031676    0.045419  0.032195        1.000000   \n",
      "rawred.mean           0.027393    0.036840  0.026693        0.998368   \n",
      "rawblue.mean          0.035480    0.056038  0.036083        0.995999   \n",
      "rawgreen.mean         0.030943    0.040577  0.032527        0.996008   \n",
      "exred.mean           -0.066150   -0.117494 -0.077908       -0.847817   \n",
      "exblue.mean           0.051983    0.109939  0.052983        0.790424   \n",
      "exgreen.mean         -0.021265   -0.067669 -0.011050       -0.483406   \n",
      "value.mean            0.034180    0.054543  0.034451        0.997503   \n",
      "saturation.mean      -0.026936   -0.129714 -0.033746       -0.597960   \n",
      "hue.mean             -0.050047   -0.099132 -0.065608       -0.340205   \n",
      "\n",
      "                      rawred.mean  rawblue.mean  rawgreen.mean  exred.mean  \\\n",
      "short.line.density.5     0.015496      0.019885       0.016377   -0.033073   \n",
      "short.line.density.2    -0.000205      0.019940      -0.001973   -0.071279   \n",
      "vedge.mean               0.027916      0.058247       0.029952   -0.146916   \n",
      "vegde.sd                 0.027393      0.035480       0.030943   -0.066150   \n",
      "hedge.mean               0.036840      0.056038       0.040577   -0.117494   \n",
      "hedge.sd                 0.026693      0.036083       0.032527   -0.077908   \n",
      "intensity.mean           0.998368      0.995999       0.996008   -0.847817   \n",
      "rawred.mean              1.000000      0.991615       0.994531   -0.816154   \n",
      "rawblue.mean             0.991615      1.000000       0.985192   -0.870049   \n",
      "rawgreen.mean            0.994531      0.985192       1.000000   -0.843065   \n",
      "exred.mean              -0.816154     -0.870049      -0.843065    1.000000   \n",
      "exblue.mean              0.770221      0.842002       0.740441   -0.845784   \n",
      "exgreen.mean            -0.481173     -0.548471      -0.403339    0.423248   \n",
      "value.mean               0.992817      0.998628       0.990384   -0.874104   \n",
      "saturation.mean         -0.604367     -0.587618      -0.597925    0.438390   \n",
      "hue.mean                -0.339479     -0.393231      -0.274457    0.290013   \n",
      "\n",
      "                      exblue.mean  exgreen.mean  value.mean  saturation.mean  \\\n",
      "short.line.density.5     0.030726     -0.018675    0.024119        -0.047993   \n",
      "short.line.density.2     0.095298     -0.089624    0.016319         0.021139   \n",
      "vedge.mean               0.156964     -0.117721    0.057065        -0.083988   \n",
      "vegde.sd                 0.051983     -0.021265    0.034180        -0.026936   \n",
      "hedge.mean               0.109939     -0.067669    0.054543        -0.129714   \n",
      "hedge.sd                 0.052983     -0.011050    0.034451        -0.033746   \n",
      "intensity.mean           0.790424     -0.483406    0.997503        -0.597960   \n",
      "rawred.mean              0.770221     -0.481173    0.992817        -0.604367   \n",
      "rawblue.mean             0.842002     -0.548471    0.998628        -0.587618   \n",
      "rawgreen.mean            0.740441     -0.403339    0.990384        -0.597925   \n",
      "exred.mean              -0.845784      0.423248   -0.874104         0.438390   \n",
      "exblue.mean              1.000000     -0.841358    0.823514        -0.418152   \n",
      "exgreen.mean            -0.841358      1.000000   -0.512968         0.265988   \n",
      "value.mean               0.823514     -0.512968    1.000000        -0.593934   \n",
      "saturation.mean         -0.418152      0.265988   -0.593934         1.000000   \n",
      "hue.mean                -0.641719      0.795951   -0.350652        -0.029741   \n",
      "\n",
      "                      hue.mean  \n",
      "short.line.density.5  0.076389  \n",
      "short.line.density.2 -0.092742  \n",
      "vedge.mean           -0.084204  \n",
      "vegde.sd             -0.050047  \n",
      "hedge.mean           -0.099132  \n",
      "hedge.sd             -0.065608  \n",
      "intensity.mean       -0.340205  \n",
      "rawred.mean          -0.339479  \n",
      "rawblue.mean         -0.393231  \n",
      "rawgreen.mean        -0.274457  \n",
      "exred.mean            0.290013  \n",
      "exblue.mean          -0.641719  \n",
      "exgreen.mean          0.795951  \n",
      "value.mean           -0.350652  \n",
      "saturation.mean      -0.029741  \n",
      "hue.mean              1.000000  \n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "# Load the dataset\n",
    "train_df = pd.read_csv('/data/datasets/segment/split_train.csv')\n",
    "dev_df = pd.read_csv('/data/datasets/segment/split_dev.csv')\n",
    "\n",
    "# Display basic information about the dataset\n",
    "print(\"Train Dataset Info:\")\n",
    "print(train_df.info())\n",
    "print(\"\\nDev Dataset Info:\")\n",
    "print(dev_df.info())\n",
    "\n",
    "# Display the first few rows of the dataset\n",
    "print(\"\\nTrain Dataset Head:\")\n",
    "print(train_df.head())\n",
    "print(\"\\nDev Dataset Head:\")\n",
    "print(dev_df.head())\n",
    "\n",
    "# Check for missing values\n",
    "print(\"\\nMissing Values in Train Dataset:\")\n",
    "print(train_df.isnull().sum())\n",
    "print(\"\\nMissing Values in Dev Dataset:\")\n",
    "print(dev_df.isnull().sum())\n",
    "\n",
    "# Summary statistics for numerical columns\n",
    "print(\"\\nSummary Statistics for Numerical Columns in Train Dataset:\")\n",
    "print(train_df.describe())\n",
    "print(\"\\nSummary Statistics for Numerical Columns in Dev Dataset:\")\n",
    "print(dev_df.describe())\n",
    "\n",
    "# Summary statistics for categorical columns\n",
    "print(\"\\nSummary Statistics for Categorical Columns in Train Dataset:\")\n",
    "print(train_df.describe(include=['object']))\n",
    "print(\"\\nSummary Statistics for Categorical Columns in Dev Dataset:\")\n",
    "print(dev_df.describe(include=['object']))\n",
    "\n",
    "# Check the distribution of the target column\n",
    "print(\"\\nTarget Distribution in Train Dataset:\")\n",
    "print(train_df['class'].value_counts())\n",
    "print(\"\\nTarget Distribution in Dev Dataset:\")\n",
    "print(dev_df['class'].value_counts())\n",
    "\n",
    "# Check the correlation matrix for numerical features\n",
    "def get_numerical_features(df):\n",
    "    numerical_features = df.select_dtypes(include=[np.number])\n",
    "    if 'class' in numerical_features.columns:\n",
    "        numerical_features = numerical_features.drop(columns=['class'])\n",
    "    return numerical_features\n",
    "\n",
    "numerical_features_train = get_numerical_features(train_df)\n",
    "numerical_features_dev = get_numerical_features(dev_df)\n",
    "\n",
    "print(\"\\nCorrelation Matrix for Numerical Features in Train Dataset:\")\n",
    "print(numerical_features_train.corr())\n",
    "print(\"\\nCorrelation Matrix for Numerical Features in Dev Dataset:\")\n",
    "print(numerical_features_dev.corr())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "428752b0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T17:29:09.881938Z",
     "iopub.status.busy": "2024-10-18T17:29:09.881703Z",
     "iopub.status.idle": "2024-10-18T17:29:10.391156Z",
     "shell.execute_reply": "2024-10-18T17:29:10.390275Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Processed Train Dataset Head:\n",
      "   short.line.density.5  short.line.density.2  vedge.mean  vegde.sd  \\\n",
      "0             -0.360842              4.153902    3.424800  2.879629   \n",
      "1              2.313633             -0.207695    0.328136 -0.049617   \n",
      "2             -0.360842             -0.207695    0.132144 -0.056733   \n",
      "3             -0.360842             -0.207695   -0.534223 -0.139277   \n",
      "4             -0.360842             -0.207695    0.386931 -0.064524   \n",
      "\n",
      "   hedge.mean  hedge.sd  intensity.mean  rawred.mean  rawblue.mean  \\\n",
      "0   -0.407039 -0.139018       -0.753624    -0.777859     -0.708489   \n",
      "1   -0.137767 -0.062495        1.293566     1.083798      1.563316   \n",
      "2    0.086628 -0.039739       -0.883588    -0.918362     -0.844695   \n",
      "3   -0.212563 -0.130493        1.721583     1.569170      1.861430   \n",
      "4   -0.018089 -0.099408        0.518663     0.502627      0.622728   \n",
      "\n",
      "   rawgreen.mean  exred.mean  exblue.mean  exgreen.mean  value.mean  \\\n",
      "0      -0.775832    0.389527    -0.317306      0.147197   -0.740708   \n",
      "1       1.158033   -2.943569     2.848238     -1.870008    1.563452   \n",
      "2      -0.886690    0.399160    -0.465070      0.386202   -0.878854   \n",
      "3       1.681531   -2.770171     2.336749     -1.181674    1.865811   \n",
      "4       0.403580   -0.564158     1.114863     -1.315512    0.609468   \n",
      "\n",
      "   saturation.mean  hue.mean    class  \n",
      "0         1.499123 -0.447518   window  \n",
      "1        -0.264972 -0.553855      sky  \n",
      "2         2.103657 -0.577526  foliage  \n",
      "3        -0.572408 -0.599654      sky  \n",
      "4        -0.501028 -0.418363     path  \n",
      "\n",
      "Processed Dev Dataset Head:\n",
      "   short.line.density.5  short.line.density.2  vedge.mean  vegde.sd  \\\n",
      "0             -0.360842             -0.207695   -0.416629 -0.133940   \n",
      "1             -0.360842             -0.207695   -0.161841 -0.103965   \n",
      "2              2.313633             -0.207695   -0.495025 -0.140344   \n",
      "3             -0.360842             -0.207695   -0.495025 -0.133496   \n",
      "4             -0.360842             -0.207695   -0.377431 -0.125897   \n",
      "\n",
      "   hedge.mean  hedge.sd  intensity.mean  rawred.mean  rawblue.mean  \\\n",
      "0    0.325982 -0.078135       -0.394023    -0.324419     -0.371831   \n",
      "1   -0.242482 -0.135595       -0.410634    -0.337193     -0.392388   \n",
      "2   -0.227524 -0.140629       -0.391091    -0.327612     -0.353841   \n",
      "3   -0.526715 -0.143109       -0.796619    -0.812985     -0.747038   \n",
      "4    1.477867 -0.053097        0.145382     0.116247      0.229530   \n",
      "\n",
      "   rawgreen.mean  exred.mean  exblue.mean  exgreen.mean  value.mean  \\\n",
      "0      -0.483288    0.948254    -0.175230     -0.646297   -0.399256   \n",
      "1      -0.498686    0.996420    -0.215008     -0.627178   -0.420106   \n",
      "2      -0.492525    0.890454    -0.072932     -0.761025   -0.381009   \n",
      "3      -0.831261    0.495492    -0.322994      0.051595   -0.779807   \n",
      "4       0.071004   -0.381128     0.677253     -0.761025    0.210670   \n",
      "\n",
      "   saturation.mean  hue.mean      class  \n",
      "0        -0.115510 -0.162473  brickface  \n",
      "1        -0.099369 -0.166768  brickface  \n",
      "2         0.012368 -0.193947  brickface  \n",
      "3         0.947673 -0.435763     window  \n",
      "4        -0.490454 -0.461507     cement  \n"
     ]
    }
   ],
   "source": [
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "# Function to preprocess data\n",
    "def preprocess_data(df, scaler=None):\n",
    "    df_copy = df.copy()\n",
    "    \n",
    "    # Separate target column if it exists\n",
    "    if 'class' in df_copy.columns:\n",
    "        y = df_copy['class']\n",
    "        X = df_copy.drop(columns=['class'])\n",
    "    else:\n",
    "        X = df_copy\n",
    "        y = None\n",
    "    \n",
    "    # Scale numerical features\n",
    "    if scaler is None:\n",
    "        scaler = StandardScaler()\n",
    "        X[X.select_dtypes(include=[np.number]).columns] = scaler.fit_transform(X.select_dtypes(include=[np.number]))\n",
    "    else:\n",
    "        X[X.select_dtypes(include=[np.number]).columns] = scaler.transform(X.select_dtypes(include=[np.number]))\n",
    "    \n",
    "    # Reattach target column if it exists\n",
    "    if y is not None:\n",
    "        X['class'] = y\n",
    "    \n",
    "    return X, scaler\n",
    "\n",
    "# Preprocess train, dev, and test sets\n",
    "train_df_processed, scaler = preprocess_data(train_df)\n",
    "dev_df_processed, _ = preprocess_data(dev_df, scaler)\n",
    "test_df = pd.read_csv('/data/datasets/segment/split_test_wo_target.csv')\n",
    "test_df_processed, _ = preprocess_data(test_df, scaler)\n",
    "\n",
    "# Display the first few rows of the processed train and dev sets\n",
    "print(\"Processed Train Dataset Head:\")\n",
    "print(train_df_processed.head())\n",
    "print(\"\\nProcessed Dev Dataset Head:\")\n",
    "print(dev_df_processed.head())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0f6d4de8",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T17:29:10.405793Z",
     "iopub.status.busy": "2024-10-18T17:29:10.405481Z",
     "iopub.status.idle": "2024-10-18T17:29:10.453007Z",
     "shell.execute_reply": "2024-10-18T17:29:10.452180Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Processed Train Dataset with Polynomial Features Head:\n",
      "   short.line.density.5  short.line.density.2  vedge.mean  vegde.sd  \\\n",
      "0             -0.360842              4.153902    3.424800  2.879629   \n",
      "1              2.313633             -0.207695    0.328136 -0.049617   \n",
      "2             -0.360842             -0.207695    0.132144 -0.056733   \n",
      "3             -0.360842             -0.207695   -0.534223 -0.139277   \n",
      "4             -0.360842             -0.207695    0.386931 -0.064524   \n",
      "\n",
      "   hedge.mean  hedge.sd  intensity.mean  rawred.mean  rawblue.mean  \\\n",
      "0   -0.407039 -0.139018       -0.753624    -0.777859     -0.708489   \n",
      "1   -0.137767 -0.062495        1.293566     1.083798      1.563316   \n",
      "2    0.086628 -0.039739       -0.883588    -0.918362     -0.844695   \n",
      "3   -0.212563 -0.130493        1.721583     1.569170      1.861430   \n",
      "4   -0.018089 -0.099408        0.518663     0.502627      0.622728   \n",
      "\n",
      "   rawgreen.mean  ...  exblue.mean value.mean  exblue.mean saturation.mean  \\\n",
      "0      -0.775832  ...                0.235031                    -0.475681   \n",
      "1       1.158033  ...                4.453084                    -0.754704   \n",
      "2      -0.886690  ...                0.408729                    -0.978348   \n",
      "3       1.681531  ...                4.359932                    -1.337573   \n",
      "4       0.403580  ...                0.679473                    -0.558578   \n",
      "\n",
      "   exblue.mean hue.mean  exgreen.mean value.mean  \\\n",
      "0              0.142000                -0.109030   \n",
      "1             -1.577510                -2.923669   \n",
      "2              0.268590                -0.339415   \n",
      "3             -1.401240                -2.204781   \n",
      "4             -0.466417                -0.801762   \n",
      "\n",
      "   exgreen.mean saturation.mean  exgreen.mean hue.mean  \\\n",
      "0                      0.220666              -0.065873   \n",
      "1                      0.495500               1.035713   \n",
      "2                      0.812436              -0.223042   \n",
      "3                      0.676400               0.708596   \n",
      "4                      0.659109               0.550361   \n",
      "\n",
      "   value.mean saturation.mean  value.mean hue.mean  saturation.mean hue.mean  \\\n",
      "0                   -1.110413             0.331480                 -0.670884   \n",
      "1                   -0.414271            -0.865926                  0.146756   \n",
      "2                   -1.848808             0.507561                 -1.214917   \n",
      "3                   -1.068005            -1.118841                  0.343247   \n",
      "4                   -0.305361            -0.254979                  0.209612   \n",
      "\n",
      "     class  \n",
      "0   window  \n",
      "1      sky  \n",
      "2  foliage  \n",
      "3      sky  \n",
      "4     path  \n",
      "\n",
      "[5 rows x 137 columns]\n",
      "\n",
      "Processed Dev Dataset with Polynomial Features Head:\n",
      "   short.line.density.5  short.line.density.2  vedge.mean  vegde.sd  \\\n",
      "0             -0.360842             -0.207695   -0.416629 -0.133940   \n",
      "1             -0.360842             -0.207695   -0.161841 -0.103965   \n",
      "2              2.313633             -0.207695   -0.495025 -0.140344   \n",
      "3             -0.360842             -0.207695   -0.495025 -0.133496   \n",
      "4             -0.360842             -0.207695   -0.377431 -0.125897   \n",
      "\n",
      "   hedge.mean  hedge.sd  intensity.mean  rawred.mean  rawblue.mean  \\\n",
      "0    0.325982 -0.078135       -0.394023    -0.324419     -0.371831   \n",
      "1   -0.242482 -0.135595       -0.410634    -0.337193     -0.392388   \n",
      "2   -0.227524 -0.140629       -0.391091    -0.327612     -0.353841   \n",
      "3   -0.526715 -0.143109       -0.796619    -0.812985     -0.747038   \n",
      "4    1.477867 -0.053097        0.145382     0.116247      0.229530   \n",
      "\n",
      "   rawgreen.mean  ...  exblue.mean value.mean  exblue.mean saturation.mean  \\\n",
      "0      -0.483288  ...                0.069962                     0.020241   \n",
      "1      -0.498686  ...                0.090326                     0.021365   \n",
      "2      -0.492525  ...                0.027788                    -0.000902   \n",
      "3      -0.831261  ...                0.251873                    -0.306093   \n",
      "4       0.071004  ...                0.142677                    -0.332161   \n",
      "\n",
      "   exblue.mean hue.mean  exgreen.mean value.mean  \\\n",
      "0              0.028470                 0.258038   \n",
      "1              0.035857                 0.263481   \n",
      "2              0.014145                 0.289958   \n",
      "3              0.140749                -0.040234   \n",
      "4             -0.312557                -0.160325   \n",
      "\n",
      "   exgreen.mean saturation.mean  exgreen.mean hue.mean  \\\n",
      "0                      0.074654               0.105006   \n",
      "1                      0.062322               0.104593   \n",
      "2                     -0.009413               0.147598   \n",
      "3                      0.048895              -0.022483   \n",
      "4                      0.373247               0.351219   \n",
      "\n",
      "   value.mean saturation.mean  value.mean hue.mean  saturation.mean hue.mean  \\\n",
      "0                    0.046118             0.064868                  0.018767   \n",
      "1                    0.041746             0.070060                  0.016572   \n",
      "2                   -0.004712             0.073895                 -0.002399   \n",
      "3                   -0.739002             0.339811                 -0.412961   \n",
      "4                   -0.103324            -0.097226                  0.226348   \n",
      "\n",
      "       class  \n",
      "0  brickface  \n",
      "1  brickface  \n",
      "2  brickface  \n",
      "3     window  \n",
      "4     cement  \n",
      "\n",
      "[5 rows x 137 columns]\n"
     ]
    }
   ],
   "source": [
    "from sklearn.preprocessing import PolynomialFeatures\n",
    "\n",
    "# Function to engineer polynomial features\n",
    "def engineer_polynomial_features(df, degree=2):\n",
    "    df_copy = df.copy()\n",
    "    if 'class' in df_copy.columns:\n",
    "        y = df_copy['class']\n",
    "        X = df_copy.drop(columns=['class'])\n",
    "    else:\n",
    "        X = df_copy\n",
    "        y = None\n",
    "    \n",
    "    poly = PolynomialFeatures(degree=degree, interaction_only=True, include_bias=False)\n",
    "    X_poly = poly.fit_transform(X)\n",
    "    X_poly_df = pd.DataFrame(X_poly, columns=poly.get_feature_names_out(X.columns))\n",
    "    \n",
    "    if y is not None:\n",
    "        X_poly_df['class'] = y\n",
    "    \n",
    "    return X_poly_df\n",
    "\n",
    "# Apply polynomial feature engineering to train, dev, and test sets\n",
    "train_df_poly = engineer_polynomial_features(train_df_processed)\n",
    "dev_df_poly = engineer_polynomial_features(dev_df_processed)\n",
    "test_df_poly = engineer_polynomial_features(test_df_processed)\n",
    "\n",
    "# Print the head of the processed datasets with polynomial features\n",
    "print(\"Processed Train Dataset with Polynomial Features Head:\")\n",
    "print(train_df_poly.head())\n",
    "print(\"\\nProcessed Dev Dataset with Polynomial Features Head:\")\n",
    "print(dev_df_poly.head())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cef9472",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T17:29:10.473553Z",
     "iopub.status.busy": "2024-10-18T17:29:10.473109Z",
     "iopub.status.idle": "2024-10-18T17:32:35.354617Z",
     "shell.execute_reply": "2024-10-18T17:32:35.353605Z"
    }
   },
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.ensemble import StackingClassifier, GradientBoostingClassifier\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier\n",
    "from xgboost import XGBClassifier\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "# Label Encoding for categorical target variable\n",
    "label_encoder = LabelEncoder()\n",
    "train_df_poly['class'] = label_encoder.fit_transform(train_df_poly['class'])\n",
    "dev_df_poly['class'] = label_encoder.transform(dev_df_poly['class'])\n",
    "\n",
    "# Splitting data into features and target\n",
    "X_train = train_df_poly.drop(columns=['class'])\n",
    "y_train = train_df_poly['class']\n",
    "X_dev = dev_df_poly.drop(columns=['class'])\n",
    "y_dev = dev_df_poly['class']\n",
    "\n",
    "# Base models for stacking\n",
    "base_models = [\n",
    "    ('knn', KNeighborsClassifier(n_neighbors=5)),\n",
    "    ('rf', RandomForestClassifier(n_estimators=100, random_state=42)),\n",
    "    ('et', ExtraTreesClassifier(n_estimators=100, random_state=42)),\n",
    "    ('xgb', XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42)),\n",
    "    ('gb', GradientBoostingClassifier(n_estimators=100, random_state=42))\n",
    "]\n",
    "\n",
    "# Stacking model\n",
    "stacking_model = StackingClassifier(\n",
    "    estimators=base_models,\n",
    "    final_estimator=LogisticRegression(max_iter=1000)\n",
    ")\n",
    "\n",
    "# Training the stacking model\n",
    "stacking_model.fit(X_train, y_train)\n",
    "\n",
    "# Predictions on dev set\n",
    "dev_predictions = stacking_model.predict(X_dev)\n",
    "\n",
    "# F1 weighted score on dev set\n",
    "f1_weighted = f1_score(y_dev, dev_predictions, average='weighted')\n",
    "print(f\"F1 Weighted Score on Dev Set: {f1_weighted}\")\n",
    "\n",
    "# Predictions on test set\n",
    "test_predictions = stacking_model.predict(test_df_poly)\n",
    "\n",
    "# Saving predictions\n",
    "dev_predictions_df = pd.DataFrame({'target': label_encoder.inverse_transform(dev_predictions)})\n",
    "test_predictions_df = pd.DataFrame({'target': label_encoder.inverse_transform(test_predictions)})\n",
    "\n",
    "dev_predictions_df.to_csv('../workspace/segment/dev_predictions.csv', index=False)\n",
    "test_predictions_df.to_csv('../workspace/segment/test_predictions.csv', index=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ccd5e914",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T17:32:45.050657Z",
     "iopub.status.busy": "2024-10-18T17:32:45.049973Z",
     "iopub.status.idle": "2024-10-18T17:32:45.178462Z",
     "shell.execute_reply": "2024-10-18T17:32:45.177653Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "F1 Weighted Score on Dev Set: 0.9268104524601448\n",
      "\n",
      "Class Distribution in Train Dataset:\n",
      "class\n",
      "cement       0.146465\n",
      "foliage      0.143579\n",
      "window       0.142857\n",
      "path         0.142857\n",
      "sky          0.141414\n",
      "brickface    0.141414\n",
      "grass        0.141414\n",
      "Name: proportion, dtype: float64\n",
      "\n",
      "Class Distribution in Dev Dataset:\n",
      "class\n",
      "window       0.160173\n",
      "path         0.158009\n",
      "brickface    0.151515\n",
      "sky          0.145022\n",
      "grass        0.140693\n",
      "cement       0.132035\n",
      "foliage      0.112554\n",
      "Name: proportion, dtype: float64\n"
     ]
    }
   ],
   "source": [
    "# Evaluate the base model on the dev set\n",
    "dev_predictions = stacking_model.predict(X_dev)\n",
    "f1_weighted = f1_score(y_dev, dev_predictions, average='weighted')\n",
    "print(f\"F1 Weighted Score on Dev Set: {f1_weighted}\")\n",
    "\n",
    "# Class distribution analysis\n",
    "print(\"\\nClass Distribution in Train Dataset:\")\n",
    "print(train_df['class'].value_counts(normalize=True))\n",
    "print(\"\\nClass Distribution in Dev Dataset:\")\n",
    "print(dev_df['class'].value_counts(normalize=True))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04fcd220",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T17:32:50.525142Z",
     "iopub.status.busy": "2024-10-18T17:32:50.524454Z",
     "iopub.status.idle": "2024-10-18T17:32:51.547607Z",
     "shell.execute_reply": "2024-10-18T17:32:51.546887Z"
    }
   },
   "outputs": [],
   "source": [
    "from metagpt.tools.libs.data_preprocess import get_column_info\n",
    "\n",
    "# Check the latest column information for the processed train dataset with polynomial features\n",
    "column_info = get_column_info(train_df_poly)\n",
    "print(\"column_info\")\n",
    "print(column_info)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1db801b0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T17:33:21.954953Z",
     "iopub.status.busy": "2024-10-18T17:33:21.953957Z",
     "iopub.status.idle": "2024-10-18T17:40:01.118204Z",
     "shell.execute_reply": "2024-10-18T17:40:01.117118Z"
    }
   },
   "outputs": [],
   "source": [
    "from sklearn.ensemble import StackingClassifier\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier\n",
    "from xgboost import XGBClassifier\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "# Define base models\n",
    "base_models = [\n",
    "    ('knn', KNeighborsClassifier(n_neighbors=5)),\n",
    "    ('rf', RandomForestClassifier(n_estimators=200, random_state=42)),\n",
    "    ('et', ExtraTreesClassifier(n_estimators=200, random_state=42)),\n",
    "    ('xgb', XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42, n_estimators=200, learning_rate=0.1)),\n",
    "    ('gb', GradientBoostingClassifier(n_estimators=200, random_state=42))\n",
    "]\n",
    "\n",
    "# Define the stacking model\n",
    "stacking_model = StackingClassifier(\n",
    "    estimators=base_models,\n",
    "    final_estimator=LogisticRegression(max_iter=1000, C=1.0, solver='lbfgs', multi_class='auto')\n",
    ")\n",
    "\n",
    "# Fit the stacking model\n",
    "stacking_model.fit(X_train, y_train)\n",
    "\n",
    "# Predict on the dev set\n",
    "dev_predictions = stacking_model.predict(X_dev)\n",
    "f1_weighted = f1_score(y_dev, dev_predictions, average='weighted')\n",
    "print(f\"F1 Weighted Score on Dev Set: {f1_weighted}\")\n",
    "\n",
    "# Predict on the test set\n",
    "test_predictions = stacking_model.predict(test_df_poly)\n",
    "\n",
    "# Save the predictions\n",
    "dev_predictions_df = pd.DataFrame({'target': label_encoder.inverse_transform(dev_predictions)})\n",
    "test_predictions_df = pd.DataFrame({'target': label_encoder.inverse_transform(test_predictions)})\n",
    "dev_predictions_df.to_csv('../workspace/segment/dev_predictions.csv', index=False)\n",
    "test_predictions_df.to_csv('../workspace/segment/test_predictions.csv', index=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "87447b91",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T17:40:10.198690Z",
     "iopub.status.busy": "2024-10-18T17:40:10.197705Z",
     "iopub.status.idle": "2024-10-18T17:40:10.363507Z",
     "shell.execute_reply": "2024-10-18T17:40:10.362538Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "F1 Weighted Score on Dev Set: 0.9268104524601448\n",
      "\n",
      "Class Distribution in Train Dataset:\n",
      "class\n",
      "cement       0.146465\n",
      "foliage      0.143579\n",
      "window       0.142857\n",
      "path         0.142857\n",
      "sky          0.141414\n",
      "brickface    0.141414\n",
      "grass        0.141414\n",
      "Name: proportion, dtype: float64\n",
      "\n",
      "Class Distribution in Dev Dataset:\n",
      "class\n",
      "window       0.160173\n",
      "path         0.158009\n",
      "brickface    0.151515\n",
      "sky          0.145022\n",
      "grass        0.140693\n",
      "cement       0.132035\n",
      "foliage      0.112554\n",
      "Name: proportion, dtype: float64\n"
     ]
    }
   ],
   "source": [
    "# Evaluate the base model on the dev set\n",
    "dev_predictions = stacking_model.predict(X_dev)\n",
    "f1_weighted = f1_score(y_dev, dev_predictions, average='weighted')\n",
    "print(f\"F1 Weighted Score on Dev Set: {f1_weighted}\")\n",
    "\n",
    "# Class distribution analysis\n",
    "print(\"\\nClass Distribution in Train Dataset:\")\n",
    "print(train_df['class'].value_counts(normalize=True))\n",
    "print(\"\\nClass Distribution in Dev Dataset:\")\n",
    "print(dev_df['class'].value_counts(normalize=True))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "eabaa5c5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-18T17:40:19.526832Z",
     "iopub.status.busy": "2024-10-18T17:40:19.526220Z",
     "iopub.status.idle": "2024-10-18T17:40:19.674051Z",
     "shell.execute_reply": "2024-10-18T17:40:19.673022Z"
    }
   },
   "outputs": [],
   "source": [
    "# Predict on the test set\n",
    "test_predictions = stacking_model.predict(test_df_poly)\n",
    "\n",
    "# Save the predictions\n",
    "dev_predictions_df = pd.DataFrame({'target': label_encoder.inverse_transform(dev_predictions)})\n",
    "test_predictions_df = pd.DataFrame({'target': label_encoder.inverse_transform(test_predictions)})\n",
    "dev_predictions_df.to_csv('../workspace/segment/dev_predictions.csv', index=False)\n",
    "test_predictions_df.to_csv('../workspace/segment/test_predictions.csv', index=False)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
