{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8638937d-bdce-41b1-a03b-87f5fd839b0f",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 176022,
     "status": "ok",
     "timestamp": 1747330716957,
     "user": {
      "displayName": "subina khanal",
      "userId": "01847233947579723981"
     },
     "user_tz": -120
    },
    "id": "8638937d-bdce-41b1-a03b-87f5fd839b0f",
    "outputId": "4a544cbb-4512-4b4e-e61c-efc2740b93a0"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\n#Installation\\n\\n# Clone the ibm/tsfm\\n! git clone https://github.com/ibm-granite/granite-tsfm.git\\n! ls\\n\\n# Change directory. Move inside the tsfm repo.\\n%cd granite-tsfm\\n! ls\\n\\n# Relax requirement for python version < 3.12\\n! sed -i.orig \\'s/3\\\\.12/3.13/g\\' pyproject.toml\\n\\n# Install the tsfm library\\n#! pip install \".[notebooks]\"\\n#! python3 -m pip install \".[notebooks]\"\\n! pip3 install \".[notebooks]\"\\n'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'''\n",
    "#Installation\n",
    "\n",
    "# Clone the ibm/tsfm\n",
    "! git clone https://github.com/ibm-granite/granite-tsfm.git\n",
    "! ls\n",
    "\n",
    "# Change directory. Move inside the tsfm repo.\n",
    "%cd granite-tsfm\n",
    "! ls\n",
    "\n",
    "# Relax requirement for python version < 3.12\n",
    "! sed -i.orig 's/3\\.12/3.13/g' pyproject.toml\n",
    "\n",
    "# Install the tsfm library\n",
    "#! pip install \".[notebooks]\"\n",
    "#! python3 -m pip install \".[notebooks]\"\n",
    "! pip3 install \".[notebooks]\"\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "70d97912-fcc0-430a-a2e5-c3f375fed980",
   "metadata": {
    "executionInfo": {
     "elapsed": 47618,
     "status": "ok",
     "timestamp": 1747330778190,
     "user": {
      "displayName": "subina khanal",
      "userId": "01847233947579723981"
     },
     "user_tz": -120
    },
    "id": "70d97912-fcc0-430a-a2e5-c3f375fed980"
   },
   "outputs": [],
   "source": [
    "# Standard\n",
    "import os, types\n",
    "import math\n",
    "import tempfile\n",
    "import torch\n",
    "import time\n",
    "\n",
    "# Third Party\n",
    "from torch.optim import AdamW\n",
    "from torch.optim.lr_scheduler import OneCycleLR\n",
    "from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed\n",
    "from transformers.integrations import INTEGRATION_TO_CALLBACK\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from botocore.client import Config\n",
    "from tsfm_public.toolkit.lr_finder import optimal_lr_finder\n",
    "\n",
    "from sklearn.metrics import mean_squared_error, mean_absolute_error\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "57b7a684-c267-4074-be18-a04db01618fb",
   "metadata": {
    "executionInfo": {
     "elapsed": 661,
     "status": "ok",
     "timestamp": 1747330778890,
     "user": {
      "displayName": "subina khanal",
      "userId": "01847233947579723981"
     },
     "user_tz": -120
    },
    "id": "57b7a684-c267-4074-be18-a04db01618fb"
   },
   "outputs": [],
   "source": [
    "# tsfm library\n",
    "from tsfm_public import (\n",
    "    TimeSeriesPreprocessor,\n",
    "    TinyTimeMixerForPrediction,\n",
    "    TrackingCallback,\n",
    "    count_parameters,\n",
    "    get_datasets\n",
    ")\n",
    "from tsfm_public.toolkit.visualization import plot_predictions\n",
    "from tsfm_public.toolkit.get_model import get_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "723a4bef-c377-4266-a212-386e80503a91",
   "metadata": {
    "executionInfo": {
     "elapsed": 32,
     "status": "ok",
     "timestamp": 1747330778951,
     "user": {
      "displayName": "subina khanal",
      "userId": "01847233947579723981"
     },
     "user_tz": -120
    },
    "id": "723a4bef-c377-4266-a212-386e80503a91"
   },
   "outputs": [],
   "source": [
    "# Set seed for reproducibility\n",
    "SEED = 42\n",
    "set_seed(SEED)\n",
    "\n",
    "# TTM model branch\n",
    "TTM_MODEL_PATH = \"ibm-granite/granite-timeseries-ttm-r2\"\n",
    "\n",
    "# Forecasting parameters\n",
    "CONTEXT_LENGTH = 512\n",
    "\n",
    "# Granite-TTM-R2 supports forecast length upto 720 and Granite-TTM-R1 supports forecast length upto 96\n",
    "PREDICTION_LENGTH = 96\n",
    "\n",
    "# Results dir\n",
    "OUT_DIR = \"/TTM/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fa44878-a70e-42ca-9db8-1deeaae07c9e",
   "metadata": {
    "executionInfo": {
     "elapsed": 846,
     "status": "ok",
     "timestamp": 1747330779803,
     "user": {
      "displayName": "subina khanal",
      "userId": "01847233947579723981"
     },
     "user_tz": -120
    },
    "id": "3fa44878-a70e-42ca-9db8-1deeaae07c9e"
   },
   "outputs": [],
   "source": [
    "timestamp_column = \"DATE\"\n",
    "id_columns = []\n",
    "df = pd.read_csv('5G_millisecond.csv')\n",
    "df['DATE'] = pd.to_datetime(df['DATE'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f9b73104-aa27-4cf6-8198-5b95a78cee5f",
   "metadata": {
    "executionInfo": {
     "elapsed": 26,
     "status": "ok",
     "timestamp": 1747330779853,
     "user": {
      "displayName": "subina khanal",
      "userId": "01847233947579723981"
     },
     "user_tz": -120
    },
    "id": "f9b73104-aa27-4cf6-8198-5b95a78cee5f"
   },
   "outputs": [],
   "source": [
    "target_columns = [\"mac_dl_brate\"]\n",
    "control_columns = [col for col in df.columns if col not in target_columns]\n",
    "id_columns = [\"ue_ident\"]\n",
    "\n",
    "column_specifiers = {\n",
    "    \"timestamp_column\": timestamp_column,\n",
    "    \"id_columns\": [],\n",
    "    \"target_columns\": [\"mac_dl_brate\"],\n",
    "    \"control_columns\": [\"mac_dl_cqi\",\"mac_dl_mcs\",\"mac_dl_ok\",\"mac_dl_nok\",\n",
    "    ],\n",
    "}\n",
    "\n",
    "split_config =  {\n",
    "    \"train\": 0.8,\n",
    "    \"test\": 0.2,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "463f20c5-6da6-4fa0-a0b0-cdad7a2442af",
   "metadata": {
    "executionInfo": {
     "elapsed": 17,
     "status": "ok",
     "timestamp": 1747333478580,
     "user": {
      "displayName": "subina khanal",
      "userId": "01847233947579723981"
     },
     "user_tz": -120
    },
    "id": "463f20c5-6da6-4fa0-a0b0-cdad7a2442af"
   },
   "outputs": [],
   "source": [
    "def zeroshot_eval_with_metrics(dataset_name, batch_size, context_length, forecast_length, seed):\n",
    "    set_seed(seed)\n",
    "    # Get data\n",
    "\n",
    "    tsp = TimeSeriesPreprocessor(\n",
    "        **column_specifiers,\n",
    "        context_length=context_length,\n",
    "        prediction_length=forecast_length,\n",
    "        scaling=True,\n",
    "        encode_categorical=False,\n",
    "        scaler_type=\"minmax\",\n",
    "    )\n",
    "\n",
    "    # Load model\n",
    "    zeroshot_model = get_model(\n",
    "        TTM_MODEL_PATH,\n",
    "        context_length=context_length,\n",
    "        prediction_length=forecast_length,\n",
    "        freq_prefix_tuning=False,\n",
    "        freq='100ms',\n",
    "        prefer_l1_loss=False,\n",
    "        prefer_longer_context=True,\n",
    "    )\n",
    "\n",
    "    dset_train, dset_valid, dset_test = get_datasets(\n",
    "        tsp, df, split_config, use_frequency_token=zeroshot_model.config.resolution_prefix_tuning\n",
    "    )\n",
    "\n",
    "    raw_timestamps = [item[\"timestamp\"] for item in dset_test]\n",
    "    test_timestamps = np.array(raw_timestamps)\n",
    "    test_timestamps = pd.to_datetime(test_timestamps)\n",
    "\n",
    "    #print(dset_test[0].keys())\n",
    "\n",
    "    cutoff = 96\n",
    "\n",
    "    target_column_index = 0\n",
    "    actuals = np.concatenate(\n",
    "    [item[\"future_values\"].numpy()[:, target_column_index:target_column_index+1] for item in dset_test],\n",
    "    axis=0)\n",
    "\n",
    "    actuals = actuals.reshape(-1, forecast_length)[:, :cutoff]\n",
    "    print('Actuals Shape', actuals.shape)\n",
    "\n",
    "    temp_dir = tempfile.mkdtemp()\n",
    "    # zeroshot_trainer\n",
    "    zeroshot_trainer = Trainer(\n",
    "        model=zeroshot_model,\n",
    "        args=TrainingArguments(\n",
    "            output_dir=temp_dir,\n",
    "            per_device_eval_batch_size=batch_size,\n",
    "            seed=seed,\n",
    "            report_to=\"none\",\n",
    "        ),\n",
    "    )\n",
    "    # evaluate = zero-shot performance\n",
    "    #print(\"+\" * 20, \"Test MSE zero-shot\", \"+\" * 20)\n",
    "    #zeroshot_output = zeroshot_trainer.evaluate(dset_test)\n",
    "    #print(zeroshot_output)\n",
    "\n",
    "    # get predictions\n",
    "    start_time = time.time()\n",
    "    predictions_dict = zeroshot_trainer.predict(dset_test)\n",
    "    end_time = time.time()\n",
    "\n",
    "    inference_time = end_time - start_time\n",
    "    num_samples = len(dset_test)\n",
    "    inference_per_sample = inference_time / num_samples\n",
    "\n",
    "    print(f\"TTM Total Zero Shot inference time: {inference_time:.4f} seconds\")\n",
    "    print(f\"TTM Zero Shot Inference time per sample: {inference_per_sample:.6f} seconds\")\n",
    "\n",
    "    predictions_np = predictions_dict.predictions[0]\n",
    "\n",
    "    #print(predictions_np.shape)\n",
    "\n",
    "    # get backbone embeddings (if needed for further analysis)\n",
    "\n",
    "    backbone_embedding = predictions_dict.predictions[1]\n",
    "\n",
    "    #print(backbone_embedding.shape)\n",
    "\n",
    "    predictions = predictions_dict.predictions\n",
    "    #print(type(predictions))\n",
    "\n",
    "    predictions = predictions[0]\n",
    "    predictions = predictions_np[:, :, target_column_index]\n",
    "    predictions = predictions.reshape(-1, forecast_length)[:, :cutoff]\n",
    "    print('Predictions Shape', predictions.shape)\n",
    "\n",
    "\n",
    "    predictions_avg = np.mean(predictions, axis=1)\n",
    "    actuals_avg = np.mean(actuals, axis=1)\n",
    "\n",
    "\n",
    "    rmse = np.sqrt(mean_squared_error(actuals, predictions))\n",
    "    mae = mean_absolute_error(actuals, predictions)\n",
    "\n",
    "    return rmse, mae"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d1e7855",
   "metadata": {},
   "outputs": [],
   "source": [
    "seeds = [42, 123, 99]  \n",
    "results = {\"seed\": [], \"rmse\": [], \"mae\": []}\n",
    "\n",
    "for s in seeds:\n",
    "    rmse, mae = zeroshot_eval_with_metrics(\n",
    "        dataset_name=df,\n",
    "        context_length=CONTEXT_LENGTH,\n",
    "        forecast_length=PREDICTION_LENGTH,\n",
    "        batch_size=64,\n",
    "        seed=s,\n",
    "    )\n",
    "    results[\"seed\"].append(s)\n",
    "    results[\"rmse\"].append(rmse)\n",
    "    results[\"mae\"].append(mae)\n",
    "    print(f\"Seed {s} -> RMSE: {rmse:.4f}, MAE: {mae:.4f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a78947cc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "====== Summary over Seeds ======\n",
      "RMSE: mean=0.0359, std=0.0000\n",
      "MAE : mean=0.0230, std=0.0000\n"
     ]
    }
   ],
   "source": [
    "rmse_mean, rmse_std = np.mean(results[\"rmse\"]), np.std(results[\"rmse\"])\n",
    "mae_mean, mae_std = np.mean(results[\"mae\"]), np.std(results[\"mae\"])\n",
    "\n",
    "print(\"\\n====== Summary over Seeds ======\")\n",
    "print(f\"RMSE: mean={rmse_mean:.4f}, std={rmse_std:.4f}\")\n",
    "print(f\"MAE : mean={mae_mean:.4f}, std={mae_std:.4f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "9b3769d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fewshot_finetune_eval_with_metrics(\n",
    "    dataset_name,\n",
    "    batch_size,\n",
    "    learning_rate=None,\n",
    "    context_length=512,\n",
    "    forecast_length=96,\n",
    "    fewshot_percent=0.01,\n",
    "    freeze_backbone=True,\n",
    "    num_epochs=30,\n",
    "    save_dir=OUT_DIR,\n",
    "    loss=\"mse\",\n",
    "    quantile=0.5,\n",
    "    seed=42,\n",
    "):\n",
    "    set_seed(seed)\n",
    "\n",
    "    #out_dir = os.path.join(save_dir, dataset_name)\n",
    "\n",
    "    print(\"-\" * 20, f\"Running few-shot {fewshot_percent}%\", \"-\" * 20)\n",
    "\n",
    "    # Data prep: Get dataset\n",
    "\n",
    "    tsp = TimeSeriesPreprocessor(\n",
    "        **column_specifiers,\n",
    "        context_length=context_length,\n",
    "        prediction_length=forecast_length,\n",
    "        scaling=True,\n",
    "        encode_categorical=False,\n",
    "        scaler_type=\"minmax\",\n",
    "    )\n",
    "\n",
    "    # change head dropout to 0.7 for ett datasets\n",
    "    if \"ett\" in dataset_name:\n",
    "        finetune_forecast_model = get_model(\n",
    "            TTM_MODEL_PATH,\n",
    "            context_length=context_length,\n",
    "            prediction_length=forecast_length,\n",
    "            freq_prefix_tuning=False,\n",
    "            freq=None,\n",
    "            prefer_l1_loss=False,\n",
    "            prefer_longer_context=True,\n",
    "            # Can also provide TTM Config args\n",
    "            head_dropout=0.7,\n",
    "            loss=loss,\n",
    "            quantile=quantile,\n",
    "        )\n",
    "    else:\n",
    "        finetune_forecast_model = get_model(\n",
    "            TTM_MODEL_PATH,\n",
    "            context_length=context_length,\n",
    "            prediction_length=forecast_length,\n",
    "            freq_prefix_tuning=False,\n",
    "            freq=None,\n",
    "            prefer_l1_loss=False,\n",
    "            prefer_longer_context=True,\n",
    "            # Can also provide TTM Config args\n",
    "            loss=loss,\n",
    "            quantile=quantile,\n",
    "        )\n",
    "\n",
    "    dset_train, dset_val, dset_test = get_datasets(\n",
    "        tsp,\n",
    "        df,\n",
    "        split_config,\n",
    "        fewshot_fraction=fewshot_percent / 100,\n",
    "        fewshot_location=\"first\",\n",
    "        use_frequency_token=finetune_forecast_model.config.resolution_prefix_tuning,\n",
    "    )\n",
    "\n",
    "    if freeze_backbone:\n",
    "        print(\n",
    "            \"Number of params before freezing backbone\",\n",
    "            count_parameters(finetune_forecast_model),\n",
    "        )\n",
    "\n",
    "        # Freeze the backbone of the model\n",
    "        for param in finetune_forecast_model.backbone.parameters():\n",
    "            param.requires_grad = False\n",
    "\n",
    "        # Count params\n",
    "        print(\n",
    "            \"Number of params after freezing the backbone\",\n",
    "            count_parameters(finetune_forecast_model),\n",
    "        )\n",
    "\n",
    "    \n",
    "    # Find optimal learning rate\n",
    "    # Use with caution: Set it manually if the suggested learning rate is not suitable\n",
    "    if learning_rate is None:\n",
    "        learning_rate, finetune_forecast_model = optimal_lr_finder(\n",
    "            finetune_forecast_model,\n",
    "            dset_train,\n",
    "            batch_size=batch_size,\n",
    "        )\n",
    "        print(\"OPTIMAL SUGGESTED LEARNING RATE =\", learning_rate)\n",
    "\n",
    "    print(f\"Using learning rate = {learning_rate}\")\n",
    "    finetune_forecast_args = TrainingArguments(\n",
    "        output_dir=os.path.join(OUT_DIR, \"finetune_output\"),\n",
    "        overwrite_output_dir=True,\n",
    "        learning_rate=learning_rate,\n",
    "        num_train_epochs=num_epochs,\n",
    "        do_eval=True,\n",
    "        evaluation_strategy=\"epoch\",\n",
    "        per_device_train_batch_size=batch_size,\n",
    "        per_device_eval_batch_size=batch_size,\n",
    "        dataloader_num_workers=8,\n",
    "        report_to=\"none\",\n",
    "        save_strategy=\"epoch\",\n",
    "        logging_strategy=\"epoch\",\n",
    "        save_total_limit=1,\n",
    "        logging_dir=os.path.join(OUT_DIR, \"finetune_logs\"),  # Make sure to specify a logging directory\n",
    "        load_best_model_at_end=True,  # Load the best model when training ends\n",
    "        metric_for_best_model=\"eval_loss\",  # Metric to monitor for early stopping\n",
    "        greater_is_better=False,  # For loss\n",
    "        seed=seed,\n",
    "    )\n",
    "\n",
    "    # Create the early stopping callback\n",
    "    early_stopping_callback = EarlyStoppingCallback(\n",
    "        early_stopping_patience=10,  # Number of epochs with no improvement after which to stop\n",
    "        early_stopping_threshold=1e-5,  # Minimum improvement required to consider as improvement\n",
    "    )\n",
    "    tracking_callback = TrackingCallback()\n",
    "\n",
    "    # Optimizer and scheduler\n",
    "    optimizer = AdamW(finetune_forecast_model.parameters(), lr=learning_rate)\n",
    "    scheduler = OneCycleLR(\n",
    "        optimizer,\n",
    "        learning_rate,\n",
    "        epochs=num_epochs,\n",
    "        steps_per_epoch=math.ceil(len(dset_train) / (batch_size)),\n",
    "    )\n",
    "\n",
    "    finetune_forecast_trainer = Trainer(\n",
    "        model=finetune_forecast_model,\n",
    "        args=finetune_forecast_args,\n",
    "        train_dataset=dset_train,\n",
    "        eval_dataset=dset_val,\n",
    "        callbacks=[early_stopping_callback, tracking_callback],\n",
    "        optimizers=(optimizer, scheduler),\n",
    "    )\n",
    "    finetune_forecast_trainer.remove_callback(INTEGRATION_TO_CALLBACK[\"codecarbon\"])\n",
    "\n",
    "    # Fine tune\n",
    "    finetune_forecast_trainer.train()\n",
    "\n",
    "    # Evaluation\n",
    "    print(\"+\" * 20, f\"Test MSE after few-shot {fewshot_percent}% fine-tuning\", \"+\" * 20)\n",
    "\n",
    "    finetune_forecast_trainer.model.loss = \"mse\"  # fixing metric to mse for evaluation\n",
    "\n",
    "    fewshot_output = finetune_forecast_trainer.evaluate(dset_test)\n",
    "    print(fewshot_output)\n",
    "    print(\"+\" * 60)\n",
    "\n",
    "    raw_timestamps = [item[\"timestamp\"] for item in dset_test]\n",
    "    test_timestamps = np.array(raw_timestamps)\n",
    "    test_timestamps = pd.to_datetime(test_timestamps)\n",
    "\n",
    "    #print(dset_test[0].keys())\n",
    "\n",
    "    cutoff = 96\n",
    "\n",
    "    target_column_index = 0\n",
    "    actuals = np.concatenate(\n",
    "    [item[\"future_values\"].numpy()[:, target_column_index:target_column_index+1] for item in dset_test],\n",
    "    axis=0)\n",
    "\n",
    "    actuals = actuals.reshape(-1, forecast_length)[:, :cutoff]\n",
    "    print('Actuals Shape', actuals.shape)\n",
    "\n",
    "\n",
    "    # get predictions\n",
    "    start_time = time.time()\n",
    "    predictions_dict = finetune_forecast_trainer.predict(dset_test)\n",
    "    end_time = time.time()\n",
    "\n",
    "    inference_time = end_time - start_time\n",
    "    num_samples = len(dset_test)\n",
    "    inference_per_sample = inference_time / num_samples\n",
    "\n",
    "    print(f\"TTM Total fientune inference time: {inference_time:.4f} seconds\")\n",
    "    print(f\"TTM finetune Inference time per sample: {inference_per_sample:.6f} seconds\")\n",
    "\n",
    "    predictions_np = predictions_dict.predictions[0]\n",
    "\n",
    "    print(predictions_np.shape)\n",
    "\n",
    "    # get backbone embeddings (if needed for further analysis)\n",
    "\n",
    "    backbone_embedding = predictions_dict.predictions[1]\n",
    "\n",
    "    print(backbone_embedding.shape)\n",
    "\n",
    " \n",
    "    predictions = predictions_dict.predictions\n",
    "    #print(type(predictions))\n",
    "\n",
    "    predictions = predictions[0]\n",
    "    predictions = predictions_np[:, :, target_column_index]\n",
    "    predictions = predictions.reshape(-1, forecast_length)[:, :cutoff]\n",
    "    print('Predictions Shape', predictions.shape)\n",
    "\n",
    "\n",
    "    predictions_avg = np.mean(predictions, axis=1)\n",
    "    actuals_avg = np.mean(actuals, axis=1)\n",
    "\n",
    "\n",
    "    rmse = np.sqrt(mean_squared_error(actuals, predictions))\n",
    "    mae = mean_absolute_error(actuals, predictions)\n",
    "\n",
    "    return rmse, mae"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cada6db",
   "metadata": {},
   "outputs": [],
   "source": [
    "f_results = {\"seed\": [], \"rmse\": [], \"mae\": []}\n",
    "\n",
    "for s in seeds:\n",
    "    rmse, mae = fewshot_finetune_eval_with_metrics(\n",
    "        dataset_name=df,\n",
    "        context_length=CONTEXT_LENGTH,\n",
    "        forecast_length=PREDICTION_LENGTH,\n",
    "        batch_size=128,\n",
    "        fewshot_percent=0.01,\n",
    "        learning_rate=None,\n",
    "        seed=s,\n",
    "    )\n",
    "    f_results[\"seed\"].append(s)\n",
    "    f_results[\"rmse\"].append(rmse)\n",
    "    f_results[\"mae\"].append(mae)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "6ec0be64",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "====== Summary over Seeds ======\n",
      "RMSE: mean=0.0359, std=0.0000\n",
      "MAE : mean=0.0230, std=0.0000\n"
     ]
    }
   ],
   "source": [
    "ft_rmse_mean, ft_rmse_std = np.mean(f_results[\"rmse\"]), np.std(f_results[\"rmse\"])\n",
    "ft_mae_mean, ft_mae_std = np.mean(f_results[\"mae\"]), np.std(f_results[\"mae\"])\n",
    "\n",
    "print(\"\\n====== Summary over Seeds ======\")\n",
    "print(f\"RMSE: mean={ft_rmse_mean:.4f}, std={ft_rmse_std:.4f}\")\n",
    "print(f\"MAE : mean={ft_mae_mean:.4f}, std={ft_mae_std:.4f}\")\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "tf-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
