{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/xiaoyuanxin/miniconda3/envs/lyf_jiuan/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.environ['CUDA_LAUNCH_BLOCKING'] = '0'\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '2'  #'0,1,2,3,4,5,6,7'\n",
    "import torch\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "from transformers import ViTForImageClassification, ViTFeatureExtractor, Trainer, TrainingArguments, ViTImageProcessor,AutoImageProcessor\n",
    "import os\n",
    "from PIL import Image\n",
    "from sklearn.metrics import accuracy_score\n",
    "import numpy as np\n",
    "import pyarrow.parquet as pq\n",
    "import io\n",
    "from torchvision import datasets, transforms, models\n",
    "from sklearn.model_selection import train_test_split\n",
    "from peft import LoraConfig, TaskType, get_peft_model,PeftModel, PeftConfig\n",
    "#from transformers import LoRAConfig, LoRAAdapter\n",
    "from utils import get_mnist_data, get_EuroSAT_data,CustomTensorDataset, get_cifar10_data, get_car_data,get_fruits_data, get_GTSRB_data, get_DTD_data, get_resis_data, get_grabage_data, get_plants_data\n",
    "import random\n",
    "from tqdm import tqdm\n",
    "import torch.nn.functional as F\n",
    "\n",
    "def set_random():\n",
    "    random.seed(42)\n",
    "    np.random.seed(42)\n",
    "    torch.manual_seed(42)\n",
    "    # 如果使用 GPU，还需固定 GPU 相关随机数\n",
    "    torch.cuda.manual_seed(42)\n",
    "    torch.cuda.manual_seed_all(42)  # 用于多 GPU 情况\n",
    "    # 确保卷积操作等确定性（针对特定卷积算法）\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "def test_result():\n",
    "    sum1 = 0\n",
    "    set_random()\n",
    "    for i in tqdm(range(min(1000, len(test_dataset)))):\n",
    "        with torch.no_grad():\n",
    "            outputs = model(test_dataset[i][\"pixel_values\"].unsqueeze(0).cuda())\n",
    "        logits = outputs.logits\n",
    "        predicted_class = torch.argmax(logits, dim=-1).item()\n",
    "        if predicted_class != int(test_dataset[i]['labels']):\n",
    "            # print(predicted_class, int(test_dataset[i]['labels']))\n",
    "            sum1 += 1\n",
    "    print(1-sum1/min(1000, len(test_dataset)))\n",
    "    return 1-sum1/min(1000, len(test_dataset))\n",
    "    \n",
    "\n",
    "def compute_accuracy(eval_pred):\n",
    "    logits, labels = eval_pred\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "    return {\"accuracy\": accuracy_score(labels, predictions)}\n",
    "model_dir = \"./vit-base-patch16-224\"\n",
    "model = ViTForImageClassification.from_pretrained(model_dir, torch_dtype=\"auto\")\n",
    "feature_extractor = ViTImageProcessor.from_pretrained(model_dir)\n",
    "\n",
    "\n",
    "lora_config = LoraConfig(\n",
    "        r=16,\n",
    "        lora_alpha=16,\n",
    "        target_modules=[ \"intermediate.dense\", \"output.dense\"],  #[\"query\", \"value\"]\n",
    "        lora_dropout=0.1,\n",
    "        bias=\"none\",\n",
    "        modules_to_save=[\"classifier\"],\n",
    "    )\n",
    "\n",
    "model = get_peft_model(model, lora_config).cuda()\n",
    "# model.print_trainable_parameters()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_dataset, test_dataset = get_GTSRB_data()  \n",
    "# train_dataset, test_dataset = get_DTD_data()\n",
    "# train_dataset, test_dataset = get_resis_data()  \n",
    "# train_dataset, test_dataset = get_grabage_data() \n",
    "train_dataset, test_dataset = get_mnist_data()   #select task you train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=f\"./{task}_checkpoint\",          # 输出目录\n",
    "    #evaluation_strategy=\"epoch\",    # 每个 epoch 后评估一次\n",
    "    #save_steps=10,\n",
    "    save_total_limit=1,\n",
    "    learning_rate=5e-3,              # 学习率2e-5\n",
    "    per_device_train_batch_size=64, # 训练批次大小\n",
    "    per_device_eval_batch_size=64,  # 验证批次大小\n",
    "    num_train_epochs=1,             # 训练轮数\n",
    "    weight_decay=0.01,              # 权重衰减\n",
    "    logging_dir=\"./logs\",           # 日志目录\n",
    "    logging_steps=10,               # 每 10 步记录一次日志\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,                         # 模型\n",
    "    args=training_args,                  # 训练参数\n",
    "    train_dataset=train_dataset,         # 训练集\n",
    "    eval_dataset=test_dataset,            # 验证集\n",
    "    tokenizer=feature_extractor,         # 特征提取器（用于处理输入）\n",
    "    compute_metrics=compute_accuracy      # 评估指标\n",
    ")\n",
    "\n",
    "# 开始训练\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_result()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lyf_jiuan",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
