{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8292ed9d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/wall/works/research/MultiModalPFN/mmpfn/models/dino_v2/layers/swiglu_ffn.py:51: UserWarning: xFormers is not available (SwiGLU)\n",
      "  warnings.warn(\"xFormers is not available (SwiGLU)\")\n",
      "/home/wall/works/research/MultiModalPFN/mmpfn/models/dino_v2/layers/attention.py:33: UserWarning: xFormers is not available (Attention)\n",
      "  warnings.warn(\"xFormers is not available (Attention)\")\n",
      "/home/wall/works/research/MultiModalPFN/mmpfn/models/dino_v2/layers/block.py:40: UserWarning: xFormers is not available (Block)\n",
      "  warnings.warn(\"xFormers is not available (Block)\")\n",
      "/home/wall/anaconda3/envs/mmpfn2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from __future__ import annotations\n",
    "\n",
    "import os \n",
    "\n",
    "from torch.utils.data import random_split\n",
    "\n",
    "from mmpfn.datasets.pad_ufes_20 import PADUFES20Dataset\n",
    "\n",
    "import os \n",
    "import torch \n",
    "import numpy as np \n",
    "import pandas as pd\n",
    "\n",
    "from sklearn.metrics import accuracy_score\n",
    "from mmpfn.models.mmpfn_v2 import MMPFNClassifier\n",
    "from mmpfn.models.mmpfn_v2.constants import ModelInterfaceConfig\n",
    "from mmpfn.models.mmpfn_v2.preprocessing import PreprocessorConfig\n",
    "from mmpfn.scripts_finetune_mm.finetune_mmpfn_main import fine_tune_mmpfn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5facc593",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5dd5c4c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Load embeddings from embeddings/pad_ufes_20/pad_ufes_20_dinov3.pt\n"
     ]
    }
   ],
   "source": [
    "# data_path = os.path.join(os.getenv('HOME'), \"workspace/works/tabular_image/MultiModalPFN/mmpfn/data/pad_ufes_20\")\n",
    "data_path = os.path.join(os.getenv('HOME'), \"works/research/MultiModalPFN/mmpfn/data/pad_ufes_20\")\n",
    "dataset = PADUFES20Dataset(data_path)\n",
    "# _ = dataset.get_images()\n",
    "_ = dataset.get_embeddings()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "868f0040",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Fine-tuning Steps:   4%|▍         | 4/100 [00:06<03:35,  2.25s/it, Best Val. Loss=0.521, Best Val. Score=-0.521, Training Loss=0.508, Val. Loss=0.521, Patience=47, Utilization=0, Grad Norm=8.58][2025-09-23 23:31:15,168] INFO - \n",
      "Optimizer step skipped due to NaNs/infs in grad scaling.\n",
      "Fine-tuning Steps:  17%|█▋        | 17/100 [00:29<02:25,  1.75s/it, Best Val. Loss=0.482, Best Val. Score=-0.482, Training Loss=0.502, Val. Loss=0.49, Patience=35, Utilization=0, Grad Norm=9.08] [2025-09-23 23:31:38,201] INFO - \n",
      "Optimizer step skipped due to NaNs/infs in grad scaling.\n",
      "Fine-tuning Steps: 101it [02:44,  1.65s/it, Best Val. Loss=0.471, Best Val. Score=-0.471, Training Loss=0.397, Val. Loss=0.543, Patience=-48, Utilization=0, Grad Norm=8.44]                         \n",
      "[2025-09-23 23:33:51,993] INFO - Initial Validation Loss: 0.5660938875519983 Best Validation Loss: 0.4705977997870326 Total Steps: 101 Best Step: 36 Total Time Spent: 168.80583572387695\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy_score (Finetuned): 0.8652173913043478\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Fine-tuning Steps:   2%|▏         | 2/100 [00:02<03:41,  2.26s/it, Best Val. Loss=0.59, Best Val. Score=-0.59, Training Loss=0.524, Val. Loss=0.59, Patience=49, Utilization=0, Grad Norm=6.91][2025-09-23 23:34:00,366] INFO - \n",
      "Optimizer step skipped due to NaNs/infs in grad scaling.\n",
      "Fine-tuning Steps:  22%|██▏       | 22/100 [00:38<02:33,  1.97s/it, Best Val. Loss=0.521, Best Val. Score=-0.521, Training Loss=0.468, Val. Loss=0.521, Patience=30, Utilization=0, Grad Norm=9.84][2025-09-23 23:34:36,793] INFO - \n",
      "Optimizer step skipped due to NaNs/infs in grad scaling.\n",
      "Fine-tuning Steps: 101it [03:03,  1.83s/it, Best Val. Loss=0.461, Best Val. Score=-0.461, Training Loss=0.395, Val. Loss=0.461, Patience=-48, Utilization=0, Grad Norm=8.14]                         \n",
      "[2025-09-23 23:37:00,544] INFO - Initial Validation Loss: 0.5935690765318904 Best Validation Loss: 0.46058942037713024 Total Steps: 101 Best Step: 100 Total Time Spent: 185.22291922569275\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy_score (Finetuned): 0.8652173913043478\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Fine-tuning Steps:   3%|▎         | 3/100 [00:03<02:40,  1.66s/it, Best Val. Loss=0.604, Best Val. Score=-0.604, Training Loss=0.43, Val. Loss=0.613, Patience=48, Utilization=0, Grad Norm=9.77][2025-09-23 23:37:10,131] INFO - \n",
      "Optimizer step skipped due to NaNs/infs in grad scaling.\n",
      "Fine-tuning Steps:  14%|█▍        | 14/100 [00:25<02:46,  1.93s/it, Best Val. Loss=0.538, Best Val. Score=-0.538, Training Loss=0.452, Val. Loss=0.538, Patience=38, Utilization=0, Grad Norm=7.21][2025-09-23 23:37:32,189] INFO - \n",
      "Optimizer step skipped due to NaNs/infs in grad scaling.\n",
      "Fine-tuning Steps: 101it [02:45,  1.66s/it, Best Val. Loss=0.503, Best Val. Score=-0.503, Training Loss=0.255, Val. Loss=0.503, Patience=-48, Utilization=0, Grad Norm=5.69]                         \n",
      "[2025-09-23 23:39:51,566] INFO - Initial Validation Loss: 0.6042756883815655 Best Validation Loss: 0.503314944787894 Total Steps: 101 Best Step: 99 Total Time Spent: 167.3140733242035\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy_score (Finetuned): 0.8478260869565217\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Fine-tuning Steps:   2%|▏         | 2/100 [00:02<03:16,  2.01s/it, Best Val. Loss=0.524, Best Val. Score=-0.524, Training Loss=0.637, Val. Loss=0.524, Patience=49, Utilization=0, Grad Norm=7.39][2025-09-23 23:39:59,645] INFO - \n",
      "Optimizer step skipped due to NaNs/infs in grad scaling.\n",
      "Fine-tuning Steps:  13%|█▎        | 13/100 [00:21<02:54,  2.00s/it, Best Val. Loss=0.511, Best Val. Score=-0.511, Training Loss=0.53, Val. Loss=0.511, Patience=39, Utilization=0, Grad Norm=8.06] [2025-09-23 23:40:18,795] INFO - \n",
      "Optimizer step skipped due to NaNs/infs in grad scaling.\n",
      "Fine-tuning Steps: 101it [02:45,  1.66s/it, Best Val. Loss=0.455, Best Val. Score=-0.455, Training Loss=0.379, Val. Loss=0.456, Patience=-48, Utilization=0, Grad Norm=6.48]                         \n",
      "[2025-09-23 23:42:42,553] INFO - Initial Validation Loss: 0.5362367285994355 Best Validation Loss: 0.4547377325446615 Total Steps: 101 Best Step: 90 Total Time Spent: 167.64795780181885\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy_score (Finetuned): 0.8543478260869565\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Fine-tuning Steps:   2%|▏         | 2/100 [00:02<04:18,  2.64s/it, Best Val. Loss=0.478, Best Val. Score=-0.478, Training Loss=0.549, Val. Loss=0.478, Patience=49, Utilization=0, Grad Norm=7.05][2025-09-23 23:42:51,214] INFO - \n",
      "Optimizer step skipped due to NaNs/infs in grad scaling.\n",
      "Fine-tuning Steps:   9%|▉         | 9/100 [00:14<02:48,  1.85s/it, Best Val. Loss=0.449, Best Val. Score=-0.449, Training Loss=0.48, Val. Loss=0.449, Patience=43, Utilization=0, Grad Norm=13.9] [2025-09-23 23:43:03,547] INFO - \n",
      "Optimizer step skipped due to NaNs/infs in grad scaling.\n",
      "Fine-tuning Steps:  26%|██▌       | 26/100 [00:45<02:05,  1.70s/it, Best Val. Loss=0.397, Best Val. Score=-0.397, Training Loss=0.505, Val. Loss=0.42, Patience=27, Utilization=0, Grad Norm=31.4] [2025-09-23 23:43:34,345] INFO - \n",
      "Optimizer step skipped due to NaNs/infs in grad scaling.\n",
      "Fine-tuning Steps: 101it [02:50,  1.70s/it, Best Val. Loss=0.377, Best Val. Score=-0.377, Training Loss=0.29, Val. Loss=0.39, Patience=-47, Utilization=0, Grad Norm=4.08]                          \n",
      "[2025-09-23 23:45:37,702] INFO - Initial Validation Loss: 0.5028540863596721 Best Validation Loss: 0.37733133012739 Total Steps: 101 Best Step: 59 Total Time Spent: 171.71857380867004\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy_score (Finetuned): 0.8478260869565217\n"
     ]
    }
   ],
   "source": [
    "accuracy_scores = []\n",
    "for seed in range(5):\n",
    "    torch.manual_seed(seed)\n",
    "    train_len = int(len(dataset) * 0.8)\n",
    "    test_len = len(dataset) - train_len\n",
    "\n",
    "    train_dataset, test_dataset = random_split(dataset, [train_len, test_len])\n",
    "\n",
    "    X_train = train_dataset.dataset.x[train_dataset.indices]\n",
    "    y_train = train_dataset.dataset.y[train_dataset.indices]\n",
    "    X_test = test_dataset.dataset.x[test_dataset.indices]\n",
    "    y_test = test_dataset.dataset.y[test_dataset.indices]\n",
    "    image_train = train_dataset.dataset.embeddings[train_dataset.indices]\n",
    "    image_test = test_dataset.dataset.embeddings[test_dataset.indices]\n",
    "    \n",
    "    for i in range(X_train.shape[1]):\n",
    "        col = X_train[:, i]\n",
    "        col[np.isnan(col)] = np.nanmin(col) - 1\n",
    "    for i in range(X_test.shape[1]):\n",
    "        col = X_test[:, i]\n",
    "        col[np.isnan(col)] = np.nanmin(col) - 1\n",
    "\n",
    "    torch.cuda.empty_cache()\n",
    "\n",
    "    save_path_to_fine_tuned_model = \"./finetuned_mmpfn_pad_ufes_20.ckpt\"\n",
    "    \n",
    "    fine_tune_mmpfn(\n",
    "        # path_to_base_model=\"auto\",\n",
    "        save_path_to_fine_tuned_model=save_path_to_fine_tuned_model,\n",
    "        # Finetuning HPs\n",
    "        time_limit=60,\n",
    "        finetuning_config={\"learning_rate\": 0.00001, \"batch_size\": 1, \"max_steps\": 100},\n",
    "        validation_metric=\"log_loss\",\n",
    "        # Input Data\n",
    "        X_train=pd.DataFrame(X_train),\n",
    "        image_train=image_train,\n",
    "        y_train=pd.Series(y_train),\n",
    "        categorical_features_index=None,\n",
    "        device=\"cuda\",  # use \"cpu\" if you don't have a GPU\n",
    "        task_type=\"multiclass\",\n",
    "        # Optional\n",
    "        show_training_curve=False,  # Shows a final report after finetuning.\n",
    "        logger_level=0,  # Shows all logs, higher values shows less\n",
    "        freeze_input=True,  # Freeze the input layers (encoder and y_encoder) during finetuning\n",
    "        mixer_type='MGM+CQAM', # MGM MGM+CQAM\n",
    "        mgm_heads=128,\n",
    "        cqam_heads=12,\n",
    "    )\n",
    "\n",
    "    # disables preprocessing at inference time to match fine-tuning\n",
    "    no_preprocessing_inference_config = ModelInterfaceConfig(\n",
    "        FINGERPRINT_FEATURE=False,\n",
    "        PREPROCESS_TRANSFORMS=[PreprocessorConfig(name='none')]\n",
    "    )\n",
    "\n",
    "    # Evaluate on Test Data\n",
    "    model_finetuned = MMPFNClassifier(\n",
    "        model_path=save_path_to_fine_tuned_model,\n",
    "        inference_config=no_preprocessing_inference_config, \n",
    "        ignore_pretraining_limits=True,\n",
    "        mixer_type='MGM+CQAM', # MGM MGM+CQAM\n",
    "        mgm_heads=128,\n",
    "        cqam_heads=12,\n",
    "    )\n",
    "\n",
    "    clf_finetuned = model_finetuned.fit(X_train, image_train, y_train)\n",
    "    acc_score = accuracy_score(y_test, clf_finetuned.predict(X_test, image_test))\n",
    "    print(\"accuracy_score (Finetuned):\", acc_score)\n",
    "    accuracy_scores.append(acc_score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "33aa39dd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean Accuracy: 0.8560869565217392\n",
      "Std Accuracy: 0.007826086956521745\n"
     ]
    }
   ],
   "source": [
    "# get mean and std of accuracy scores\n",
    "mean_accuracy = np.mean(accuracy_scores)\n",
    "std_accuracy = np.std(accuracy_scores)\n",
    "print(\"Mean Accuracy:\", mean_accuracy)\n",
    "print(\"Std Accuracy:\", std_accuracy)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mmpfn2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
