{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8292ed9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import annotations\n",
    "\n",
    "import os \n",
    "\n",
    "from torch.utils.data import DataLoader, random_split\n",
    "\n",
    "from mmpfn.datasets.petfinder import PetfinderDataset\n",
    "\n",
    "import os \n",
    "import torch \n",
    "import numpy as np \n",
    "import pandas as pd\n",
    "\n",
    "from math import ceil\n",
    "from PIL import Image\n",
    "from pathlib import Path\n",
    "from sklearn.metrics import accuracy_score, log_loss, roc_auc_score, root_mean_squared_error\n",
    "\n",
    "from mmpfn.models.mmpfn_v2 import MMPFNClassifier\n",
    "from mmpfn.models.dino_v2.models.vision_transformer import vit_base\n",
    "from mmpfn.models.mmpfn_v2.constants import ModelInterfaceConfig\n",
    "from mmpfn.models.mmpfn_v2.preprocessing import PreprocessorConfig\n",
    "from mmpfn.scripts_finetune_mm.finetune_mmpfn_main import fine_tune_mmpfn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5dd5c4c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = os.path.join(os.getenv('HOME'), \"works/research/MultiModalPFN/mmpfn/data/petfinder\")\n",
    "dataset = PetfinderDataset(data_path)\n",
    "_ = dataset.get_images()\n",
    "_ = dataset.get_embeddings(multimodal_type='all')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7909c263",
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracy_scores = []\n",
    "for seed in range(1):\n",
    "    torch.manual_seed(seed)\n",
    "    train_len = int(len(dataset) * 0.8)\n",
    "    test_len = len(dataset) - train_len\n",
    "\n",
    "    train_dataset, test_dataset = random_split(dataset, [train_len, test_len])\n",
    "\n",
    "    X_train = train_dataset.dataset.x[train_dataset.indices]\n",
    "    y_train = train_dataset.dataset.y[train_dataset.indices]\n",
    "    X_test = test_dataset.dataset.x[test_dataset.indices]\n",
    "    y_test = test_dataset.dataset.y[test_dataset.indices]\n",
    "    image_train = train_dataset.dataset.embeddings[train_dataset.indices]\n",
    "    image_test = test_dataset.dataset.embeddings[test_dataset.indices]\n",
    "    \n",
    "    for i in range(X_train.shape[1]):\n",
    "        col = X_train[:, i]\n",
    "        col[np.isnan(col)] = np.nanmin(col) - 1\n",
    "    for i in range(X_test.shape[1]):\n",
    "        col = X_test[:, i]\n",
    "        col[np.isnan(col)] = np.nanmin(col) - 1\n",
    "\n",
    "    torch.cuda.empty_cache()\n",
    "\n",
    "    save_path_to_fine_tuned_model = \"./finetuned_mmpfn_pad_ufes_20.ckpt\"\n",
    "    \n",
    "    fine_tune_mmpfn(\n",
    "        # path_to_base_model=\"auto\",\n",
    "        save_path_to_fine_tuned_model=save_path_to_fine_tuned_model,\n",
    "        # Finetuning HPs\n",
    "        time_limit=60,\n",
    "        finetuning_config={\"learning_rate\": 0.00001, \"batch_size\": 1, \"max_steps\": 100},\n",
    "        validation_metric=\"log_loss\",\n",
    "        # Input Data\n",
    "        X_train=pd.DataFrame(X_train),\n",
    "        image_train=image_train,\n",
    "        y_train=pd.Series(y_train),\n",
    "        categorical_features_index=None,\n",
    "        device=\"cuda\",  # use \"cpu\" if you don't have a GPU\n",
    "        task_type=\"multiclass\",\n",
    "        # Optional\n",
    "        show_training_curve=False,  # Shows a final report after finetuning.\n",
    "        logger_level=0,  # Shows all logs, higher values shows less\n",
    "        freeze_input=True,  # Freeze the input layers (encoder and y_encoder) during finetuning\n",
    "        mixer_type='MGM+CQAM', # MGM MGM+CQAM\n",
    "        mgm_heads=64,\n",
    "        cqam_heads=16,\n",
    "    )\n",
    "\n",
    "    # disables preprocessing at inference time to match fine-tuning\n",
    "    no_preprocessing_inference_config = ModelInterfaceConfig(\n",
    "        FINGERPRINT_FEATURE=False,\n",
    "        PREPROCESS_TRANSFORMS=[PreprocessorConfig(name='none')]\n",
    "    )\n",
    "\n",
    "    # Evaluate on Test Data\n",
    "    model_finetuned = MMPFNClassifier(\n",
    "        model_path=save_path_to_fine_tuned_model,\n",
    "        inference_config=no_preprocessing_inference_config,\n",
    "        ignore_pretraining_limits=True,\n",
    "        mixer_type='MGM+CQAM', # MGM MGM+CQAM\n",
    "        mgm_heads=64,\n",
    "        cqam_heads=16,\n",
    "    )\n",
    "\n",
    "    clf_finetuned = model_finetuned.fit(X_train, image_train, y_train)\n",
    "    acc_score = accuracy_score(y_test, clf_finetuned.predict(X_test, image_test))\n",
    "    print(\"accuracy_score (Finetuned):\", acc_score)\n",
    "    accuracy_scores.append(acc_score)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mmpfn2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
