{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np \n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "from models import train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_x = pd.read_csv(\".../Wine/data_x.csv\")\n",
    "data_y = pd.read_csv(\".../Wine/data_y.csv\")\n",
    "in_features = data_x.shape[1]\n",
    "\n",
    "columns_list = data_x.columns\n",
    "\n",
    "data_x = torch.tensor(np.array(data_x))\n",
    "data_y = torch.tensor(np.array(data_y)).float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/psh/.local/lib/python3.9/site-packages/sklearn/preprocessing/_data.py:2663: UserWarning: n_quantiles (1000) is greater than the total number of samples (971). n_quantiles is set to n_samples.\n",
      "  warnings.warn(\n",
      "/home/psh/.local/lib/python3.9/site-packages/sklearn/preprocessing/_data.py:2663: UserWarning: n_quantiles (1000) is greater than the total number of samples (499). n_quantiles is set to n_samples.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 0 || train rmse : 0.8742167353630066 , val rmse : 0.8934306059655316, test rmse : 0.8980950041901952\n",
      "Epoch : 1 || train rmse : 0.8737459778785706 , val rmse : 0.8929758230006271, test rmse : 0.8976725809027063\n",
      "Epoch : 2 || train rmse : 0.8732752203941345 , val rmse : 0.8925207501164532, test rmse : 0.8972501338252349\n",
      "Epoch : 3 || train rmse : 0.8728042244911194 , val rmse : 0.8920651028399678, test rmse : 0.896827428410586\n",
      "Epoch : 4 || train rmse : 0.8723328709602356 , val rmse : 0.891608573921469, test rmse : 0.8964041986255925\n",
      "Epoch : 5 || train rmse : 0.871860682964325 , val rmse : 0.8911508642404321, test rmse : 0.8959801813821134\n",
      "Epoch : 6 || train rmse : 0.8713873028755188 , val rmse : 0.890691685645921, test rmse : 0.8955551212185102\n",
      "Epoch : 7 || train rmse : 0.870912492275238 , val rmse : 0.8902307364035761, test rmse : 0.8951287579589532\n",
      "Epoch : 8 || train rmse : 0.8704360127449036 , val rmse : 0.889767701956419, test rmse : 0.8947008152130402\n",
      "Epoch : 9 || train rmse : 0.869957447052002 , val rmse : 0.8893022530750695, test rmse : 0.8942710107990871\n",
      "Epoch : 10 || train rmse : 0.8694764375686646 , val rmse : 0.8888340429318788, test rmse : 0.8938390468435724\n",
      "Epoch : 11 || train rmse : 0.8689926266670227 , val rmse : 0.8883627285871796, test rmse : 0.8934046290326039\n",
      "Epoch : 12 || train rmse : 0.8685056567192078 , val rmse : 0.8878879537100681, test rmse : 0.892967461949196\n",
      "Epoch : 13 || train rmse : 0.8680151104927063 , val rmse : 0.8874093724178689, test rmse : 0.8925272545730423\n",
      "Epoch : 14 || train rmse : 0.8675206303596497 , val rmse : 0.8869266366432677, test rmse : 0.8920837122123512\n",
      "Epoch : 15 || train rmse : 0.8670218586921692 , val rmse : 0.8864393948265784, test rmse : 0.8916365196161422\n",
      "Epoch : 16 || train rmse : 0.8665182590484619 , val rmse : 0.8859472925237374, test rmse : 0.891185329912288\n",
      "Epoch : 17 || train rmse : 0.866009533405304 , val rmse : 0.8854499613822541, test rmse : 0.8907297448220112\n",
      "Epoch : 18 || train rmse : 0.8654953837394714 , val rmse : 0.8849470173242229, test rmse : 0.8902693307294883\n",
      "Epoch : 19 || train rmse : 0.8649752140045166 , val rmse : 0.8844380691746391, test rmse : 0.8898036426045642\n",
      "Epoch : 20 || train rmse : 0.864448606967926 , val rmse : 0.8839227167552733, test rmse : 0.8893322391527366\n",
      "Epoch : 21 || train rmse : 0.8639152646064758 , val rmse : 0.8834005503582198, test rmse : 0.8888546923234728\n",
      "Epoch : 22 || train rmse : 0.8633747100830078 , val rmse : 0.882871154585323, test rmse : 0.8883705973793827\n",
      "Epoch : 23 || train rmse : 0.8628264665603638 , val rmse : 0.8823341095988253, test rmse : 0.8878795621008249\n",
      "Epoch : 24 || train rmse : 0.8622700572013855 , val rmse : 0.8817889879597447, test rmse : 0.8873811982141582\n",
      "Epoch : 25 || train rmse : 0.8617051243782043 , val rmse : 0.8812353519478943, test rmse : 0.8868751193226817\n",
      "Epoch : 26 || train rmse : 0.8611311912536621 , val rmse : 0.8806727526849142, test rmse : 0.8863609233229767\n",
      "Epoch : 27 || train rmse : 0.8605476021766663 , val rmse : 0.8801007335549996, test rmse : 0.8858382002300491\n",
      "Epoch : 28 || train rmse : 0.8599539995193481 , val rmse : 0.8795188211043004, test rmse : 0.8853065242942327\n",
      "Epoch : 29 || train rmse : 0.8593498468399048 , val rmse : 0.8789265328697787, test rmse : 0.8847654636421031\n",
      "Epoch : 30 || train rmse : 0.8587345480918884 , val rmse : 0.8783233723530373, test rmse : 0.884214568885479\n",
      "Epoch : 31 || train rmse : 0.8581076860427856 , val rmse : 0.8777088285982315, test rmse : 0.8836533878687476\n",
      "Epoch : 32 || train rmse : 0.8574686050415039 , val rmse : 0.877082381513262, test rmse : 0.8830814622705253\n",
      "Epoch : 33 || train rmse : 0.856816828250885 , val rmse : 0.8764434999798464, test rmse : 0.8824983215869048\n",
      "Epoch : 34 || train rmse : 0.8561516404151917 , val rmse : 0.8757916319910781, test rmse : 0.8819034866895492\n",
      "Epoch : 35 || train rmse : 0.8554725050926208 , val rmse : 0.8751262227524018, test rmse : 0.8812964753538796\n",
      "Epoch : 36 || train rmse : 0.8547788262367249 , val rmse : 0.874446700318683, test rmse : 0.8806767925240415\n",
      "Epoch : 37 || train rmse : 0.8540700078010559 , val rmse : 0.8737524771974812, test rmse : 0.88004393251828\n",
      "Epoch : 38 || train rmse : 0.8533453345298767 , val rmse : 0.8730429610585072, test rmse : 0.8793973868547402\n",
      "Epoch : 39 || train rmse : 0.85260409116745 , val rmse : 0.8723175429659175, test rmse : 0.8787366412383827\n",
      "Epoch : 40 || train rmse : 0.8518456816673279 , val rmse : 0.8715756033691163, test rmse : 0.8780611735523133\n",
      "Epoch : 41 || train rmse : 0.851069450378418 , val rmse : 0.8708165145454372, test rmse : 0.8773704570537241\n",
      "Epoch : 42 || train rmse : 0.8502746224403381 , val rmse : 0.8700396412865322, test rmse : 0.8766639656554662\n",
      "Epoch : 43 || train rmse : 0.8494605422019958 , val rmse : 0.8692443408033991, test rmse : 0.8759411761634991\n",
      "Epoch : 44 || train rmse : 0.848626434803009 , val rmse : 0.8684299648295485, test rmse : 0.8752015628729223\n",
      "Epoch : 45 || train rmse : 0.8477716445922852 , val rmse : 0.8675958643900368, test rmse : 0.874444615952503\n",
      "Epoch : 46 || train rmse : 0.8468953371047974 , val rmse : 0.8667413915876309, test rmse : 0.8736698292798689\n",
      "Epoch : 47 || train rmse : 0.8459969162940979 , val rmse : 0.865865906408964, test rmse : 0.8728767198936654\n",
      "Epoch : 48 || train rmse : 0.8450755476951599 , val rmse : 0.864968778521031, test rmse : 0.8720648217146042\n",
      "Epoch : 49 || train rmse : 0.8441305756568909 , val rmse : 0.8640493915062384, test rmse : 0.8712336987643022\n",
      "Epoch : 50 || train rmse : 0.8431613445281982 , val rmse : 0.8631071555200862, test rmse : 0.87038295368123\n",
      "Epoch : 51 || train rmse : 0.8421670198440552 , val rmse : 0.8621415111381608, test rmse : 0.8695122325586288\n",
      "Epoch : 52 || train rmse : 0.8411471247673035 , val rmse : 0.8611519440528232, test rmse : 0.8686212421838437\n",
      "Epoch : 53 || train rmse : 0.8401010632514954 , val rmse : 0.860137985612006, test rmse : 0.8677097518568779\n",
      "Epoch : 54 || train rmse : 0.8390282392501831 , val rmse : 0.8590992399339339, test rmse : 0.8667776231018595\n",
      "Epoch : 55 || train rmse : 0.8379281759262085 , val rmse : 0.8580353943584543, test rmse : 0.8658248183219914\n",
      "Epoch : 56 || train rmse : 0.8368006348609924 , val rmse : 0.85694623263784, test rmse : 0.8648514081995998\n",
      "Epoch : 57 || train rmse : 0.8356452584266663 , val rmse : 0.8558316647513274, test rmse : 0.8638576221166706\n",
      "Epoch : 58 || train rmse : 0.8344619870185852 , val rmse : 0.85469174106564, test rmse : 0.8628438555861673\n",
      "Epoch : 59 || train rmse : 0.833250880241394 , val rmse : 0.8535266891464499, test rmse : 0.8618106946797797\n",
      "Epoch : 60 || train rmse : 0.8320122361183167 , val rmse : 0.8523369348189481, test rmse : 0.8607589647521385\n",
      "Epoch : 61 || train rmse : 0.8307466506958008 , val rmse : 0.8511231402775442, test rmse : 0.8596897518364546\n",
      "Epoch : 62 || train rmse : 0.829454779624939 , val rmse : 0.849886216048136, test rmse : 0.8586044579599947\n",
      "Epoch : 63 || train rmse : 0.8281379342079163 , val rmse : 0.8486273674523565, test rmse : 0.8575048138099873\n",
      "Epoch : 64 || train rmse : 0.8267974853515625 , val rmse : 0.8473480918060041, test rmse : 0.8563929236117352\n",
      "Epoch : 65 || train rmse : 0.8254354596138 , val rmse : 0.8460501746186222, test rmse : 0.8552712653794576\n",
      "Epoch : 66 || train rmse : 0.8240540623664856 , val rmse : 0.8447356803334393, test rmse : 0.8541426642773584\n",
      "Epoch : 67 || train rmse : 0.8226562142372131 , val rmse : 0.8434069250446574, test rmse : 0.853010212335561\n",
      "Epoch : 68 || train rmse : 0.8212451338768005 , val rmse : 0.8420664983346069, test rmse : 0.8518771822697471\n",
      "Epoch : 69 || train rmse : 0.8198243975639343 , val rmse : 0.8407176584230729, test rmse : 0.8507468104472728\n",
      "Epoch : 70 || train rmse : 0.818398118019104 , val rmse : 0.8393644128238196, test rmse : 0.8496221645461377\n",
      "Epoch : 71 || train rmse : 0.8169698119163513 , val rmse : 0.838008870491443, test rmse : 0.8485049143400085\n",
      "Epoch : 72 || train rmse : 0.8155409693717957 , val rmse : 0.8366524810935406, test rmse : 0.8473957563288297\n",
      "Epoch : 73 || train rmse : 0.8141133189201355 , val rmse : 0.835295857548931, test rmse : 0.8462952153594143\n",
      "Epoch : 74 || train rmse : 0.8126877546310425 , val rmse : 0.8339399549510886, test rmse : 0.8452057348908144\n",
      "Epoch : 75 || train rmse : 0.8112657070159912 , val rmse : 0.8325870028371688, test rmse : 0.8441288667558227\n",
      "Epoch : 76 || train rmse : 0.8098492622375488 , val rmse : 0.8312394550843029, test rmse : 0.8430642599461563\n",
      "Epoch : 77 || train rmse : 0.808438241481781 , val rmse : 0.8298965124582631, test rmse : 0.8420130275475393\n",
      "Epoch : 78 || train rmse : 0.8070335984230042 , val rmse : 0.8285580328917918, test rmse : 0.8409767102857642\n",
      "Epoch : 79 || train rmse : 0.8056356310844421 , val rmse : 0.8272257570722414, test rmse : 0.8399565752588496\n",
      "Epoch : 80 || train rmse : 0.8042427897453308 , val rmse : 0.8258998220597239, test rmse : 0.8389486542615966\n",
      "Epoch : 81 || train rmse : 0.8028558492660522 , val rmse : 0.8245814095179583, test rmse : 0.837954778096401\n",
      "Epoch : 82 || train rmse : 0.8014734983444214 , val rmse : 0.8232716727625778, test rmse : 0.8369757442165164\n",
      "Epoch : 83 || train rmse : 0.8000954389572144 , val rmse : 0.8219730028946747, test rmse : 0.8360119810089742\n",
      "Epoch : 84 || train rmse : 0.7987222671508789 , val rmse : 0.8206885323619376, test rmse : 0.8350625553887225\n",
      "Epoch : 85 || train rmse : 0.7973531484603882 , val rmse : 0.8194164981526101, test rmse : 0.8341281907797424\n",
      "Epoch : 86 || train rmse : 0.7959877848625183 , val rmse : 0.8181573966227058, test rmse : 0.833208766316617\n",
      "Epoch : 87 || train rmse : 0.7946280241012573 , val rmse : 0.8169160141898749, test rmse : 0.8323036882342476\n",
      "Epoch : 88 || train rmse : 0.7932743430137634 , val rmse : 0.8156912884551192, test rmse : 0.8314111000635624\n",
      "Epoch : 89 || train rmse : 0.7919284701347351 , val rmse : 0.8144784945129732, test rmse : 0.8305302016776677\n",
      "Epoch : 90 || train rmse : 0.7905892133712769 , val rmse : 0.813278321003808, test rmse : 0.8296596868603752\n",
      "Epoch : 91 || train rmse : 0.7892578840255737 , val rmse : 0.8120950754468969, test rmse : 0.8287991731764043\n",
      "Epoch : 92 || train rmse : 0.787936806678772 , val rmse : 0.810929977314245, test rmse : 0.827945875153245\n",
      "Epoch : 93 || train rmse : 0.786627471446991 , val rmse : 0.809784666870293, test rmse : 0.8270986790431789\n",
      "Epoch : 94 || train rmse : 0.7853342890739441 , val rmse : 0.8086546242489892, test rmse : 0.8262619761386893\n",
      "Epoch : 95 || train rmse : 0.784058690071106 , val rmse : 0.8075381372168693, test rmse : 0.8254367546994518\n",
      "Epoch : 96 || train rmse : 0.7827978730201721 , val rmse : 0.8064305305816313, test rmse : 0.824629823234541\n",
      "Epoch : 97 || train rmse : 0.7815545797348022 , val rmse : 0.805333160462018, test rmse : 0.8238391278346794\n",
      "Epoch : 98 || train rmse : 0.7803259491920471 , val rmse : 0.8042526056870488, test rmse : 0.8230639316192121\n",
      "Epoch : 99 || train rmse : 0.779104471206665 , val rmse : 0.803190083343098, test rmse : 0.8223039631577781\n",
      "Random seed : 0 || measure: 0.8223039631577781\n"
     ]
    }
   ],
   "source": [
    "\n",
    "device = \"cuda:3\"\n",
    "\n",
    "tree_num = 10\n",
    "max_order = 1\n",
    "regression = True\n",
    "\n",
    "multiclass = 2\n",
    "lr_rate = 0.01\n",
    "epoch_num = 100\n",
    "measure = \"rmse\"\n",
    "\n",
    "\n",
    "all_measure = []\n",
    "\n",
    "\n",
    "for w in range(0,1):\n",
    "    model_path = None\n",
    "    \n",
    "    random_state = w\n",
    "    \n",
    "    model,measure_result,_ = train.Trainer(data_x=data_x,\n",
    "                                     data_y=data_y, \n",
    "                                     tree_num =tree_num, \n",
    "                                     max_order =max_order, \n",
    "                                     device = device, \n",
    "                                     model_path =model_path, \n",
    "                                     measure =measure,\n",
    "                                     regression = regression,\n",
    "                                     random_state = random_state,\n",
    "                                     multiclass = 2 ,\n",
    "                                     lr_rate =lr_rate , \n",
    "                                     epoch_num=epoch_num , \n",
    "                                     num_train_batch = 4096, \n",
    "                                     num_test_batch = 2048 ,           \n",
    "                                     init_train=True,\n",
    "                                     reg_lambda=0.0,\n",
    "                                     features_list=\"all\",\n",
    "                                     uniform_transform=True)\n",
    "\n",
    "    \n",
    "    all_measure.append(measure_result)\n",
    "    print(f\"Random seed : {w} || measure: {measure_result}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_lat",
   "language": "python",
   "name": "conda_lat"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
