{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "source": [
    "from datasets import load_dataset, Dataset\n",
    "from operator import truediv\n",
    "from datasets import concatenate_datasets\n",
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, TrainingArguments, Trainer, EvalPrediction\n",
    "from sklearn.metrics import f1_score, roc_auc_score, accuracy_score, hamming_loss\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from accelerate import Accelerator\n",
    "exp_name = 'baseline_od'\n",
    "device = Accelerator.device\n",
    "accelerator = Accelerator()\n",
    "base_model = \"roberta-base\"\n",
    "EXP = 'cost' #should be perf or cost\n",
    "tokenizer = AutoTokenizer.from_pretrained(base_model)\n",
    "ds = load_dataset(\"CARROT-LLM-Routing/SPROUT-o3mini\")\n",
    "## CORRECTLY FORMAT DATA\n",
    "\n",
    "ds_train = ds['train']\n",
    "ds_test = ds['test']\n",
    "ds_validation = ds['validation']\n",
    "all_models = ds['train'].to_pandas().columns.tolist()[6:]\n",
    "print(all_models)\n",
    "train_inputs = ds_train['prompt']\n",
    "val_inputs = ds_validation['prompt']\n",
    "test_inputs = ds_test['prompt']\n",
    "def get_labels(split, models, mode):\n",
    "    labels = []\n",
    "    if mode == 'cost':\n",
    "        for m in models:\n",
    "            tmp = [sample['num_output_tokens'] for sample in split[m]]\n",
    "            labels.append(tmp)\n",
    "        return np.array(labels).T\n",
    "    else:\n",
    "        for m in models:\n",
    "            tmp = [sample['score'] for sample in split[m]]\n",
    "            labels.append(tmp)\n",
    "        return np.array(labels).T\n",
    "if EXP == 'cost':\n",
    "    train_labels = get_labels(ds_train, all_models, 'cost')\n",
    "    val_labels = get_labels(ds_validation, all_models, 'cost')\n",
    "    test_labels = get_labels(ds_test, all_models, 'cost')\n",
    "    model_max_token_counts = [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000]\n",
    "    data_train = Dataset.from_dict({\"input\": train_inputs, \"count\": train_labels})\n",
    "    data_val = Dataset.from_dict({\"input\": val_inputs, \"count\": val_labels})\n",
    "    data_test = Dataset.from_dict({\"input\": test_inputs, \"count\": test_labels})\n",
    "    def preprocess_function(examples):\n",
    "        floating_counts = []\n",
    "        for count_by_case in examples[\"count\"]:\n",
    "            floating_counts.append(list(map(truediv, count_by_case, model_max_token_counts)))\n",
    "        examples = tokenizer(examples[\"input\"], truncation=True)\n",
    "        examples[\"label\"] = floating_counts\n",
    "        return examples\n",
    "    tokenized_train = data_train.map(preprocess_function, batched=True, remove_columns=[\"count\"])\n",
    "    tokenized_val = data_val.map(preprocess_function, batched=True, remove_columns=[\"count\"])\n",
    "    tokenized_test = data_test.map(preprocess_function, batched=True, remove_columns=[\"count\"])\n",
    "    tokenized_train.set_format(\"torch\")\n",
    "    tokenized_val.set_format(\"torch\")\n",
    "    tokenized_test.set_format(\"torch\")\n",
    "else:\n",
    "    train_labels = get_labels(ds_train, all_models, 'perf')\n",
    "    val_labels = get_labels(ds_validation, all_models, 'perf')\n",
    "    test_labels = get_labels(ds_test, all_models, 'perf')\n",
    "    data_train = Dataset.from_dict({\"input\": train_inputs, \"label\": train_labels})\n",
    "    data_val = Dataset.from_dict({\"input\": val_inputs, \"label\": val_labels})\n",
    "    data_test = Dataset.from_dict({\"input\": test_inputs, \"label\": test_labels})\n",
    "    def preprocess_function(examples):\n",
    "        return tokenizer(examples[\"input\"], truncation=True)\n",
    "    tokenized_train = data_train.map(preprocess_function, batched=True)\n",
    "    tokenized_val = data_val.map(preprocess_function, batched=True)\n",
    "    tokenized_test = data_test.map(preprocess_function, batched=True)\n",
    "    tokenized_train.set_format(\"torch\")\n",
    "    tokenized_val.set_format(\"torch\")\n",
    "    tokenized_test.set_format(\"torch\")\n",
    "\n",
    "## TRAIN COST/PERF PREDICTION\n",
    "def compute_metrics_for_regression(eval_pred: EvalPrediction):\n",
    "    logits, labels = eval_pred\n",
    "    mse = np.mean(np.square(labels-logits))\n",
    "    return {\"mse\": mse}\n",
    "def multi_label_metrics(predictions, labels, threshold=0.5):\n",
    "    # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)\n",
    "    sigmoid = torch.nn.Sigmoid()\n",
    "    probs = sigmoid(torch.Tensor(predictions))\n",
    "    # next, use threshold to turn them into integer predictions\n",
    "    y_pred = np.zeros(probs.shape)\n",
    "    y_pred[np.where(probs >= threshold)] = 1\n",
    "    # finally, compute metrics\n",
    "    y_true = np.zeros(labels.shape)\n",
    "    y_true[np.where(labels >= threshold)] = 1 \n",
    "    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')\n",
    "    roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')\n",
    "    accuracy = accuracy_score(y_true, y_pred)\n",
    "    # return as dictionary\n",
    "    metrics = {'f1': f1_micro_average,\n",
    "               'roc_auc': roc_auc,\n",
    "               'accuracy': accuracy,\n",
    "               'all_accuracy': (y_true == y_pred).mean()}\n",
    "    return metrics\n",
    "\n",
    "\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n",
    "if EXP == 'cost':\n",
    "    model = AutoModelForSequenceClassification.from_pretrained(base_model, \n",
    "                                                           problem_type=\"regression\", \n",
    "                                                           num_labels=len(all_models),\n",
    "                                                           )\n",
    "    batch_size = 16\n",
    "    metric_name = 'mse'\n",
    "    num_train_epochs = 5\n",
    "    args = TrainingArguments(\n",
    "        exp_name + \"-router-\" + base_model,\n",
    "        evaluation_strategy = \"epoch\",\n",
    "        save_strategy = \"epoch\",\n",
    "        learning_rate=2e-5,\n",
    "        per_device_train_batch_size=batch_size,\n",
    "        per_device_eval_batch_size=batch_size,\n",
    "        num_train_epochs=num_train_epochs,\n",
    "        weight_decay=0.01,\n",
    "        load_best_model_at_end=True,\n",
    "        metric_for_best_model=metric_name,\n",
    "        greater_is_better=False,\n",
    "        save_total_limit=1,\n",
    "        logging_dir='./fmselect-logs',\n",
    "        logging_steps=10,\n",
    "        report_to='none'\n",
    "    )\n",
    "    trainer = Trainer(\n",
    "        model,\n",
    "        args,\n",
    "        train_dataset=tokenized_train,\n",
    "        eval_dataset=tokenized_val,\n",
    "        tokenizer=tokenizer,\n",
    "        compute_metrics=compute_metrics_for_regression,\n",
    "        data_collator=data_collator\n",
    "    )\n",
    "    trainer.train()\n",
    "\n",
    "else:\n",
    "    model = AutoModelForSequenceClassification.from_pretrained(base_model, \n",
    "                                                           problem_type=\"multi_label_classification\", \n",
    "                                                           num_labels=len(all_models), device_map = 'auto') #reference_compile=False)\n",
    "    batch_size = 16\n",
    "    metric_name = 'all_accuracy'\n",
    "    num_train_epochs = 5\n",
    "    args = TrainingArguments(\n",
    "        exp_name + \"-router-\" + base_model,\n",
    "        evaluation_strategy = \"epoch\",\n",
    "        save_strategy = \"epoch\",\n",
    "        learning_rate=2e-5,\n",
    "        per_device_train_batch_size=batch_size,\n",
    "        per_device_eval_batch_size=batch_size,\n",
    "        num_train_epochs=5,\n",
    "        weight_decay=0.01,\n",
    "        load_best_model_at_end=True,\n",
    "        metric_for_best_model=metric_name,\n",
    "        greater_is_better=True,\n",
    "        save_total_limit=1,\n",
    "        logging_dir='./fmselect-logs',\n",
    "        logging_steps=10,\n",
    "        report_to = 'none')\n",
    "    def compute_metrics(p: EvalPrediction):\n",
    "        preds = p.predictions[0] if isinstance(p.predictions, \n",
    "                tuple) else p.predictions\n",
    "        result = multi_label_metrics(\n",
    "            predictions=preds, \n",
    "            labels=p.label_ids)\n",
    "        return result\n",
    "    trainer = Trainer(\n",
    "        model,\n",
    "        args,\n",
    "        train_dataset=tokenized_train,\n",
    "        eval_dataset=tokenized_val,\n",
    "        tokenizer=tokenizer,\n",
    "        compute_metrics=compute_metrics,\n",
    "        data_collator=data_collator)\n",
    "    trainer.train()"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "['openai-o3-mini', 'wxai-llama-3-405b-instruct', 'wxai-llama-3-3-70b-instruct', 'wxai-llama-3-2-3b-instruct', 'wxai-llama-3-2-1b-instruct', 'wxai-llama-3-1-8b-instruct', 'wxai-llama-3-1-70b-instruct', 'aws-titan-text-premier-v1', 'wxai-mixtral-8x7b-instruct-v01', 'aws-claude-3-5-sonnet-v1', 'openai-gpt-4o-mini', 'openai-gpt-4o', 'wxai-granite-3-8b-instruct-8k-max-tokens', 'wxai-granite-3-2b-instruct-8k-max-tokens']\n"
     ]
    },
    {
     "output_type": "display_data",
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f6dda145bffb47dba0bf042e30a6c9ed",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/30301 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "39c8bd83b2a74632ba7ac4dd7c52f85a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/6493 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "329667859890433ebefc56f9b74e5c11",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/6494 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "/home/skunk/routing2/lib/python3.10/site-packages/transformers/training_args.py:1575: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
      "  warnings.warn(\n",
      "/tmp/ipykernel_140129/722317140.py:128: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
      "  trainer = Trainer(\n",
      "/home/skunk/routing2/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "output_type": "display_data",
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='4735' max='4735' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [4735/4735 49:21, Epoch 5/5]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Mse</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>1.051300</td>\n",
       "      <td>0.966218</td>\n",
       "      <td>0.966313</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1.231800</td>\n",
       "      <td>0.965187</td>\n",
       "      <td>0.965257</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.825000</td>\n",
       "      <td>0.946590</td>\n",
       "      <td>0.946680</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>1.019100</td>\n",
       "      <td>0.955795</td>\n",
       "      <td>0.955899</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.987700</td>\n",
       "      <td>0.960254</td>\n",
       "      <td>0.960357</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "/home/skunk/routing2/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn(\n",
      "/home/skunk/routing2/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn(\n",
      "/home/skunk/routing2/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn(\n",
      "/home/skunk/routing2/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn(\n",
      "/home/skunk/routing2/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "source": [
    "if EXP == 'perf':\n",
    "    trainer.save_model(\"./HFPERFo3\")\n",
    "if EXP == 'cost':\n",
    "    trainer.save_model(\"./HFcosto3\")"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [],
   "outputs": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}