{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "18a7461a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.configs import ApibenchDataConfig, MLLMDataConfig, TrainConfig, Olympus1DataConfig, Olympus2DataConfig\n",
    "from utils.prepareDataset import convert_to_conversational, convert_to_preference_dataset, load_dataset_json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "dbdc9ff0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.prepareDataset import load_dataset_json, ApibenchDataConfig\n",
    "dataset_config = ApibenchDataConfig()\n",
    "dataset_json_train = load_dataset_json(dataset_config.train_set)\n",
    "dataset_json_val = load_dataset_json(dataset_config.val_set)\n",
    "dataset_json_test = load_dataset_json(dataset_config.test_set)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a9476ceb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num domains: 40\n",
      "Num unique models: 852\n"
     ]
    }
   ],
   "source": [
    "from router_model import Router, ModelInfo\n",
    "import torch\n",
    "import random\n",
    "\n",
    "def build_domain_and_model_infos(dataset_json_train, experience=\"APIBench\"):\n",
    "\n",
    "    if experience == \"APIBench\":\n",
    "\n",
    "# 1. Map domain strings to integer IDs\n",
    "\n",
    "        all_domain_strings = sorted({\n",
    "            row.get(\"api_data\", {}).get(\"domain\") \n",
    "            for row in dataset_json_train \n",
    "            if row.get(\"api_data\", {}).get(\"domain\") is not None\n",
    "        })\n",
    "        domain2id = {d: i for i, d in enumerate(all_domain_strings)}\n",
    "        id2domain = {i: d for d, i in domain2id.items()}\n",
    "\n",
    "        num_domains = len(domain2id)\n",
    "        print(\"Num domains:\", num_domains)\n",
    "\n",
    "        # 2. Build unique models (model_name -> ModelInfo)\n",
    "        model_infos = {}\n",
    "        for row in dataset_json_train:\n",
    "            mname = row.get(\"model_name\")\n",
    "            if mname is None:\n",
    "                continue\n",
    "            api_data = row.get(\"api_data\", {})\n",
    "            dom_str = api_data.get(\"domain\")\n",
    "            if dom_str is None or dom_str not in domain2id:\n",
    "                continue\n",
    "            desc = api_data.get(\"description\", \"\")\n",
    "            if mname not in model_infos:\n",
    "                model_infos[mname] = ModelInfo(\n",
    "                    model_id=mname,\n",
    "                    domain_id=domain2id[dom_str],\n",
    "                    card_text=desc,\n",
    "                    cost=1.0,  # you can plug in something from performance if you like\n",
    "                )\n",
    "\n",
    "        all_models = list(model_infos.values())\n",
    "        print(\"Num unique models:\", len(all_models))\n",
    "\n",
    "    elif experience == \"MLLM\":\n",
    "\n",
    "        all_domain_strings = sorted({\n",
    "            row.get(\"domain\") \n",
    "            for row in dataset_json_train \n",
    "            if row.get(\"domain\") is not None\n",
    "        })\n",
    "        domain2id = {d: i for i, d in enumerate(all_domain_strings)}\n",
    "        id2domain = {i: d for d, i in domain2id.items()}\n",
    "\n",
    "        num_domains = len(domain2id)\n",
    "        print(\"Num domains:\", num_domains)\n",
    "\n",
    "        # 2. Build unique models (model_name -> ModelInfo)\n",
    "        model_infos = {}\n",
    "        for row in dataset_json_train:\n",
    "            mname = row.get(\"model_name\")\n",
    "            if mname is None:\n",
    "                continue\n",
    "            dom_str = row.get(\"domain\")\n",
    "            if dom_str is None or dom_str not in domain2id:  # Skip rows with None or missing domain\n",
    "                continue\n",
    "            # Description is in api_data for MLLM\n",
    "            desc = row.get(\"api_data\", {}).get(\"description\", \"\")\n",
    "            if mname not in model_infos:\n",
    "                model_infos[mname] = ModelInfo(\n",
    "                    model_id=mname,\n",
    "                    domain_id=domain2id[dom_str],\n",
    "                    card_text=desc,\n",
    "                    cost=1.0,  # you can plug in something from performance if you like\n",
    "                )\n",
    "\n",
    "        all_models = list(model_infos.values())\n",
    "        print(\"Num unique models:\", len(all_models))\n",
    "\n",
    "    return all_models, domain2id, id2domain, num_domains, model_infos\n",
    "all_models, domain2id, id2domain, num_domains, model_infos = build_domain_and_model_infos(dataset_json_train)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5254eade",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_from_dataset(dataset_json, domain2id, model_infos):\n",
    "\n",
    "    train_prompts = []\n",
    "    train_best_models = []\n",
    "    train_domains = []\n",
    "\n",
    "    for row in dataset_json:\n",
    "        instr = row.get(\"instruction\")\n",
    "        mname = row.get(\"model_name\")\n",
    "        \n",
    "        # Skip rows with missing required fields\n",
    "        if instr is None or mname is None:\n",
    "            continue\n",
    "        \n",
    "        # Try to get domain from api_data first, then from top level\n",
    "        dom_str = row.get(\"api_data\", {}).get(\"domain\") or row.get(\"domain\")\n",
    "        \n",
    "        # Skip rows with missing or None domain, or domain not in domain2id\n",
    "        if dom_str is None or dom_str not in domain2id:\n",
    "            continue\n",
    "\n",
    "        # some rows might have models not in model_infos (shouldn't, but just in case)\n",
    "        if mname not in model_infos:\n",
    "            continue\n",
    "\n",
    "        train_prompts.append(instr)\n",
    "        train_best_models.append(mname)\n",
    "        train_domains.append(domain2id[dom_str])\n",
    "\n",
    "    print(\"Num training examples:\", len(train_prompts))\n",
    "\n",
    "    return train_prompts, train_best_models, train_domains"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "12443de4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_embeddings_and_indices(all_models, train_prompts, val_prompts, test_prompts, train_best_models, val_best_models, test_best_models, train_domains, val_domains, test_domains):\n",
    "\n",
    "    from sentence_transformers import SentenceTransformer\n",
    "    import torch\n",
    "\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "    st_model = SentenceTransformer(\"sentence-transformers/all-MiniLM-L6-v2\", device=str(device))\n",
    "\n",
    "    # 1) Model card embeddings\n",
    "    model_id2idx = {m.model_id: i for i, m in enumerate(all_models)}\n",
    "    idx2model_id = {i: m.model_id for i, m in enumerate(all_models)}\n",
    "    model_domains = torch.tensor([m.domain_id for m in all_models], dtype=torch.long, device=device)\n",
    "\n",
    "    model_texts = [m.card_text for m in all_models]\n",
    "    with torch.no_grad():\n",
    "        model_embs_init = st_model.encode(\n",
    "            model_texts,\n",
    "            batch_size=64,\n",
    "            convert_to_tensor=True,\n",
    "            device=device,\n",
    "            show_progress_bar=True,\n",
    "        )  # [M, C]\n",
    "\n",
    "    emb_dim = model_embs_init.size(1)\n",
    "    print(\"Model emb dim:\", emb_dim)\n",
    "\n",
    "    # 2) Prompt embeddings (train / val / test)\n",
    "    def encode_texts(texts):\n",
    "        with torch.no_grad():\n",
    "            return st_model.encode(\n",
    "                texts,\n",
    "                batch_size=64,\n",
    "                convert_to_tensor=True,\n",
    "                device=device,\n",
    "                show_progress_bar=True,\n",
    "            )\n",
    "\n",
    "    train_prompt_embs = encode_texts(train_prompts)  # [N_train, C]\n",
    "    val_prompt_embs   = encode_texts(val_prompts)    # [N_val, C]\n",
    "    test_prompt_embs  = encode_texts(test_prompts)   # [N_test, C]\n",
    "\n",
    "    # Map gold models to indices\n",
    "    train_model_idx = torch.tensor(\n",
    "        [model_id2idx[m] for m in train_best_models],\n",
    "        dtype=torch.long,\n",
    "        device=device,\n",
    "    )\n",
    "    val_model_idx = torch.tensor(\n",
    "        [model_id2idx[m] for m in val_best_models],\n",
    "        dtype=torch.long,\n",
    "        device=device,\n",
    "    )\n",
    "    test_model_idx = torch.tensor(\n",
    "        [model_id2idx[m] for m in test_best_models],\n",
    "        dtype=torch.long,\n",
    "        device=device,\n",
    "    )\n",
    "\n",
    "    train_domain_ids = torch.tensor(train_domains, dtype=torch.long, device=device)\n",
    "    val_domain_ids   = torch.tensor(val_domains,   dtype=torch.long, device=device)\n",
    "    test_domain_ids  = torch.tensor(test_domains,  dtype=torch.long, device=device)\n",
    "\n",
    "    return model_embs_init, train_prompt_embs, val_prompt_embs, test_prompt_embs, train_model_idx, val_model_idx, test_model_idx, train_domain_ids, val_domain_ids, test_domain_ids, model_domains\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "eca10deb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num domains: 40\n",
      "Num unique models: 852\n",
      "Num training examples: 6652\n",
      "Num training examples: 1174\n",
      "Num training examples: 867\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Batches: 100%|██████████| 14/14 [00:00<00:00, 43.72it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model emb dim: 384\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Batches: 100%|██████████| 104/104 [00:00<00:00, 186.02it/s]\n",
      "Batches: 100%|██████████| 19/19 [00:00<00:00, 179.30it/s]\n",
      "Batches: 100%|██████████| 14/14 [00:00<00:00, 164.08it/s]\n"
     ]
    }
   ],
   "source": [
    "all_models, domain2id, id2domain, num_domains, model_infos = build_domain_and_model_infos(dataset_json_train)\n",
    "train_prompts, train_best_models, train_domains = extract_from_dataset(dataset_json_train, domain2id, model_infos)\n",
    "val_prompts, val_best_models, val_domains = extract_from_dataset(dataset_json_val, domain2id, model_infos)\n",
    "test_prompts, test_best_models, test_domains = extract_from_dataset(dataset_json_test, domain2id, model_infos)\n",
    "model_embs_init, train_prompt_embs, val_prompt_embs, test_prompt_embs, train_model_idx, val_model_idx, test_model_idx, train_domain_ids, val_domain_ids, test_domain_ids, model_domains = compute_embeddings_and_indices(all_models, train_prompts, val_prompts, test_prompts, train_best_models, val_best_models, test_best_models, train_domains, val_domains, test_domains)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "7b28b191",
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_router_st(router, prompts, gold_models, gold_domains, num_samples=500):\n",
    "    n = len(prompts)\n",
    "    idxs = random.sample(range(n), min(num_samples, n))\n",
    "\n",
    "    correct_models = 0\n",
    "    correct_domains = 0\n",
    "    dev = next(router.parameters()).device\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for i in idxs:\n",
    "            x = prompts[i]\n",
    "            gm = gold_models[i]\n",
    "            gd = gold_domains[i]\n",
    "\n",
    "            h_q = router.prompt_encoder([x]).to(dev)\n",
    "            logits = router.domain_head(h_q)\n",
    "            d_pred = logits.argmax(dim=-1).item()\n",
    "            if d_pred == gd:\n",
    "                correct_domains += 1\n",
    "\n",
    "            m_pred = router.route(x, top_k_domains=1)\n",
    "            if m_pred == gm:\n",
    "                correct_models += 1\n",
    "\n",
    "    print(f\"Domain acc: {correct_domains/len(idxs):.3f}\")\n",
    "    print(f\"Routing acc: {correct_models/len(idxs):.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "d7efd935",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def eval_router_st(\n",
    "    router,\n",
    "    prompt_embs: torch.Tensor,      # [N, C] on same device as router\n",
    "    gold_model_idx: torch.Tensor,   # [N]\n",
    "    gold_domains: torch.Tensor,     # [N]\n",
    "    num_samples: int = 1000,\n",
    "    top_k_domains: int = 1,\n",
    "):\n",
    "    \"\"\"\n",
    "    Returns (domain_acc, routing_acc) on a random subset of num_samples examples.\n",
    "    \"\"\"\n",
    "    router.eval()\n",
    "    device = router.model_embeddings.device\n",
    "    N = prompt_embs.size(0)\n",
    "\n",
    "    if num_samples is None or num_samples >= N:\n",
    "        idxs = torch.arange(N, device=device)\n",
    "    else:\n",
    "        idxs = torch.randperm(N, device=device)[:num_samples]\n",
    "\n",
    "    batch_embs  = prompt_embs[idxs]        # [B, C]\n",
    "    batch_midx  = gold_model_idx[idxs]     # [B]\n",
    "    batch_doms  = gold_domains[idxs]       # [B]\n",
    "    B = batch_embs.size(0)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        # ----- domain accuracy (vectorised) -----\n",
    "        logits = router.domain_head(batch_embs)     # [B, D]\n",
    "        pred_domains = logits.argmax(dim=-1)        # [B]\n",
    "        domain_acc = (pred_domains == batch_doms).float().mean().item()\n",
    "\n",
    "        # ----- routing accuracy (per example, but no extra .to/device) -----\n",
    "        correct_models = 0\n",
    "\n",
    "        for b in range(B):\n",
    "            h_q = batch_embs[b]           # [C]\n",
    "            gm  = int(batch_midx[b].item())\n",
    "\n",
    "            m_pred = router.route(h_q, top_k_domains=top_k_domains)\n",
    "            if m_pred == gm:\n",
    "                correct_models += 1\n",
    "\n",
    "        routing_acc = correct_models / B\n",
    "\n",
    "    return domain_acc, routing_acc\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "bbe8e531",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Built from APIBench:\n",
    "all_models       # List[ModelInfo]\n",
    "num_domains      # int\n",
    "train_prompts    # List[str]\n",
    "train_best_models  # List[model_id str]\n",
    "train_domains    # List[int] (domain IDs)\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "from codebook_router import CodebookRouterST  # or however you imported it\n",
    "\n",
    "router_st = CodebookRouterST(\n",
    "    model_embs_init=model_embs_init.to(device),  # SBERT on model cards\n",
    "    model_domains=model_domains.to(device),\n",
    "    num_domains=num_domains,\n",
    "    codebook_size=64,\n",
    "    topk_codes=3,\n",
    "    vq_beta=0.25,\n",
    ").to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "e82e8c50",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = router_st.model_embeddings.device  # or torch.device(\"cuda\")\n",
    "\n",
    "train_prompt_embs = train_prompt_embs.to(device)\n",
    "train_model_idx   = train_model_idx.to(device)\n",
    "train_domain_ids  = train_domain_ids.to(device)\n",
    "\n",
    "val_prompt_embs   = val_prompt_embs.to(device)\n",
    "val_model_idx     = val_model_idx.to(device)\n",
    "val_domain_ids    = val_domain_ids.to(device)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "3b1cbb9e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Global CB] Epoch 1/50 - train_loss=4.504 - val_domain=0.687 val_model=0.148\n",
      "[Global CB] Epoch 2/50 - train_loss=4.447 - val_domain=0.687 val_model=0.131\n",
      "[Global CB] Epoch 3/50 - train_loss=4.409 - val_domain=0.676 val_model=0.138\n",
      "[Global CB] Epoch 4/50 - train_loss=4.364 - val_domain=0.688 val_model=0.138\n",
      "[Global CB] Epoch 5/50 - train_loss=4.350 - val_domain=0.683 val_model=0.142\n",
      "[Global CB] Epoch 6/50 - train_loss=4.301 - val_domain=0.683 val_model=0.136\n",
      "[Global CB] Epoch 7/50 - train_loss=4.265 - val_domain=0.686 val_model=0.133\n",
      "[Global CB] Epoch 8/50 - train_loss=4.234 - val_domain=0.679 val_model=0.133\n",
      "[Global CB] Epoch 9/50 - train_loss=4.218 - val_domain=0.670 val_model=0.137\n",
      "[Global CB] Epoch 10/50 - train_loss=4.184 - val_domain=0.665 val_model=0.141\n",
      "[Global CB] Epoch 11/50 - train_loss=4.161 - val_domain=0.674 val_model=0.119\n",
      "Early stopping (global).\n"
     ]
    }
   ],
   "source": [
    "import torch, math\n",
    "\n",
    "optimizer_global = torch.optim.Adam(\n",
    "    router_st.parameters(),\n",
    "    lr=1e-3,\n",
    "    weight_decay=1e-4,\n",
    ")\n",
    "\n",
    "num_epochs = 50\n",
    "batch_size = 64\n",
    "N_train = train_prompt_embs.size(0)\n",
    "\n",
    "best_val_model_acc = -math.inf\n",
    "best_state_global = None\n",
    "patience = 10\n",
    "epochs_no_improve = 0\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    perm = torch.randperm(N_train, device=device)\n",
    "    total_loss, num_batches = 0.0, 0\n",
    "\n",
    "    for start in range(0, N_train, batch_size):\n",
    "        batch_idx = perm[start:start+batch_size]\n",
    "        if batch_idx.numel() == 0:\n",
    "            continue\n",
    "\n",
    "        b_embs = train_prompt_embs[batch_idx]\n",
    "        b_midx = train_model_idx[batch_idx]\n",
    "        b_dom  = train_domain_ids[batch_idx]\n",
    "\n",
    "        loss = router_st.training_step(\n",
    "            batch_prompt_embs=b_embs,\n",
    "            batch_model_idx=b_midx,\n",
    "            batch_domains=b_dom,\n",
    "            optimizer=optimizer_global,\n",
    "            domain_loss_weight=0.1,   # tune if you like\n",
    "            vq_loss_weight=0.01,      # small regulariser\n",
    "        )\n",
    "        if loss > 0:\n",
    "            total_loss += loss\n",
    "            num_batches += 1\n",
    "\n",
    "    avg_loss = total_loss / num_batches if num_batches > 0 else float(\"nan\")\n",
    "\n",
    "    # Validation on full val set\n",
    "    val_dom_acc, val_model_acc = eval_router_st(\n",
    "        router_st,\n",
    "        val_prompt_embs,\n",
    "        val_model_idx,\n",
    "        val_domain_ids,\n",
    "        num_samples=1000,        # or None if small\n",
    "        top_k_domains=1,\n",
    "    )\n",
    "\n",
    "    print(\n",
    "        f\"[Global CB] Epoch {epoch+1}/{num_epochs} \"\n",
    "        f\"- train_loss={avg_loss:.3f} \"\n",
    "        f\"- val_domain={val_dom_acc:.3f} val_model={val_model_acc:.3f}\"\n",
    "    )\n",
    "\n",
    "    # Early stopping on val model acc\n",
    "    if val_model_acc > best_val_model_acc + 1e-5:\n",
    "        best_val_model_acc = val_model_acc\n",
    "        best_state_global = {k: v.detach().cpu().clone() for k, v in router_st.state_dict().items()}\n",
    "        epochs_no_improve = 0\n",
    "    else:\n",
    "        epochs_no_improve += 1\n",
    "        if epochs_no_improve >= patience:\n",
    "            print(\"Early stopping (global).\")\n",
    "            break\n",
    "\n",
    "# Restore best global checkpoint\n",
    "if best_state_global is not None:\n",
    "    router_st.load_state_dict(best_state_global)\n",
    "    router_st.to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "ff7d0bdc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Global Train: 0.7070000171661377 0.263\n",
      "Global Val: 0.690000057220459 0.139\n",
      "Global Test: 0.6286043524742126 0.11534025374855825\n"
     ]
    }
   ],
   "source": [
    "train_dom, train_model = eval_router_st(router_st, train_prompt_embs, train_model_idx, train_domain_ids, num_samples=1000)\n",
    "val_dom,   val_model   = eval_router_st(router_st, val_prompt_embs,   val_model_idx,   val_domain_ids,   num_samples=1000)\n",
    "test_dom,  test_model  = eval_router_st(router_st, test_prompt_embs,  test_model_idx,  test_domain_ids,  num_samples=None)\n",
    "print(\"Global Train:\", train_dom, train_model)\n",
    "print(\"Global Val:\",   val_dom,   val_model)\n",
    "print(\"Global Test:\",  test_dom,  test_model)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "46dec8a5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train   domain/model: 0.6710000038146973 0.308\n",
      "Val     domain/model: 0.6770000457763672 0.151\n",
      "Test    domain/model: 0.6286043524742126 0.12110726643598616\n"
     ]
    }
   ],
   "source": [
    "train_dom_acc, train_model_acc = eval_router_st(\n",
    "    router_st, train_prompt_embs, train_model_idx, train_domain_ids, top_k_domains=2\n",
    ")\n",
    "val_dom_acc, val_model_acc = eval_router_st(\n",
    "    router_st, val_prompt_embs, val_model_idx, val_domain_ids, top_k_domains=2\n",
    ")\n",
    "test_dom_acc, test_model_acc = eval_router_st(\n",
    "    router_st, test_prompt_embs, test_model_idx, test_domain_ids, top_k_domains=2\n",
    ")\n",
    "\n",
    "print(\"Train   domain/model:\", train_dom_acc, train_model_acc)\n",
    "print(\"Val     domain/model:\", val_dom_acc,   val_model_acc)\n",
    "print(\"Test    domain/model:\", test_dom_acc,  test_model_acc)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "bbbff7b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_config = MLLMDataConfig()\n",
    "dataset_json_train_2 = load_dataset_json(dataset_config.train_set)\n",
    "dataset_json_val_2 = load_dataset_json(dataset_config.val_set)\n",
    "dataset_json_test_2 = load_dataset_json(dataset_config.test_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "d6a10dd7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num domains: 35\n",
      "Num unique models: 481\n"
     ]
    }
   ],
   "source": [
    "all_models_2, domain2id_2, id2domain_2, num_domains_2, model_infos_2 = build_domain_and_model_infos(dataset_json_train_2, experience=\"APIBench\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "0a8a4c79",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num domains: 35\n",
      "Num unique models: 481\n",
      "Num training examples: 3894\n",
      "Num training examples: 688\n",
      "Num training examples: 809\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Batches: 100%|██████████| 8/8 [00:00<00:00, 56.08it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model emb dim: 384\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Batches: 100%|██████████| 61/61 [00:00<00:00, 107.73it/s]\n",
      "Batches: 100%|██████████| 11/11 [00:00<00:00, 95.66it/s]\n",
      "Batches: 100%|██████████| 13/13 [00:00<00:00, 104.88it/s]\n"
     ]
    }
   ],
   "source": [
    "all_models_2, domain2id_2, id2domain_2, num_domains_2, model_infos_2 = build_domain_and_model_infos(dataset_json_train_2)\n",
    "train_prompts_2, train_best_models_2, train_domains_2 = extract_from_dataset(dataset_json_train_2, domain2id_2, model_infos_2)\n",
    "val_prompts_2, val_best_models_2, val_domains_2 = extract_from_dataset(dataset_json_val_2, domain2id_2, model_infos_2)\n",
    "test_prompts_2, test_best_models_2, test_domains_2 = extract_from_dataset(dataset_json_test_2, domain2id_2, model_infos_2)\n",
    "model_embs_init_2, train_prompt_embs_2, val_prompt_embs_2, test_prompt_embs_2, train_model_idx_2, val_model_idx_2, test_model_idx_2, train_domain_ids_2, val_domain_ids_2, test_domain_ids_2, model_domains_2 = compute_embeddings_and_indices(all_models_2, train_prompts_2, val_prompts_2, test_prompts_2, train_best_models_2, val_best_models_2, test_best_models_2, train_domains_2, val_domains_2, test_domains_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "1ef54da4",
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_params = {\n",
    "    name: p.detach().clone()\n",
    "    for name, p in router_st.named_parameters()\n",
    "    if p.requires_grad  # should be codebook + query_mlp only\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "ccf1ed7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "lambda_reg      = 1e-3   # \"how hard\" we try not to move\n",
    "domain_loss_w   = 0.05   # slightly smaller than 0.1 during adaptation\n",
    "vq_loss_w       = 0.005  # smaller VQ weight for this phase\n",
    "\n",
    "optimizer_e2 = torch.optim.Adam(\n",
    "    [p for p in router_st.parameters() if p.requires_grad],\n",
    "    lr=5e-4,\n",
    "    weight_decay=0.0,     # turn off weight_decay; we use explicit reg instead\n",
    ")\n",
    "\n",
    "def e2_step(router, b_embs, b_midx, b_dom, optimizer):\n",
    "    router.train()\n",
    "    dev = router.model_embeddings.device\n",
    "\n",
    "    # 1) Codebook-augmented representation\n",
    "    h, vq_loss = router._encode_with_codebook(b_embs.to(dev))  # [B, C]\n",
    "\n",
    "    m_idx = b_midx.to(dev)\n",
    "    d_idx = b_dom.to(dev)\n",
    "\n",
    "    # 2) Domain loss\n",
    "    logits_dom = router.domain_head(h)        # [B, D]\n",
    "    domain_loss = F.cross_entropy(logits_dom, d_idx)\n",
    "\n",
    "    # 3) Model loss over all models\n",
    "    scores_all = h @ router.model_embeddings.T  # [B, M]\n",
    "    ce_loss = F.cross_entropy(scores_all, m_idx)\n",
    "\n",
    "    # 4) Regularisation: stay close to pre-E2 params\n",
    "    reg_loss = 0.0\n",
    "    for name, p in router.named_parameters():\n",
    "        if not p.requires_grad:\n",
    "            continue\n",
    "        reg_loss = reg_loss + (p - orig_params[name].to(dev)).pow(2).sum()\n",
    "\n",
    "    total_loss = (\n",
    "        ce_loss\n",
    "        + domain_loss_w * domain_loss\n",
    "        + vq_loss_w * vq_loss\n",
    "        + lambda_reg * reg_loss\n",
    "    )\n",
    "\n",
    "    optimizer.zero_grad()\n",
    "    total_loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    return float(total_loss.item())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "4a8d5dc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "lambda_reg      = 1e-3   # \"how hard\" we try not to move\n",
    "domain_loss_w   = 0.05   # slightly smaller than 0.1 during adaptation\n",
    "vq_loss_w       = 0.005  # smaller VQ weight for this phase\n",
    "\n",
    "optimizer_e2 = torch.optim.Adam(\n",
    "    [p for p in router_st.parameters() if p.requires_grad],\n",
    "    lr=5e-4,\n",
    "    weight_decay=0.0,     # turn off weight_decay; we use explicit reg instead\n",
    ")\n",
    "\n",
    "def e2_step(router, b_embs, b_midx, b_dom, optimizer):\n",
    "    router.train()\n",
    "    dev = router.model_embeddings.device\n",
    "\n",
    "    # 1) Codebook-augmented representation\n",
    "    h, vq_loss = router._encode_with_codebook(b_embs.to(dev))  # [B, C]\n",
    "\n",
    "    m_idx = b_midx.to(dev)\n",
    "    d_idx = b_dom.to(dev)\n",
    "\n",
    "    # 2) Domain loss\n",
    "    logits_dom = router.domain_head(h)        # [B, D]\n",
    "    domain_loss = F.cross_entropy(logits_dom, d_idx)\n",
    "\n",
    "    # 3) Model loss over all models\n",
    "    scores_all = h @ router.model_embeddings.T  # [B, M]\n",
    "    ce_loss = F.cross_entropy(scores_all, m_idx)\n",
    "\n",
    "    # 4) Regularisation: stay close to pre-E2 params\n",
    "    reg_loss = 0.0\n",
    "    for name, p in router.named_parameters():\n",
    "        if not p.requires_grad:\n",
    "            continue\n",
    "        reg_loss = reg_loss + (p - orig_params[name].to(dev)).pow(2).sum()\n",
    "\n",
    "    total_loss = (\n",
    "        ce_loss\n",
    "        + domain_loss_w * domain_loss\n",
    "        + vq_loss_w * vq_loss\n",
    "        + lambda_reg * reg_loss\n",
    "    )\n",
    "\n",
    "    optimizer.zero_grad()\n",
    "    total_loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    return float(total_loss.item())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "b0216d86",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[E2 no-replay] Epoch 1/30 - train_loss=7.234 - E2_val_domain=0.211 E2_val_model=0.001\n",
      "[E2 no-replay] Epoch 2/30 - train_loss=7.195 - E2_val_domain=0.211 E2_val_model=0.001\n",
      "[E2 no-replay] Epoch 3/30 - train_loss=7.158 - E2_val_domain=0.211 E2_val_model=0.000\n",
      "[E2 no-replay] Epoch 4/30 - train_loss=7.123 - E2_val_domain=0.211 E2_val_model=0.001\n",
      "[E2 no-replay] Epoch 5/30 - train_loss=7.087 - E2_val_domain=0.211 E2_val_model=0.001\n",
      "[E2 no-replay] Epoch 6/30 - train_loss=7.060 - E2_val_domain=0.211 E2_val_model=0.001\n",
      "[E2 no-replay] Epoch 7/30 - train_loss=7.033 - E2_val_domain=0.211 E2_val_model=0.001\n",
      "[E2 no-replay] Epoch 8/30 - train_loss=7.009 - E2_val_domain=0.211 E2_val_model=0.001\n",
      "[E2 no-replay] Epoch 9/30 - train_loss=7.000 - E2_val_domain=0.211 E2_val_model=0.001\n",
      "[E2 no-replay] Epoch 10/30 - train_loss=6.993 - E2_val_domain=0.211 E2_val_model=0.001\n",
      "[E2 no-replay] Epoch 11/30 - train_loss=6.982 - E2_val_domain=0.211 E2_val_model=0.001\n",
      "[E2 no-replay] Epoch 12/30 - train_loss=6.972 - E2_val_domain=0.211 E2_val_model=0.001\n",
      "[E2 no-replay] Epoch 13/30 - train_loss=6.962 - E2_val_domain=0.211 E2_val_model=0.001\n",
      "[E2 no-replay] Epoch 14/30 - train_loss=6.950 - E2_val_domain=0.211 E2_val_model=0.003\n",
      "[E2 no-replay] Epoch 15/30 - train_loss=6.942 - E2_val_domain=0.211 E2_val_model=0.001\n",
      "[E2 no-replay] Epoch 16/30 - train_loss=6.938 - E2_val_domain=0.211 E2_val_model=0.001\n",
      "[E2 no-replay] Epoch 17/30 - train_loss=6.932 - E2_val_domain=0.211 E2_val_model=0.000\n",
      "[E2 no-replay] Epoch 18/30 - train_loss=6.926 - E2_val_domain=0.211 E2_val_model=0.000\n",
      "[E2 no-replay] Epoch 19/30 - train_loss=6.922 - E2_val_domain=0.211 E2_val_model=0.000\n",
      "[E2 no-replay] Epoch 20/30 - train_loss=6.921 - E2_val_domain=0.211 E2_val_model=0.000\n",
      "[E2 no-replay] Epoch 21/30 - train_loss=6.916 - E2_val_domain=0.211 E2_val_model=0.000\n",
      "[E2 no-replay] Epoch 22/30 - train_loss=6.914 - E2_val_domain=0.211 E2_val_model=0.000\n",
      "[E2 no-replay] Epoch 23/30 - train_loss=6.912 - E2_val_domain=0.211 E2_val_model=0.000\n",
      "[E2 no-replay] Epoch 24/30 - train_loss=6.913 - E2_val_domain=0.211 E2_val_model=0.000\n",
      "[E2 no-replay] Epoch 25/30 - train_loss=6.911 - E2_val_domain=0.211 E2_val_model=0.000\n",
      "[E2 no-replay] Epoch 26/30 - train_loss=6.910 - E2_val_domain=0.211 E2_val_model=0.000\n",
      "[E2 no-replay] Epoch 27/30 - train_loss=6.909 - E2_val_domain=0.211 E2_val_model=0.000\n",
      "[E2 no-replay] Epoch 28/30 - train_loss=6.909 - E2_val_domain=0.211 E2_val_model=0.000\n",
      "[E2 no-replay] Epoch 29/30 - train_loss=6.912 - E2_val_domain=0.211 E2_val_model=0.000\n",
      "[E2 no-replay] Epoch 30/30 - train_loss=6.913 - E2_val_domain=0.211 E2_val_model=0.000\n"
     ]
    }
   ],
   "source": [
    "# Make sure these are your E2-only tensors\n",
    "N_e2 = train_prompt_embs_2.size(0)\n",
    "batch_size_e2  = 64\n",
    "num_epochs_e2  = 30\n",
    "patience_e2    = 30\n",
    "\n",
    "best_val_model_e2 = -math.inf\n",
    "best_state_e2     = None\n",
    "epochs_no_improve = 0\n",
    "\n",
    "for epoch in range(num_epochs_e2):\n",
    "    perm = torch.randperm(N_e2, device=device)\n",
    "    total_loss, num_batches = 0.0, 0\n",
    "\n",
    "    for start in range(0, N_e2, batch_size_e2):\n",
    "        batch_idx = perm[start:start+batch_size_e2]\n",
    "        if batch_idx.numel() == 0:\n",
    "            continue\n",
    "\n",
    "        b_embs = train_prompt_embs_2[batch_idx]\n",
    "        b_midx = train_model_idx_2[batch_idx]\n",
    "        b_dom  = train_domain_ids_2[batch_idx]\n",
    "\n",
    "        loss = e2_step(router_st, b_embs, b_midx, b_dom, optimizer_e2)\n",
    "        total_loss += loss\n",
    "        num_batches += 1\n",
    "\n",
    "    avg_loss = total_loss / num_batches if num_batches > 0 else float(\"nan\")\n",
    "\n",
    "    # Evaluate on *E2 val only*\n",
    "    val_dom_e2, val_model_e2 = eval_router_st(\n",
    "        router_st,\n",
    "        val_prompt_embs_2,\n",
    "        val_model_idx_2,\n",
    "        val_domain_ids_2,\n",
    "        num_samples=None,\n",
    "        top_k_domains=1,\n",
    "    )\n",
    "\n",
    "    print(\n",
    "        f\"[E2 no-replay] Epoch {epoch+1}/{num_epochs_e2} \"\n",
    "        f\"- train_loss={avg_loss:.3f} \"\n",
    "        f\"- E2_val_domain={val_dom_e2:.3f} E2_val_model={val_model_e2:.3f}\"\n",
    "    )\n",
    "\n",
    "    if val_model_e2 > best_val_model_e2 + 1e-3:\n",
    "        best_val_model_e2 = val_model_e2\n",
    "        best_state_e2 = {\n",
    "            k: v.detach().cpu().clone()\n",
    "            for k, v in router_st.state_dict().items()\n",
    "        }\n",
    "        epochs_no_improve = 0\n",
    "    else:\n",
    "        epochs_no_improve += 1\n",
    "        if epochs_no_improve >= patience_e2:\n",
    "            print(\"Early stopping on E2 (no replay, codebook+reg).\")\n",
    "            break\n",
    "\n",
    "# Restore best E2-adapted checkpoint\n",
    "if best_state_e2 is not None:\n",
    "    router_st.load_state_dict(best_state_e2)\n",
    "    router_st.to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "6c3440a2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "E2 Train   domain/model: 0.20107857882976532 0.001027221366204417\n",
      "E2 Val     domain/model: 0.21075581014156342 0.0029069767441860465\n",
      "E2 Test    domain/model: 0.19901111721992493 0.0\n"
     ]
    }
   ],
   "source": [
    "train_dom_e2, train_model_e2 = eval_router_st(\n",
    "    router_st,\n",
    "    train_prompt_embs_2,\n",
    "    train_model_idx_2,\n",
    "    train_domain_ids_2,\n",
    "    num_samples=None,        # use all E2 examples\n",
    "    top_k_domains=1,\n",
    ")\n",
    "print(\"E2 Train   domain/model:\", train_dom_e2, train_model_e2)\n",
    "\n",
    "# On E2 val subset\n",
    "val_dom_e2, val_model_e2 = eval_router_st(\n",
    "    router_st,\n",
    "    val_prompt_embs_2,\n",
    "    val_model_idx_2,\n",
    "    val_domain_ids_2,\n",
    "    num_samples=None,\n",
    "    top_k_domains=1,\n",
    ")\n",
    "print(\"E2 Val     domain/model:\", val_dom_e2, val_model_e2)\n",
    "\n",
    "# On E2 test subset (this is the comparison you care about)\n",
    "test_dom_e2, test_model_e2 = eval_router_st(\n",
    "    router_st,\n",
    "    test_prompt_embs_2,\n",
    "    test_model_idx_2,\n",
    "    test_domain_ids_2,\n",
    "    num_samples=None,\n",
    "    top_k_domains=1,\n",
    ")\n",
    "print(\"E2 Test    domain/model:\", test_dom_e2, test_model_e2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "3d18a5e1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[E2 finetune] Epoch 1/30 - train_loss=7.050 - E2_val_domain=0.265 E2_val_model=0.003\n",
      "[E2 finetune] Epoch 2/30 - train_loss=6.895 - E2_val_domain=0.273 E2_val_model=0.006\n",
      "[E2 finetune] Epoch 3/30 - train_loss=6.761 - E2_val_domain=0.285 E2_val_model=0.006\n",
      "[E2 finetune] Epoch 4/30 - train_loss=6.639 - E2_val_domain=0.295 E2_val_model=0.010\n",
      "[E2 finetune] Epoch 5/30 - train_loss=6.530 - E2_val_domain=0.311 E2_val_model=0.010\n",
      "[E2 finetune] Epoch 6/30 - train_loss=6.429 - E2_val_domain=0.339 E2_val_model=0.009\n",
      "[E2 finetune] Epoch 7/30 - train_loss=6.335 - E2_val_domain=0.369 E2_val_model=0.015\n",
      "[E2 finetune] Epoch 8/30 - train_loss=6.248 - E2_val_domain=0.394 E2_val_model=0.015\n",
      "[E2 finetune] Epoch 9/30 - train_loss=6.165 - E2_val_domain=0.417 E2_val_model=0.019\n",
      "[E2 finetune] Epoch 10/30 - train_loss=6.087 - E2_val_domain=0.435 E2_val_model=0.019\n",
      "[E2 finetune] Epoch 11/30 - train_loss=6.012 - E2_val_domain=0.449 E2_val_model=0.020\n",
      "[E2 finetune] Epoch 12/30 - train_loss=5.940 - E2_val_domain=0.458 E2_val_model=0.020\n",
      "[E2 finetune] Epoch 13/30 - train_loss=5.871 - E2_val_domain=0.465 E2_val_model=0.022\n",
      "[E2 finetune] Epoch 14/30 - train_loss=5.804 - E2_val_domain=0.468 E2_val_model=0.022\n",
      "[E2 finetune] Epoch 15/30 - train_loss=5.739 - E2_val_domain=0.481 E2_val_model=0.020\n",
      "[E2 finetune] Epoch 16/30 - train_loss=5.676 - E2_val_domain=0.493 E2_val_model=0.020\n",
      "[E2 finetune] Epoch 17/30 - train_loss=5.615 - E2_val_domain=0.513 E2_val_model=0.026\n",
      "[E2 finetune] Epoch 18/30 - train_loss=5.557 - E2_val_domain=0.532 E2_val_model=0.025\n",
      "[E2 finetune] Epoch 19/30 - train_loss=5.498 - E2_val_domain=0.538 E2_val_model=0.025\n",
      "[E2 finetune] Epoch 20/30 - train_loss=5.442 - E2_val_domain=0.545 E2_val_model=0.025\n",
      "[E2 finetune] Epoch 21/30 - train_loss=5.386 - E2_val_domain=0.560 E2_val_model=0.023\n",
      "[E2 finetune] Epoch 22/30 - train_loss=5.333 - E2_val_domain=0.560 E2_val_model=0.023\n",
      "Early stopping on E2 val model acc.\n"
     ]
    }
   ],
   "source": [
    "import math, torch\n",
    "\n",
    "device = router_st.model_embeddings.device\n",
    "\n",
    "# Small LR so we don't destroy previous knowledge immediately\n",
    "optimizer_e2 = torch.optim.Adam(\n",
    "    router_st.parameters(),\n",
    "    lr=5e-4,\n",
    "    weight_decay=1e-4,\n",
    ")\n",
    "\n",
    "num_epochs_e2 = 30\n",
    "batch_size_e2 = 64\n",
    "N_e2 = train_prompt_embs_2.size(0)\n",
    "\n",
    "best_val_model_e2 = -math.inf\n",
    "best_state_e2 = None\n",
    "patience_e2 = 5\n",
    "epochs_no_improve_e2 = 0\n",
    "\n",
    "for epoch in range(num_epochs_e2):\n",
    "    # Shuffle E2 train\n",
    "    perm = torch.randperm(N_e2, device=device)\n",
    "    total_loss, num_batches = 0.0, 0\n",
    "\n",
    "    for start in range(0, N_e2, batch_size_e2):\n",
    "        batch_idx = perm[start:start+batch_size_e2]\n",
    "        if batch_idx.numel() == 0:\n",
    "            continue\n",
    "\n",
    "        b_embs = train_prompt_embs_2[batch_idx]\n",
    "        b_midx = train_model_idx_2[batch_idx]\n",
    "        b_dom  = train_domain_ids_2[batch_idx]\n",
    "\n",
    "        loss = router_st.training_step(\n",
    "            batch_prompt_embs=b_embs,\n",
    "            batch_model_idx=b_midx,\n",
    "            batch_domains=b_dom,\n",
    "            optimizer=optimizer_e2,\n",
    "            domain_loss_weight=0.1,  # or 0.05 if you want domain more stable\n",
    "        )\n",
    "        if loss > 0:\n",
    "            total_loss += loss\n",
    "            num_batches += 1\n",
    "\n",
    "    avg_loss = total_loss / num_batches if num_batches > 0 else float(\"nan\")\n",
    "\n",
    "    # Eval on E2 val\n",
    "    val_dom_e2, val_model_e2 = eval_router_st(\n",
    "        router_st,\n",
    "        val_prompt_embs_2,\n",
    "        val_model_idx_2,\n",
    "        val_domain_ids_2,\n",
    "        num_samples=None,\n",
    "        top_k_domains=1,\n",
    "    )\n",
    "\n",
    "    print(\n",
    "        f\"[E2 finetune] Epoch {epoch+1}/{num_epochs_e2} \"\n",
    "        f\"- train_loss={avg_loss:.3f} \"\n",
    "        f\"- E2_val_domain={val_dom_e2:.3f} E2_val_model={val_model_e2:.3f}\"\n",
    "    )\n",
    "\n",
    "    if val_model_e2 > best_val_model_e2 + 1e-3:\n",
    "        best_val_model_e2 = val_model_e2\n",
    "        best_state_e2 = {k: v.detach().cpu().clone() for k, v in router_st.state_dict().items()}\n",
    "        epochs_no_improve_e2 = 0\n",
    "    else:\n",
    "        epochs_no_improve_e2 += 1\n",
    "        if epochs_no_improve_e2 >= patience_e2:\n",
    "            print(\"Early stopping on E2 val model acc.\")\n",
    "            break\n",
    "\n",
    "# Restore best E2 checkpoint\n",
    "if best_state_e2 is not None:\n",
    "    router_st.load_state_dict(best_state_e2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "f6a52a44",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== After E2 fine-tune ===\n",
      "E2 Train   domain/model: 0.5295326113700867 0.02953261427837699\n",
      "E2 Val     domain/model: 0.5130813717842102 0.02616279069767442\n",
      "E2 Test    domain/model: 0.5018541216850281 0.022249690976514216\n",
      "=== Per-experience test after E2 fine-tune ===\n",
      "Train   domain/model: 0.44200003147125244 0.051\n",
      "Val     domain/model: 0.44700002670288086 0.041\n",
      "Test    domain/model: 0.4498269855976105 0.03690888119953864\n"
     ]
    }
   ],
   "source": [
    "print(\"=== After E2 fine-tune ===\")\n",
    "\n",
    "# E2 only\n",
    "train_dom_e2_ft, train_model_e2_ft = eval_router_st(\n",
    "    router_st, train_prompt_embs_2, train_model_idx_2, train_domain_ids_2, num_samples=None\n",
    ")\n",
    "val_dom_e2_ft, val_model_e2_ft = eval_router_st(\n",
    "    router_st, val_prompt_embs_2, val_model_idx_2, val_domain_ids_2, num_samples=None\n",
    ")\n",
    "test_dom_e2_ft, test_model_e2_ft = eval_router_st(\n",
    "    router_st, test_prompt_embs_2, test_model_idx_2, test_domain_ids_2, num_samples=None\n",
    ")\n",
    "\n",
    "print(\"E2 Train   domain/model:\", train_dom_e2_ft, train_model_e2_ft)\n",
    "print(\"E2 Val     domain/model:\", val_dom_e2_ft,   val_model_e2_ft)\n",
    "print(\"E2 Test    domain/model:\", test_dom_e2_ft,  test_model_e2_ft)\n",
    "\n",
    "\n",
    "# Other experiences (forgetting)\n",
    "print(\"=== Per-experience test after E2 fine-tune ===\")\n",
    "train_dom_acc, train_model_acc = eval_router_st(\n",
    "    router_st, train_prompt_embs, train_model_idx, train_domain_ids, top_k_domains=2\n",
    ")\n",
    "val_dom_acc, val_model_acc = eval_router_st(\n",
    "    router_st, val_prompt_embs, val_model_idx, val_domain_ids, top_k_domains=2\n",
    ")\n",
    "test_dom_acc, test_model_acc = eval_router_st(\n",
    "    router_st, test_prompt_embs, test_model_idx, test_domain_ids, top_k_domains=2\n",
    ")\n",
    "\n",
    "print(\"Train   domain/model:\", train_dom_acc, train_model_acc)\n",
    "print(\"Val     domain/model:\", val_dom_acc,   val_model_acc)\n",
    "print(\"Test    domain/model:\", test_dom_acc,  test_model_acc)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "dbd877c6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Domain pretrain] Epoch 1/50 - avg loss 0.2844\n",
      "[Domain pretrain] Epoch 2/50 - avg loss 0.2739\n",
      "[Domain pretrain] Epoch 3/50 - avg loss 0.2619\n",
      "[Domain pretrain] Epoch 4/50 - avg loss 0.2545\n",
      "[Domain pretrain] Epoch 5/50 - avg loss 0.2424\n",
      "[Domain pretrain] Epoch 6/50 - avg loss 0.2370\n",
      "[Domain pretrain] Epoch 7/50 - avg loss 0.2290\n",
      "[Domain pretrain] Epoch 8/50 - avg loss 0.2200\n",
      "[Domain pretrain] Epoch 9/50 - avg loss 0.2117\n",
      "[Domain pretrain] Epoch 10/50 - avg loss 0.2045\n",
      "[Domain pretrain] Epoch 11/50 - avg loss 0.2013\n",
      "[Domain pretrain] Epoch 12/50 - avg loss 0.1964\n",
      "[Domain pretrain] Epoch 13/50 - avg loss 0.1873\n",
      "[Domain pretrain] Epoch 14/50 - avg loss 0.1788\n",
      "[Domain pretrain] Epoch 15/50 - avg loss 0.1738\n",
      "[Domain pretrain] Epoch 16/50 - avg loss 0.1684\n",
      "[Domain pretrain] Epoch 17/50 - avg loss 0.1650\n",
      "[Domain pretrain] Epoch 18/50 - avg loss 0.1596\n",
      "[Domain pretrain] Epoch 19/50 - avg loss 0.1534\n",
      "[Domain pretrain] Epoch 20/50 - avg loss 0.1481\n",
      "[Domain pretrain] Epoch 21/50 - avg loss 0.1422\n",
      "[Domain pretrain] Epoch 22/50 - avg loss 0.1392\n",
      "[Domain pretrain] Epoch 23/50 - avg loss 0.1343\n",
      "[Domain pretrain] Epoch 24/50 - avg loss 0.1311\n",
      "[Domain pretrain] Epoch 25/50 - avg loss 0.1251\n",
      "[Domain pretrain] Epoch 26/50 - avg loss 0.1229\n",
      "[Domain pretrain] Epoch 27/50 - avg loss 0.1179\n",
      "[Domain pretrain] Epoch 28/50 - avg loss 0.1149\n",
      "[Domain pretrain] Epoch 29/50 - avg loss 0.1104\n",
      "[Domain pretrain] Epoch 30/50 - avg loss 0.1075\n",
      "[Domain pretrain] Epoch 31/50 - avg loss 0.1052\n",
      "[Domain pretrain] Epoch 32/50 - avg loss 0.1010\n",
      "[Domain pretrain] Epoch 33/50 - avg loss 0.0974\n",
      "[Domain pretrain] Epoch 34/50 - avg loss 0.0951\n",
      "[Domain pretrain] Epoch 35/50 - avg loss 0.0922\n",
      "[Domain pretrain] Epoch 36/50 - avg loss 0.0894\n",
      "[Domain pretrain] Epoch 37/50 - avg loss 0.0878\n",
      "[Domain pretrain] Epoch 38/50 - avg loss 0.0838\n",
      "[Domain pretrain] Epoch 39/50 - avg loss 0.0821\n",
      "[Domain pretrain] Epoch 40/50 - avg loss 0.0791\n",
      "[Domain pretrain] Epoch 41/50 - avg loss 0.0756\n",
      "[Domain pretrain] Epoch 42/50 - avg loss 0.0745\n",
      "[Domain pretrain] Epoch 43/50 - avg loss 0.0716\n",
      "[Domain pretrain] Epoch 44/50 - avg loss 0.0711\n",
      "[Domain pretrain] Epoch 45/50 - avg loss 0.0679\n",
      "[Domain pretrain] Epoch 46/50 - avg loss 0.0640\n",
      "[Domain pretrain] Epoch 47/50 - avg loss 0.0631\n",
      "[Domain pretrain] Epoch 48/50 - avg loss 0.0605\n",
      "[Domain pretrain] Epoch 49/50 - avg loss 0.0594\n",
      "[Domain pretrain] Epoch 50/50 - avg loss 0.0582\n",
      "(0.997, 0.048)\n",
      "(0.732, 0.03)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import random\n",
    "\n",
    "def pretrain_domain(router, prompts, domains,\n",
    "                    num_epochs=5, batch_size=128, lr=1e-3):\n",
    "    device = next(router.parameters()).device\n",
    "\n",
    "    # Only prompt encoder + domain head\n",
    "    params = list(router.prompt_encoder.parameters()) + \\\n",
    "             list(router.domain_head.parameters())\n",
    "    optimizer = torch.optim.Adam(params, lr=lr)\n",
    "\n",
    "    n = len(prompts)\n",
    "    indices = list(range(n))\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        random.shuffle(indices)\n",
    "        total_loss, num_batches = 0.0, 0\n",
    "\n",
    "        for start in range(0, n, batch_size):\n",
    "            batch_idx    = indices[start:start+batch_size]\n",
    "            batch_prompts = [prompts[i] for i in batch_idx]\n",
    "            batch_domains = [domains[i] for i in batch_idx]\n",
    "\n",
    "            router.train()\n",
    "            h_q = router.prompt_encoder(batch_prompts).to(device)    # [B, C]\n",
    "            logits = router.domain_head(h_q)                         # [B, D]\n",
    "            labels = torch.tensor(batch_domains, dtype=torch.long, device=device)\n",
    "\n",
    "            loss = F.cross_entropy(logits, labels)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            total_loss += float(loss.item())\n",
    "            num_batches += 1\n",
    "\n",
    "        avg_loss = total_loss / num_batches\n",
    "        print(f\"[Domain pretrain] Epoch {epoch+1}/{num_epochs} - avg loss {avg_loss:.4f}\")\n",
    "\n",
    "# run domain pretraining\n",
    "pretrain_domain(router, train_prompts, train_domains,\n",
    "                num_epochs=50, batch_size=128, lr=1e-3)\n",
    "\n",
    "print(eval_router_reg(router, train_prompts, train_best_models, train_domains))\n",
    "print(eval_router_reg(router, val_prompts, val_best_models, val_domains))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "f0d7b046",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/50 - train_loss=3.459 - val_domain=0.708 val_model=0.060\n",
      "Epoch 2/50 - train_loss=3.273 - val_domain=0.708 val_model=0.067\n",
      "Epoch 3/50 - train_loss=3.180 - val_domain=0.698 val_model=0.063\n",
      "Epoch 4/50 - train_loss=3.084 - val_domain=0.706 val_model=0.064\n",
      "Epoch 5/50 - train_loss=2.985 - val_domain=0.690 val_model=0.068\n",
      "Epoch 6/50 - train_loss=2.924 - val_domain=0.692 val_model=0.071\n",
      "Epoch 7/50 - train_loss=2.858 - val_domain=0.693 val_model=0.066\n",
      "Epoch 8/50 - train_loss=2.791 - val_domain=0.674 val_model=0.070\n",
      "Epoch 9/50 - train_loss=2.730 - val_domain=0.681 val_model=0.074\n",
      "Epoch 10/50 - train_loss=2.697 - val_domain=0.674 val_model=0.070\n",
      "Epoch 11/50 - train_loss=2.651 - val_domain=0.662 val_model=0.076\n",
      "Epoch 12/50 - train_loss=2.614 - val_domain=0.652 val_model=0.066\n",
      "Epoch 13/50 - train_loss=2.570 - val_domain=0.647 val_model=0.068\n",
      "Epoch 14/50 - train_loss=2.569 - val_domain=0.640 val_model=0.070\n",
      "Epoch 15/50 - train_loss=2.531 - val_domain=0.626 val_model=0.077\n",
      "Epoch 16/50 - train_loss=2.516 - val_domain=0.643 val_model=0.079\n",
      "Epoch 17/50 - train_loss=2.499 - val_domain=0.620 val_model=0.069\n",
      "Epoch 18/50 - train_loss=2.489 - val_domain=0.613 val_model=0.077\n",
      "Epoch 19/50 - train_loss=2.490 - val_domain=0.609 val_model=0.077\n",
      "Epoch 20/50 - train_loss=2.478 - val_domain=0.602 val_model=0.073\n",
      "Epoch 21/50 - train_loss=2.475 - val_domain=0.603 val_model=0.073\n",
      "Epoch 22/50 - train_loss=2.472 - val_domain=0.577 val_model=0.077\n",
      "Epoch 23/50 - train_loss=2.467 - val_domain=0.573 val_model=0.066\n",
      "Epoch 24/50 - train_loss=2.474 - val_domain=0.556 val_model=0.075\n",
      "Epoch 25/50 - train_loss=2.471 - val_domain=0.553 val_model=0.068\n",
      "Epoch 26/50 - train_loss=2.478 - val_domain=0.543 val_model=0.073\n",
      "Early stopping triggered.\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "\n",
    "def eval_router_reg(router, prompts, gold_models, gold_domains, num_samples=1000):\n",
    "    import random, torch\n",
    "    n = len(prompts)\n",
    "    idxs = random.sample(range(n), min(num_samples, n))\n",
    "\n",
    "    correct_models = 0\n",
    "    correct_domains = 0\n",
    "    dev = next(router.parameters()).device\n",
    "\n",
    "    with torch.no_grad():\n",
    "        router.eval()\n",
    "        for i in idxs:\n",
    "            x  = prompts[i]\n",
    "            gm = gold_models[i]\n",
    "            gd = gold_domains[i]\n",
    "\n",
    "            h_q = router.prompt_encoder([x]).to(dev)\n",
    "            logits = router.domain_head(h_q)\n",
    "            d_pred = logits.argmax(dim=-1).item()\n",
    "            if d_pred == gd:\n",
    "                correct_domains += 1\n",
    "\n",
    "            m_pred = router.route(x, top_k_domains=1)\n",
    "            if m_pred == gm:\n",
    "                correct_models += 1\n",
    "\n",
    "    return (correct_domains/len(idxs), correct_models/len(idxs))\n",
    "\n",
    "\n",
    "# After domain pretrain:\n",
    "for p in router.domain_head.parameters():\n",
    "    p.requires_grad = False   # keep that good domain head for now\n",
    "\n",
    "optimizer_joint = torch.optim.Adam(\n",
    "    (p for p in router.parameters() if p.requires_grad),\n",
    "    lr=1e-3,\n",
    "    weight_decay=1e-4,\n",
    ")\n",
    "\n",
    "num_epochs = 50\n",
    "batch_size = 64\n",
    "n = len(train_prompts)\n",
    "indices = list(range(n))\n",
    "\n",
    "best_val_model_acc = -math.inf\n",
    "best_state_dict = None\n",
    "patience = 10\n",
    "epochs_without_improve = 0\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    random.shuffle(indices)\n",
    "    total_loss, num_batches = 0.0, 0\n",
    "\n",
    "    # --- training ---\n",
    "    for start in range(0, n, batch_size):\n",
    "        batch_idx = indices[start:start+batch_size]\n",
    "        bp = [train_prompts[i]      for i in batch_idx]\n",
    "        bm = [train_best_models[i]  for i in batch_idx]\n",
    "        bd = [train_domains[i]      for i in batch_idx]\n",
    "\n",
    "        loss = router.training_step(\n",
    "            batch_prompts=bp,\n",
    "            batch_best_models=bm,\n",
    "            batch_domains=bd,\n",
    "            optimizer=optimizer_joint,\n",
    "            domain_loss_weight=0.0,   # domain head frozen in this phase\n",
    "            vq_loss_weight=0.01,\n",
    "            cost_penalty=0.0,\n",
    "        )\n",
    "        if loss > 0:\n",
    "            total_loss += loss\n",
    "            num_batches += 1\n",
    "\n",
    "    avg_loss = total_loss/num_batches if num_batches > 0 else float(\"nan\")\n",
    "\n",
    "    # --- validation ---\n",
    "    val_dom_acc, val_model_acc = eval_router_reg(\n",
    "        router, val_prompts, val_best_models, val_domains\n",
    "    )\n",
    "    print(\n",
    "        f\"Epoch {epoch+1}/{num_epochs} \"\n",
    "        f\"- train_loss={avg_loss:.3f} \"\n",
    "        f\"- val_domain={val_dom_acc:.3f} val_model={val_model_acc:.3f}\"\n",
    "    )\n",
    "\n",
    "    # early stopping on val_model_acc\n",
    "    if val_model_acc > best_val_model_acc + 1e-3:  # small margin\n",
    "        best_val_model_acc = val_model_acc\n",
    "        best_state_dict = {k: v.cpu().clone() for k, v in router.state_dict().items()}\n",
    "        epochs_without_improve = 0\n",
    "    else:\n",
    "        epochs_without_improve += 1\n",
    "        if epochs_without_improve >= patience:\n",
    "            print(\"Early stopping triggered.\")\n",
    "            break\n",
    "\n",
    "# restore best val checkpoint\n",
    "if best_state_dict is not None:\n",
    "    router.load_state_dict(best_state_dict)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "f08206c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Stage2 frozen domain] Epoch 1/20 - avg loss: 10.1534\n",
      "[Stage2 frozen domain] Epoch 2/20 - avg loss: 7.5560\n",
      "[Stage2 frozen domain] Epoch 3/20 - avg loss: 6.2465\n",
      "[Stage2 frozen domain] Epoch 4/20 - avg loss: 5.4781\n",
      "[Stage2 frozen domain] Epoch 5/20 - avg loss: 4.9774\n",
      "[Stage2 frozen domain] Epoch 6/20 - avg loss: 4.6246\n",
      "[Stage2 frozen domain] Epoch 7/20 - avg loss: 4.3082\n",
      "[Stage2 frozen domain] Epoch 8/20 - avg loss: 4.0800\n",
      "[Stage2 frozen domain] Epoch 9/20 - avg loss: 3.8595\n",
      "[Stage2 frozen domain] Epoch 10/20 - avg loss: 3.6452\n",
      "[Stage2 frozen domain] Epoch 11/20 - avg loss: 3.4840\n",
      "[Stage2 frozen domain] Epoch 12/20 - avg loss: 3.2926\n",
      "[Stage2 frozen domain] Epoch 13/20 - avg loss: 3.1260\n",
      "[Stage2 frozen domain] Epoch 14/20 - avg loss: 2.9610\n",
      "[Stage2 frozen domain] Epoch 15/20 - avg loss: 2.8368\n",
      "[Stage2 frozen domain] Epoch 16/20 - avg loss: 2.6863\n",
      "[Stage2 frozen domain] Epoch 17/20 - avg loss: 2.5506\n",
      "[Stage2 frozen domain] Epoch 18/20 - avg loss: 2.4287\n",
      "[Stage2 frozen domain] Epoch 19/20 - avg loss: 2.2941\n",
      "[Stage2 frozen domain] Epoch 20/20 - avg loss: 2.2019\n",
      "Domain acc: 0.990\n",
      "Routing acc: 0.590\n",
      "Domain acc: 0.706\n",
      "Routing acc: 0.082\n"
     ]
    }
   ],
   "source": [
    "# Freeze domain head parameters\n",
    "for p in router.domain_head.parameters():\n",
    "    p.requires_grad = False\n",
    "\n",
    "# New optimizer over everything that still has requires_grad=True\n",
    "optimizer_stage2 = torch.optim.Adam(\n",
    "    (p for p in router.parameters() if p.requires_grad),\n",
    "    lr=1e-3,\n",
    ")\n",
    "\n",
    "def joint_train_frozen_domain(router,\n",
    "                              prompts, best_models, domains,\n",
    "                              optimizer,\n",
    "                              num_epochs=20, batch_size=64,\n",
    "                              vq_loss_weight=0.01):\n",
    "    device = next(router.parameters()).device\n",
    "    n = len(prompts)\n",
    "    indices = list(range(n))\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        random.shuffle(indices)\n",
    "        total_loss, num_batches = 0.0, 0\n",
    "\n",
    "        for start in range(0, n, batch_size):\n",
    "            batch_idx = indices[start:start+batch_size]\n",
    "            bp = [prompts[i]      for i in batch_idx]\n",
    "            bm = [best_models[i]  for i in batch_idx]\n",
    "            bd = [domains[i]      for i in batch_idx]\n",
    "\n",
    "            loss = router.training_step(\n",
    "                batch_prompts=bp,\n",
    "                batch_best_models=bm,\n",
    "                batch_domains=bd,\n",
    "                optimizer=optimizer,\n",
    "                domain_loss_weight=0.0,       # <-- no domain loss in stage 2\n",
    "                vq_loss_weight=vq_loss_weight,\n",
    "                cost_penalty=0.0,\n",
    "            )\n",
    "            if loss > 0:\n",
    "                total_loss += loss\n",
    "                num_batches += 1\n",
    "\n",
    "        avg_loss = total_loss / num_batches if num_batches > 0 else float(\"nan\")\n",
    "        print(f\"[Stage2 frozen domain] Epoch {epoch+1}/{num_epochs} - avg loss: {avg_loss:.4f}\")\n",
    "\n",
    "# run stage 2\n",
    "joint_train_frozen_domain(\n",
    "    router,\n",
    "    train_prompts,\n",
    "    train_best_models,\n",
    "    train_domains,\n",
    "    optimizer_stage2,\n",
    "    num_epochs=20,\n",
    "    batch_size=64,\n",
    "    vq_loss_weight=0.01,\n",
    ")\n",
    "eval_router_reg(router, train_prompts, train_best_models, train_domains)\n",
    "eval_router_reg(router, val_prompts, val_best_models, val_domains)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "2b6bd4ca",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Stage3 small domain loss] Epoch 1/20 - avg loss: 0.0141\n",
      "[Stage3 small domain loss] Epoch 2/20 - avg loss: 0.0128\n",
      "[Stage3 small domain loss] Epoch 3/20 - avg loss: 0.0119\n",
      "[Stage3 small domain loss] Epoch 4/20 - avg loss: 0.0103\n",
      "[Stage3 small domain loss] Epoch 5/20 - avg loss: 0.0097\n",
      "[Stage3 small domain loss] Epoch 6/20 - avg loss: 0.0088\n",
      "[Stage3 small domain loss] Epoch 7/20 - avg loss: 0.0084\n",
      "[Stage3 small domain loss] Epoch 8/20 - avg loss: 0.0078\n",
      "[Stage3 small domain loss] Epoch 9/20 - avg loss: 0.0074\n",
      "[Stage3 small domain loss] Epoch 10/20 - avg loss: 0.0072\n",
      "[Stage3 small domain loss] Epoch 11/20 - avg loss: 0.0065\n",
      "[Stage3 small domain loss] Epoch 12/20 - avg loss: 0.0059\n",
      "[Stage3 small domain loss] Epoch 13/20 - avg loss: 0.0060\n",
      "[Stage3 small domain loss] Epoch 14/20 - avg loss: 0.0059\n",
      "[Stage3 small domain loss] Epoch 15/20 - avg loss: 0.0053\n",
      "[Stage3 small domain loss] Epoch 16/20 - avg loss: 0.0049\n",
      "[Stage3 small domain loss] Epoch 17/20 - avg loss: 0.0046\n",
      "[Stage3 small domain loss] Epoch 18/20 - avg loss: 0.0044\n",
      "[Stage3 small domain loss] Epoch 19/20 - avg loss: 0.0044\n",
      "[Stage3 small domain loss] Epoch 20/20 - avg loss: 0.0041\n",
      "Domain acc: 1.000\n",
      "Routing acc: 0.998\n",
      "Domain acc: 0.706\n",
      "Routing acc: 0.176\n"
     ]
    }
   ],
   "source": [
    "# Un-freeze domain head\n",
    "for p in router.domain_head.parameters():\n",
    "    p.requires_grad = True\n",
    "\n",
    "# New optimizer over all parameters\n",
    "optimizer_stage3 = torch.optim.Adam(router.parameters(), lr=1e-3)\n",
    "\n",
    "def joint_train_small_domain_loss(router,\n",
    "                                  prompts, best_models, domains,\n",
    "                                  optimizer,\n",
    "                                  num_epochs=20, batch_size=64,\n",
    "                                  domain_loss_weight=0.01,\n",
    "                                  vq_loss_weight=0.01):\n",
    "    device = next(router.parameters()).device\n",
    "    n = len(prompts)\n",
    "    indices = list(range(n))\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        random.shuffle(indices)\n",
    "        total_loss, num_batches = 0.0, 0\n",
    "\n",
    "        for start in range(0, n, batch_size):\n",
    "            batch_idx = indices[start:start+batch_size]\n",
    "            bp = [prompts[i]      for i in batch_idx]\n",
    "            bm = [best_models[i]  for i in batch_idx]\n",
    "            bd = [domains[i]      for i in batch_idx]\n",
    "\n",
    "            loss = router.training_step(\n",
    "                batch_prompts=bp,\n",
    "                batch_best_models=bm,\n",
    "                batch_domains=bd,\n",
    "                optimizer=optimizer,\n",
    "                domain_loss_weight=domain_loss_weight,  # small but non-zero\n",
    "                vq_loss_weight=vq_loss_weight,\n",
    "                cost_penalty=0.0,\n",
    "            )\n",
    "            if loss > 0:\n",
    "                total_loss += loss\n",
    "                num_batches += 1\n",
    "\n",
    "        avg_loss = total_loss / num_batches if num_batches > 0 else float(\"nan\")\n",
    "        print(f\"[Stage3 small domain loss] Epoch {epoch+1}/{num_epochs} - avg loss: {avg_loss:.4f}\")\n",
    "\n",
    "# run stage 3\n",
    "joint_train_small_domain_loss(\n",
    "    router,\n",
    "    train_prompts,\n",
    "    train_best_models,\n",
    "    train_domains,\n",
    "    optimizer_stage3,\n",
    "    num_epochs=20,\n",
    "    batch_size=64,\n",
    "    domain_loss_weight=0.01,\n",
    "    vq_loss_weight=0.01,\n",
    ")\n",
    "eval_router_reg(router, train_prompts, train_best_models, train_domains)\n",
    "eval_router_reg(router, val_prompts, val_best_models, val_domains)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
