{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/xiaoyuanxin/miniconda3/envs/lyf_jiuan/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "from transformers import ViTForImageClassification, ViTFeatureExtractor, Trainer, TrainingArguments, ViTImageProcessor,AutoImageProcessor\n",
    "import os\n",
    "from PIL import Image\n",
    "from sklearn.metrics import accuracy_score\n",
    "import numpy as np\n",
    "import pyarrow.parquet as pq\n",
    "import io\n",
    "from sklearn.model_selection import train_test_split\n",
    "from peft import LoraConfig, TaskType, get_peft_model,PeftModel, PeftConfig\n",
    "#from transformers import LoRAConfig, LoRAAdapter\n",
    "from utils import get_mnist_data, get_EuroSAT_data,CustomTensorDataset, get_cifar10_data, get_car_data,get_fruits_data, get_GTSRB_data, get_DTD_data, get_resis_data, get_grabage_data, get_plants_data\n",
    "import random\n",
    "from tqdm import tqdm\n",
    "\n",
    "def set_random():\n",
    "    random.seed(42)\n",
    "    np.random.seed(42)\n",
    "    torch.manual_seed(42)\n",
    "    # 如果使用 GPU，还需固定 GPU 相关随机数\n",
    "    torch.cuda.manual_seed(42)\n",
    "    torch.cuda.manual_seed_all(42)  # 用于多 GPU 情况\n",
    "    # 确保卷积操作等确定性（针对特定卷积算法）\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "def test_result():\n",
    "    sum1 = 0\n",
    "    set_random()\n",
    "    for i in tqdm(range(1000)):\n",
    "        with torch.no_grad():\n",
    "            outputs = model(test_dataset[i][\"pixel_values\"].unsqueeze(0).cuda())\n",
    "        logits = outputs.logits\n",
    "        predicted_class = torch.argmax(logits, dim=-1).item()\n",
    "        if predicted_class != int(test_dataset[i]['labels']):\n",
    "            #print(predicted_class, int(test_dataset[i]['labels']))\n",
    "            sum1 += 1\n",
    "    print(1-sum1/1000)\n",
    "    return 1-sum1/1000\n",
    "    \n",
    "\n",
    "def compute_accuracy(eval_pred):\n",
    "    logits, labels = eval_pred\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "    return {\"accuracy\": accuracy_score(labels, predictions)}\n",
    "model_dir = \"./vit-base-patch16-224\"\n",
    "model = ViTForImageClassification.from_pretrained(model_dir, torch_dtype=\"auto\")\n",
    "feature_extractor = ViTImageProcessor.from_pretrained(model_dir)\n",
    "\n",
    "# lora_config = LoraConfig(\n",
    "#         #task_type=TaskType.CAUSAL_LM,\n",
    "#         #task_type=TaskType.SEQ_CLS,\n",
    "#         task_type=TaskType.FEATURE_EXTRACTION,\n",
    "#         target_modules=[ \"intermediate.dense\", \"output.dense\"],\n",
    "#         inference_mode=False,  # 训练模式\n",
    "#         r=8,  # Lora 秩\n",
    "#         lora_alpha=32,  # 等效于lr=lr*lora_alpha/r\n",
    "#         lora_dropout=0.1\n",
    "    # )\n",
    "\n",
    "lora_config = LoraConfig(\n",
    "        r=16,\n",
    "        lora_alpha=16,\n",
    "        target_modules=[ \"intermediate.dense\", \"output.dense\"],  #[\"query\", \"value\"]\n",
    "        lora_dropout=0.1,\n",
    "        bias=\"none\",\n",
    "        modules_to_save=[\"classifier\"],\n",
    "    )\n",
    "\n",
    "model = get_peft_model(model, lora_config)\n",
    "# model.print_trainable_parameters()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomTensorDataset(Dataset):\n",
    "    def __init__(self, tensors, labels, mode= \"PNG_code\", transform=None): #\"JPG\"  Tensor\n",
    "        self.tensors = tensors\n",
    "        self.labels = labels\n",
    "        self.transform = transform\n",
    "        self.mode = mode\n",
    "        assert len(self.tensors) == len(self.labels)  # \"张量数量和标签数量必须相同\"\n",
    "    def __len__(self):\n",
    "        return len(self.tensors)\n",
    "    def __getitem__(self, idx):\n",
    "        label = self.labels[idx]\n",
    "        return {\"pixel_values\": self.tensors[idx], \"labels\": torch.tensor(label)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "task_functions = {\"mnist\": get_mnist_data,\"EuroSAT\": get_EuroSAT_data,\"cifar10\": get_cifar10_data,\"car\": get_car_data,\"fruits\": get_fruits_data,\"GTSRB\": get_GTSRB_data,\"DTD\": get_DTD_data,\"resis\": get_resis_data,\"grabage\": get_grabage_data, \"plants\": get_plants_data}\n",
    "train_list = []\n",
    "train_label = []\n",
    "# test_list = []\n",
    "# test_label = []\n",
    "for j, task in enumerate([ \"mnist\", \"EuroSAT\", \"cifar10\", \"car\", \"fruits\",\"GTSRB\", \"DTD\",\"resis\", \"grabage\",\"plants\"]):\n",
    "    train_dataset, test_dataset = task_functions[task]()\n",
    "    train_list += [train_dataset[i]['pixel_values'] for i in range(min(5000, len(train_dataset)))]\n",
    "    train_label += [j] * min(5000, len(train_dataset))\n",
    "    # test_list += [test_dataset[i]['pixel_values'] for i in range(min(1000, len(test_dataset)))]\n",
    "    # test_label += [j] * min(1000, len(test_dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "train_dataset = CustomTensorDataset(train_list, train_label)\n",
    "test_dataset = CustomTensorDataset(test_list, test_label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/xiaoyuanxin/miniconda3/envs/lyf_jiuan/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
      "  warnings.warn(\n",
      "/home/xiaoyuanxin/miniconda3/envs/lyf_jiuan/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n",
      "  warnings.warn(msg)\n",
      "/tmp/ipykernel_1536254/2176203441.py:3: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  model1.load_state_dict(torch.load(\"./router_model.pth\"))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from torchvision import datasets, transforms, models\n",
    "model1 = models.resnet18(pretrained=False, num_classes=10).cuda()  # 修改分类数\n",
    "model1.load_state_dict(torch.load(\"./router_model.pth\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/50], Loss: 50.5148\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 1/50 [01:09<57:09, 69.98s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.42%\n",
      "✅ Saved new best model with accuracy 99.42%\n",
      "Epoch [2/50], Loss: 42.6233\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|▍         | 2/50 [02:19<55:56, 69.93s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.59%\n",
      "✅ Saved new best model with accuracy 99.59%\n",
      "Epoch [3/50], Loss: 39.4879\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 3/50 [03:29<54:36, 69.72s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.58%\n",
      "Epoch [4/50], Loss: 34.2911\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|▊         | 4/50 [04:40<53:44, 70.11s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.63%\n",
      "✅ Saved new best model with accuracy 99.63%\n",
      "Epoch [5/50], Loss: 31.7832\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 5/50 [05:50<52:45, 70.34s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.52%\n",
      "Epoch [6/50], Loss: 34.2322\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|█▏        | 6/50 [07:01<51:41, 70.49s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.70%\n",
      "✅ Saved new best model with accuracy 99.70%\n",
      "Epoch [7/50], Loss: 29.3944\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█▍        | 7/50 [08:12<50:42, 70.75s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.77%\n",
      "✅ Saved new best model with accuracy 99.77%\n",
      "Epoch [8/50], Loss: 29.8419\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|█▌        | 8/50 [09:23<49:33, 70.79s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.75%\n",
      "Epoch [9/50], Loss: 24.3104\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|█▊        | 9/50 [10:32<47:56, 70.15s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.62%\n",
      "Epoch [10/50], Loss: 30.4025\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 10/50 [11:42<46:38, 69.97s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.57%\n",
      "Epoch [11/50], Loss: 22.1049\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 22%|██▏       | 11/50 [12:51<45:24, 69.87s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.64%\n",
      "Epoch [12/50], Loss: 26.2783\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 24%|██▍       | 12/50 [14:01<44:08, 69.71s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.79%\n",
      "✅ Saved new best model with accuracy 99.79%\n",
      "Epoch [13/50], Loss: 21.0047\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 26%|██▌       | 13/50 [15:11<43:06, 69.91s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.41%\n",
      "Epoch [14/50], Loss: 21.5771\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 28%|██▊       | 14/50 [16:22<42:04, 70.13s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.47%\n",
      "Epoch [15/50], Loss: 20.3882\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 15/50 [17:32<41:01, 70.32s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.84%\n",
      "✅ Saved new best model with accuracy 99.84%\n",
      "Epoch [16/50], Loss: 21.9488\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 32%|███▏      | 16/50 [18:43<39:53, 70.40s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.85%\n",
      "✅ Saved new best model with accuracy 99.85%\n",
      "Epoch [17/50], Loss: 19.9341\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 34%|███▍      | 17/50 [19:53<38:40, 70.33s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.94%\n",
      "✅ Saved new best model with accuracy 99.94%\n",
      "Epoch [18/50], Loss: 20.3127\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 36%|███▌      | 18/50 [21:03<37:22, 70.07s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.78%\n",
      "Epoch [19/50], Loss: 18.4443\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 38%|███▊      | 19/50 [22:13<36:17, 70.25s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.90%\n",
      "Epoch [20/50], Loss: 17.8759\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 20/50 [23:23<35:04, 70.16s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.86%\n",
      "Epoch [21/50], Loss: 18.5726\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 42%|████▏     | 21/50 [24:34<34:03, 70.45s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.94%\n",
      "Epoch [22/50], Loss: 14.0819\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 44%|████▍     | 22/50 [25:44<32:49, 70.33s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.91%\n",
      "Epoch [23/50], Loss: 18.8460\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 46%|████▌     | 23/50 [26:54<31:35, 70.20s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.75%\n",
      "Epoch [24/50], Loss: 14.8107\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 48%|████▊     | 24/50 [28:02<30:09, 69.61s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.94%\n",
      "Epoch [25/50], Loss: 14.5296\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 25/50 [29:11<28:55, 69.42s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.62%\n",
      "Epoch [26/50], Loss: 19.6297\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 52%|█████▏    | 26/50 [30:21<27:49, 69.55s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.74%\n",
      "Epoch [27/50], Loss: 10.2360\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 54%|█████▍    | 27/50 [31:33<26:57, 70.32s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.21%\n",
      "Epoch [28/50], Loss: 14.7380\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 56%|█████▌    | 28/50 [32:43<25:44, 70.22s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.90%\n",
      "Epoch [29/50], Loss: 14.8460\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 58%|█████▊    | 29/50 [33:53<24:29, 69.97s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.96%\n",
      "✅ Saved new best model with accuracy 99.96%\n",
      "Epoch [30/50], Loss: 15.2014\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 30/50 [35:02<23:13, 69.66s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.69%\n",
      "Epoch [31/50], Loss: 12.8336\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 62%|██████▏   | 31/50 [36:11<22:03, 69.66s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.31%\n",
      "Epoch [32/50], Loss: 11.4480\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 64%|██████▍   | 32/50 [37:21<20:55, 69.73s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.98%\n",
      "✅ Saved new best model with accuracy 99.98%\n",
      "Epoch [33/50], Loss: 14.3290\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 66%|██████▌   | 33/50 [38:31<19:43, 69.59s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.89%\n",
      "Epoch [34/50], Loss: 12.5065\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 68%|██████▊   | 34/50 [39:39<18:29, 69.34s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.81%\n",
      "Epoch [35/50], Loss: 11.1200\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|███████   | 35/50 [40:49<17:20, 69.38s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.80%\n",
      "Epoch [36/50], Loss: 14.5078\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 72%|███████▏  | 36/50 [41:58<16:11, 69.37s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.90%\n",
      "Epoch [37/50], Loss: 11.4677\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 74%|███████▍  | 37/50 [43:07<15:01, 69.36s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.85%\n",
      "Epoch [38/50], Loss: 11.0377\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 76%|███████▌  | 38/50 [44:18<13:58, 69.84s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.90%\n",
      "Epoch [39/50], Loss: 10.6018\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 78%|███████▊  | 39/50 [45:28<12:45, 69.63s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.74%\n",
      "Epoch [40/50], Loss: 12.2402\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 40/50 [46:38<11:37, 69.76s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.91%\n",
      "Epoch [41/50], Loss: 11.9961\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 82%|████████▏ | 41/50 [47:48<10:28, 69.84s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.72%\n",
      "Epoch [42/50], Loss: 11.5267\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 84%|████████▍ | 42/50 [48:57<09:17, 69.68s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.81%\n",
      "Epoch [43/50], Loss: 7.1949\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 86%|████████▌ | 43/50 [50:08<08:10, 70.08s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.91%\n",
      "Epoch [44/50], Loss: 10.1764\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 88%|████████▊ | 44/50 [51:17<06:59, 69.84s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.90%\n",
      "Epoch [45/50], Loss: 12.5272\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|█████████ | 45/50 [52:27<05:48, 69.70s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.90%\n",
      "Epoch [46/50], Loss: 9.3731\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 92%|█████████▏| 46/50 [53:37<04:39, 69.76s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.91%\n",
      "Epoch [47/50], Loss: 11.0547\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 94%|█████████▍| 47/50 [54:46<03:29, 69.68s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.93%\n",
      "Epoch [48/50], Loss: 8.9699\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 96%|█████████▌| 48/50 [55:55<02:19, 69.62s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.94%\n",
      "Epoch [49/50], Loss: 6.9718\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 98%|█████████▊| 49/50 [57:05<01:09, 69.57s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.90%\n",
      "Epoch [50/50], Loss: 10.4520\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [58:15<00:00, 69.90s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 99.95%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# 2. 加载训练/验证数据\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms, models\n",
    "from torch.utils.data import DataLoader\n",
    "train_dataset_ = [[i['pixel_values'],i[\"labels\"]] for i in train_dataset]\n",
    "test_dataset_ = [[i['pixel_values'],i[\"labels\"]] for i in test_dataset]\n",
    "train_loader = DataLoader(train_dataset_+test_dataset_, batch_size=32, shuffle=True)\n",
    "val_loader = DataLoader(test_dataset_, batch_size=32, shuffle=False)\n",
    "\n",
    "# 3. 加载 ResNet18 模型（不使用 ImageNet 预训练）\n",
    "# model = models.resnet18(pretrained=False, num_classes=10)  # 修改分类数\n",
    "# model = model.to(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# 4. 损失函数 + 优化器\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "\n",
    "# 5. 训练\n",
    "epochs = 50\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "best_acc = 0\n",
    "for epoch in tqdm(range(epochs)):\n",
    "    model.train()\n",
    "    running_loss = 0.0\n",
    "\n",
    "    for images, labels in train_loader:\n",
    "        # print(images,labels)\n",
    "        # print(images[\"pixel_values\"],labels[\"labels\"])\n",
    "        # break\n",
    "        images, labels = images.to(device), labels.to(device)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(images)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        running_loss += loss.item()\n",
    "\n",
    "    print(f\"Epoch [{epoch+1}/{epochs}], Loss: {running_loss:.4f}\")\n",
    "\n",
    "    # 验证\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for images, labels in val_loader:\n",
    "            images, labels = images.to(device), labels.to(device)\n",
    "            outputs = model(images)\n",
    "            _, preds = torch.max(outputs, 1)\n",
    "            correct += (preds == labels).sum().item()\n",
    "            total += labels.size(0)\n",
    "\n",
    "    acc = 100 * correct / total\n",
    "    print(f\"Validation Accuracy: {acc:.2f}%\")\n",
    "    if acc > best_acc:\n",
    "        best_acc = acc\n",
    "        torch.save(model.state_dict(), \"./router_model.pth\")\n",
    "        print(f\"✅ Saved new best model with accuracy {acc:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "correct = 0\n",
    "total = 0\n",
    "\n",
    "test_dataset_ = [[i['pixel_values'],i[\"labels\"]] for i in test_dataset]\n",
    "train_loader = DataLoader(train_dataset_+test_dataset_, batch_size=32, shuffle=True)\n",
    "val_loader = DataLoader(test_dataset_, batch_size=32, shuffle=False)\n",
    "with torch.no_grad():\n",
    "    for images, labels in val_loader:\n",
    "        images, labels = images.to(device), labels.to(device)\n",
    "        outputs = model(images)\n",
    "        _, preds = torch.max(outputs, 1)\n",
    "        correct += (preds == labels).sum().item()\n",
    "        for i in range(len(labels)):\n",
    "            if preds[i] != labels[i]:\n",
    "                print(preds[i], labels[i])\n",
    "        # break\n",
    "        total += labels.size(0)\n",
    "\n",
    "acc = 100 * correct / total\n",
    "print(f\"Validation Accuracy: {acc:.2f}%\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lyf_jiuan",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
