{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/xiaoyuanxin/miniconda3/envs/lyf_jiuan/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.environ['CUDA_LAUNCH_BLOCKING'] = '0'\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '2'  #'0,1,2,3,4,5,6,7'\n",
    "import torch\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "from transformers import ViTForImageClassification, ViTFeatureExtractor, Trainer, TrainingArguments, ViTImageProcessor,AutoImageProcessor\n",
    "import os\n",
    "from PIL import Image\n",
    "from sklearn.metrics import accuracy_score\n",
    "import numpy as np\n",
    "import pyarrow.parquet as pq\n",
    "import io\n",
    "from torchvision import datasets, transforms, models\n",
    "from sklearn.model_selection import train_test_split\n",
    "from peft import LoraConfig, TaskType, get_peft_model,PeftModel, PeftConfig\n",
    "#from transformers import LoRAConfig, LoRAAdapter\n",
    "from utils import get_mnist_data, get_EuroSAT_data,CustomTensorDataset, get_cifar10_data, get_car_data,get_fruits_data, get_GTSRB_data, get_DTD_data, get_resis_data, get_grabage_data, get_plants_data\n",
    "import random\n",
    "from tqdm import tqdm\n",
    "import torch.nn.functional as F\n",
    "\n",
    "def set_random():\n",
    "    random.seed(42)\n",
    "    np.random.seed(42)\n",
    "    torch.manual_seed(42)\n",
    "    # 如果使用 GPU，还需固定 GPU 相关随机数\n",
    "    torch.cuda.manual_seed(42)\n",
    "    torch.cuda.manual_seed_all(42)  # 用于多 GPU 情况\n",
    "    # 确保卷积操作等确定性（针对特定卷积算法）\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "def test_result():\n",
    "    sum1 = 0\n",
    "    set_random()\n",
    "    for i in tqdm(range(min(1000, len(test_dataset)))):\n",
    "        with torch.no_grad():\n",
    "            outputs = model(test_dataset[i][\"pixel_values\"].unsqueeze(0).cuda())\n",
    "        logits = outputs.logits\n",
    "        predicted_class = torch.argmax(logits, dim=-1).item()\n",
    "        if predicted_class != int(test_dataset[i]['labels']):\n",
    "            # print(predicted_class, int(test_dataset[i]['labels']))\n",
    "            sum1 += 1\n",
    "    print(1-sum1/min(1000, len(test_dataset)))\n",
    "    return 1-sum1/min(1000, len(test_dataset))\n",
    "    \n",
    "\n",
    "def compute_accuracy(eval_pred):\n",
    "    logits, labels = eval_pred\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "    return {\"accuracy\": accuracy_score(labels, predictions)}\n",
    "model_dir = \"/mnt/data_B/lyf/vit-base-patch16-224\"\n",
    "model = ViTForImageClassification.from_pretrained(model_dir, torch_dtype=\"auto\")\n",
    "feature_extractor = ViTImageProcessor.from_pretrained(model_dir)\n",
    "\n",
    "# lora_config = LoraConfig(\n",
    "#         #task_type=TaskType.CAUSAL_LM,\n",
    "#         #task_type=TaskType.SEQ_CLS,\n",
    "#         task_type=TaskType.FEATURE_EXTRACTION,\n",
    "#         target_modules=[ \"intermediate.dense\", \"output.dense\"],\n",
    "#         inference_mode=False,  # 训练模式\n",
    "#         r=8,  # Lora 秩\n",
    "#         lora_alpha=32,  # 等效于lr=lr*lora_alpha/r\n",
    "#         lora_dropout=0.1\n",
    "    # )\n",
    "\n",
    "lora_config = LoraConfig(\n",
    "        r=16,\n",
    "        lora_alpha=16,\n",
    "        target_modules=[ \"intermediate.dense\", \"output.dense\"],  #[\"query\", \"value\"]\n",
    "        lora_dropout=0.1,\n",
    "        bias=\"none\",\n",
    "        modules_to_save=[\"classifier\"],\n",
    "    )\n",
    "\n",
    "model = get_peft_model(model, lora_config).cuda()\n",
    "# model.print_trainable_parameters()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for name,param in model.named_parameters():\n",
    "    if \"lora\" in name:\n",
    "        param.data.fill_(0)\n",
    "        print(name,param) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "task = \"multi_task\"  #0.978  # GTSRB DTD 0.737 resis 0.911  grabage 0.633  plants 0.892\n",
    "def load_model_data(task):\n",
    "    path_dict = {\"mnist\": \"./mnist_checkpoint/checkpoint-20\",\"EuroSAT\": \"./EuroSAT_checkpoint/checkpoint-43\",\"cifar10\": \"./cifar10_checkpoint/checkpoint-16\", \"car\":\"./car_checkpoint/checkpoint-180\", \"fruits\":\"./fruits_checkpoint/checkpoint-158\", \"GTSRB\":\"./GTSRB_checkpoint/checkpoint-92\", \"DTD\":\"./DTD_checkpoint/checkpoint-61\", \"resis\":\"./resis_checkpoint/checkpoint-157\", \"grabage\":\"./grabage_checkpoint/checkpoint-25\", \"plants\":\"./plants_checkpoint/checkpoint-113\"}\n",
    "    task_functions = {\"mnist\": get_mnist_data,\"EuroSAT\": get_EuroSAT_data,\"cifar10\": get_cifar10_data,\"car\": get_car_data,\"fruits\": get_fruits_data,\"GTSRB\": get_GTSRB_data,\"DTD\": get_DTD_data,\"resis\": get_resis_data,\"grabage\": get_grabage_data, \"plants\": get_plants_data}\n",
    "    base_model =  ViTForImageClassification.from_pretrained(model_dir, torch_dtype=\"auto\")\n",
    "    peft_config = PeftConfig.from_pretrained(path_dict[task])\n",
    "    # 加载 LoRA 适配器到主模型\n",
    "    model = PeftModel.from_pretrained(base_model, path_dict[task]).to(\"cuda:0\")\n",
    "    train_dataset, test_dataset = task_functions[task]()\n",
    "    return model, train_dataset, test_dataset\n",
    "# model, train_dataset, test_dataset = load_model_data(task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(train_dataset), len(test_dataset), "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_dataset, test_dataset = get_GTSRB_data()  \n",
    "# train_dataset, test_dataset = get_DTD_data()\n",
    "# train_dataset, test_dataset = get_resis_data()  \n",
    "# train_dataset, test_dataset = get_grabage_data() \n",
    "train_dataset, test_dataset = get_mnist_data()  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_result()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "task_functions = {\"mnist\": get_mnist_data,\"EuroSAT\": get_EuroSAT_data,\"cifar10\": get_cifar10_data,\"car\": get_car_data,\"fruits\": get_fruits_data,\"GTSRB\": get_GTSRB_data,\"DTD\": get_DTD_data,\"resis\": get_resis_data,\"grabage\": get_grabage_data, \"plants\": get_plants_data}\n",
    "train_dataset, test_dataset = get_mnist_data()\n",
    "for j, name in enumerate([\"EuroSAT\", \"cifar10\", \"car\", \"fruits\",\"GTSRB\", \"DTD\",\"resis\", \"grabage\",\"plants\"]):\n",
    "    _train_dataset, _test_dataset = task_functions[name]()\n",
    "    train_dataset += _train_dataset\n",
    "    test_dataset += _test_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=f\"./{task}_checkpoint\",          # 输出目录\n",
    "    #evaluation_strategy=\"epoch\",    # 每个 epoch 后评估一次\n",
    "    #save_steps=10,\n",
    "    save_total_limit=1,\n",
    "    learning_rate=5e-3,              # 学习率2e-5\n",
    "    per_device_train_batch_size=64, # 训练批次大小\n",
    "    per_device_eval_batch_size=64,  # 验证批次大小\n",
    "    num_train_epochs=1,             # 训练轮数\n",
    "    weight_decay=0.01,              # 权重衰减\n",
    "    logging_dir=\"./logs\",           # 日志目录\n",
    "    logging_steps=10,               # 每 10 步记录一次日志\n",
    ")\n",
    "\n",
    "# for i in range(15):\n",
    "trainer = Trainer(\n",
    "    model=model,                         # 模型\n",
    "    args=training_args,                  # 训练参数\n",
    "    train_dataset=train_dataset,         # 训练集\n",
    "    eval_dataset=test_dataset,            # 验证集\n",
    "    tokenizer=feature_extractor,         # 特征提取器（用于处理输入）\n",
    "    compute_metrics=compute_accuracy      # 评估指标\n",
    ")\n",
    "\n",
    "# 开始训练\n",
    "\n",
    "trainer.train()\n",
    "    # test_result()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_result()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dropout_consist_layers(N,task):\n",
    "    for layer in range(13-N):\n",
    "        temp_matrix = []\n",
    "        for i in range(N):\n",
    "            temp_matrix.append([model.base_model.model.vit.encoder.layer[layer+i].intermediate.dense.lora_A.default.weight.clone(), model.base_model.model.vit.encoder.layer[i].output.dense.lora_A.default.weight.clone()])\n",
    "            model.base_model.model.vit.encoder.layer[layer+i].intermediate.dense.lora_A.default.weight.fill_(0)\n",
    "            model.base_model.model.vit.encoder.layer[layer+i].output.dense.lora_A.default.weight.fill_(0)\n",
    "        result = test_result()\n",
    "        with open(f'./result_Neuron-{task}.txt',\"a\") as f:\n",
    "            f.write(f\"连续去除第{layer}到{layer+N-1}层的lora参数\" + \"-->错了\" + str(result) + \" \\n\")\n",
    "        for i in range(N):\n",
    "            \n",
    "            model.base_model.model.vit.encoder.layer[layer+i].intermediate.dense.lora_A.default.weight.copy_(temp_matrix[i][0])\n",
    "            model.base_model.model.vit.encoder.layer[layer+i].output.dense.lora_A.default.weight.copy_(temp_matrix[i][1])\n",
    "        with open(f'./result_Neuron-{task}.txt',\"a\") as f:\n",
    "            f.write(\" \\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 神经元去除\n",
    "# for task in [\"mnist\", \"EuroSAT\", \"cifar10\"]:\n",
    "for task in [ \"plants\"]: #\"GTSRB\", \"DTD\",\"resis\", \"grabage\"\n",
    "    model, train_dataset, test_dataset = load_model_data(task)\n",
    "    for i in range(1,13):\n",
    "        dropout_consist_layers(i,task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model, train_dataset, test_dataset = load_model_data(\"fruits\")\n",
    "for layer in range(5,12):\n",
    "    print(layer)\n",
    "    model.base_model.model.vit.encoder.layer[layer].intermediate.dense.lora_A.default.weight.fill_(0)\n",
    "    model.base_model.model.vit.encoder.layer[layer].output.dense.lora_A.default.weight.fill_(0)\n",
    "test_result()\n",
    "\n",
    "#dropout_layer_dict = {0:[2,2], 1:[2,2], 2:[6,11], 3:[1,1], 4:[5,11]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "#指定层进行模型融合-获取每个任务指定层的lora参数\n",
    "def get_every_task_param():\n",
    "    dropout_layer_dict = {\"mnist\":[2,2], \"EuroSAT\":[2,2], \"cifar10\":[6,11], \"car\":[1,1], \"fruits\":[5,11], \"GTSRB\":[0,0], \"DTD\":[0,0], \"resis\":[0,0], \"grabage\":[5,10], \"plants\":[11,11]}\n",
    "    every_task_param = []\n",
    "    for task in [ \"mnist\", \"EuroSAT\", \"cifar10\", \"car\", \"fruits\",\"GTSRB\", \"DTD\",\"resis\", \"grabage\",\"plants\"]:\n",
    "        path_dict = {\"mnist\": \"./mnist_checkpoint/checkpoint-20\",\"EuroSAT\": \"./EuroSAT_checkpoint/checkpoint-43\",\"cifar10\": \"./cifar10_checkpoint/checkpoint-16\", \"car\":\"./car_checkpoint/checkpoint-180\", \"fruits\":\"./fruits_checkpoint/checkpoint-158\", \"GTSRB\":\"./GTSRB_checkpoint/checkpoint-92\", \"DTD\":\"./DTD_checkpoint/checkpoint-61\", \"resis\":\"./resis_checkpoint/checkpoint-157\", \"grabage\":\"./grabage_checkpoint/checkpoint-25\", \"plants\":\"./plants_checkpoint/checkpoint-113\"}\n",
    "        base_model =  ViTForImageClassification.from_pretrained(model_dir, torch_dtype=\"auto\")\n",
    "        peft_config = PeftConfig.from_pretrained(path_dict[task])\n",
    "        # 加载 LoRA 适配器到主模型\n",
    "        model = PeftModel.from_pretrained(base_model, path_dict[task]).to(\"cuda:0\")\n",
    "        # for layer in range(dropout_layer_dict[task][0], dropout_layer_dict[task][1]+1):\n",
    "        #     model.base_model.model.vit.encoder.layer[layer].intermediate.dense.lora_A.default.weight.fill_(0)\n",
    "        #     model.base_model.model.vit.encoder.layer[layer].output.dense.lora_A.default.weight.fill_(0)\n",
    "        #     model.base_model.model.vit.encoder.layer[layer].intermediate.dense.lora_B.default.weight.fill_(0)\n",
    "        #     model.base_model.model.vit.encoder.layer[layer].output.dense.lora_B.default.weight.fill_(0)\n",
    "        lora_params = {k: v for k, v in model.state_dict().items() if 'lora'  in k and \"attention\" not in k}\n",
    "        del model, base_model\n",
    "        torch.cuda.empty_cache()\n",
    "        every_task_param.append(lora_params)\n",
    "    return every_task_param\n",
    "every_task_param = get_every_task_param()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "every_task_param[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def merge(weight_list, model): \n",
    "    merge_lora_dict = dict()\n",
    "    topk_vals, topk_indices = torch.topk(weight_list, k=2)\n",
    "    indices = topk_indices.cpu().numpy()\n",
    "    for key in every_task_param[0]:\n",
    "        temp = 0 \n",
    "        # weight_list = [0.2, 0.2, 0.2, 0.2, 0.2,0.2, 0.2, 0.2, 0.2, 0.2]\n",
    "        # weight_list = [0.1] * 10\n",
    "        for i,num in enumerate(weight_list):\n",
    "        # for i,num in zip(indices,topk_vals):\n",
    "            temp += num * every_task_param[i][key] \n",
    "        merge_lora_dict[key] = temp  #ours\n",
    "        #merge_lora_dict[key] = temp * A\n",
    "    for name,param in model.named_parameters():\n",
    "        if \"lora\" in name and \"attention\" not in name:\n",
    "            param.data = merge_lora_dict[name]\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.tensor([0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "# def ties_merge(model):\n",
    "model, train_dataset, test_dataset = load_model_data(\"mnist\")\n",
    "every_task_param = get_every_task_param()\n",
    "all_tensor_list = []\n",
    "merge_lora_dict = dict()\n",
    "k = 0.1\n",
    "for key in every_task_param[0]:\n",
    "    for i in range(len(every_task_param)):\n",
    "        all_tensor_list.append(every_task_param[i][key].flatten())\n",
    "        # print(every_task_param[i][key].flatten().shape)\n",
    "delete_value,_ = torch.cat(all_tensor_list).abs().kthvalue(int(k*torch.cat(all_tensor_list).shape[0]))\n",
    "\n",
    "for key in every_task_param[0]:\n",
    "    sum_tensor = torch.stack([every_task_param[i][key] for i in range(len(every_task_param))], dim=0).sum(dim=0)\n",
    "    # merge_lora_dict[key] = sum([every_task_param[i][key] for i in range(len(every_task_param))])/len(every_task_param)\n",
    "    for i in range(len(every_task_param)):\n",
    "        mask = every_task_param[i][key].abs() >= delete_value\n",
    "        every_task_param[i][key] = every_task_param[i][key] * mask\n",
    "        sign_match = torch.sign(sum_tensor) == torch.sign(every_task_param[i][key])\n",
    "        every_task_param[i][key] = every_task_param[i][key] * sign_match\n",
    "    merge_lora_dict[key] = torch.stack([every_task_param[i][key] for i in range(len(every_task_param))], dim=0).sum(dim=0)/len(every_task_param)\n",
    "for name,param in model.named_parameters():\n",
    "    if \"lora\" in name and \"attention\" not in name:\n",
    "        param.data = merge_lora_dict[name]\n",
    "    # return model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ti123(model):\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 0.9000,  0.7000, -0.6000, -0.8000, -0.1000])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b = torch.tensor([ 9.0,  7.0, -6.0, -8.0, -1.0])/10\n",
    "b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([12., 10., -3., -5.,  2.])\n",
      "tensor([ 1., -0.,  0., -4.,  0.])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "a = torch.tensor([ 1.0, -2.0,  3.0, -4.0,  5.0])\n",
    "b = torch.tensor([ 9.0,  7.0, -6.0, -8.0, -1.0])\n",
    "print(sum(a,b))\n",
    "# 计算符号是否一致（都为正 or 都为负）\n",
    "sign_match = torch.sign(a) == torch.sign(b)\n",
    "\n",
    "# 保留 a 中符号一致的位置，其它位置设为 0\n",
    "result = a * sign_match\n",
    "\n",
    "print(result)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1.2000, 1.5000, 1.8000],\n",
      "        [1.2000, 1.5000, 1.8000]])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "tensor_list = [\n",
    "    torch.tensor([[1.0, 2.0, 3.0],[1.0, 2.0, 3.0]]),\n",
    "    torch.tensor([[4.0, 5.0, 6.0], [4.0, 5.0, 6.0]]),\n",
    "    torch.tensor([[7.0, 8.0, 9.0],[7.0, 8.0, 9.0] ])\n",
    "]\n",
    "\n",
    "result = torch.stack(tensor_list, dim=0).sum(dim=0)/10\n",
    "print(result)  # tensor([12., 15., 18.])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "14745600"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.cat(all_tensor_list).shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model, train_dataset, test_dataset = load_model_data(\"fruits\")\n",
    "for layer in range(5, 12):\n",
    "        model.base_model.model.vit.encoder.layer[layer].intermediate.dense.lora_A.default.weight.fill_(0)\n",
    "        model.base_model.model.vit.encoder.layer[layer].output.dense.lora_A.default.weight.fill_(0)\n",
    "        model.base_model.model.vit.encoder.layer[layer].intermediate.dense.lora_B.default.weight.fill_(0)\n",
    "        model.base_model.model.vit.encoder.layer[layer].output.dense.lora_B.default.weight.fill_(0)\n",
    "lora_params = {k: v for k, v in model.state_dict().items() if 'lora'  in k and \"attention\" not in k}\n",
    "test_result()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def merge(weight_list, model): #分层融合\n",
    "    merge_lora_dict = dict()\n",
    "    dropout_layer_dict = {0:[2,2], 1:[2,2], 2:[6,11], 3:[1,1], 4:[5,11], 5:[0,0], 6:[0,0], 7:[0,0], 8:[5,10], 9:[11,11]}\n",
    "    value, index = torch.max(weight_list, dim=0)\n",
    "    for key in every_task_param[0]:\n",
    "        temp = 0 \n",
    "        layer = int(re.findall(r'\\d+', key)[0])\n",
    "        if  layer >= dropout_layer_dict[int(index)][0] and layer <= dropout_layer_dict[int(index)][1]:\n",
    "            for i,num in enumerate(weight_list):\n",
    "                temp += num * every_task_param[i][key] \n",
    "        else:\n",
    "            temp = every_task_param[int(index)][key]    \n",
    "        merge_lora_dict[key] = temp  #ours\n",
    "        #merge_lora_dict[key] = temp * A\n",
    "    for name,param in model.named_parameters():\n",
    "        if \"lora\" in name and \"attention\" not in name:\n",
    "            param.data = merge_lora_dict[name]\n",
    "            # param.data = every_task_param[int(index)][name]\n",
    "            #param.data = lora_params[name]\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/xiaoyuanxin/miniconda3/envs/lyf_jiuan/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
      "  warnings.warn(\n",
      "/home/xiaoyuanxin/miniconda3/envs/lyf_jiuan/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n",
      "  warnings.warn(msg)\n",
      "/tmp/ipykernel_70093/3672699077.py:2: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  module_route.load_state_dict(torch.load(\"./router_model.pth\"))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "module_route = models.resnet18(pretrained=False, num_classes=10).cuda()  # 修改分类数\n",
    "module_route.load_state_dict(torch.load(\"./router_model.pth\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "module_route.eval()\n",
    "with torch.no_grad():\n",
    "    model, train_dataset, test_dataset = load_model_data(\"cifar10\")\n",
    "    weight_list = F.softmax(module_route(test_dataset[19][\"pixel_values\"].unsqueeze(0).cuda())[0])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "every_task_param[i][key]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "li = [every_task_param[i][\"base_model.model.vit.encoder.layer.0.output.dense.lora_A.default.weight\"] for i in range(5)]\n",
    "li * torch.tensor([1,1,1,1,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "text = \"abc123def456gh789\"\n",
    "numbers = int(re.findall(r'\\d+', text)[0])\n",
    "print(numbers)  # ['123', '456', '789']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in every_task_param[4]:\n",
    "    print(key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "value, index = torch.max(weight_list, dim=0)\n",
    "value, int(index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for task in [ \"mnist\", \"EuroSAT\", \"cifar10\", \"car\", \"fruits\",\"GTSRB\", \"DTD\",\"resis\", \"grabage\",\"plants\"]:\n",
    "    #for task in [ \"fruits\"]:\n",
    "    model, train_dataset, test_dataset = load_model_data(task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "merge(weight_list, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "model, train_dataset, test_dataset = load_model_data(\"mnist\")  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:12<00:00, 80.54it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mnist 0.724\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:13<00:00, 76.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EuroSAT 0.902\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:12<00:00, 79.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cifar10 0.951\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:01<00:00, 72.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "car 0.29000000000000004\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:14<00:00, 70.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fruits 0.779\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:12<00:00, 79.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GTSRB 0.756\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 972/972 [00:17<00:00, 56.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DTD 0.6707818930041152\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:13<00:00, 73.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "resis 0.819\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 30/30 [00:02<00:00, 14.01it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "grabage 0.6666666666666667\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:17<00:00, 57.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "plants 0.85\n",
      "145.70899844169617\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from time import time\n",
    "import re\n",
    "                                                  #                                             0.936  0.98 0.978 0. 56 0.807\n",
    "# for name in [ \"mnist\", \"EuroSAT\", \"cifar10\", \"car\", \"fruits\",\"GTSRB\", \"DTD\",\"resis\", \"grabage\",\"plants\"]:   #0,53 0.69 0.91 0. 54 0.80;  \n",
    "#     print(name)                                                                     选择模型单专家0.944  0.979 0.981 0.58  0.829\n",
    "                                                                                  #未分层    #0.944  0.979 0.981 0.58  0.829\n",
    "t1 = time() \n",
    "# merged_model, train_dataset, test_dataset = load_model_data(name)   \n",
    "# merged_model = model\n",
    "temp = 0                                                                           #分层   #0.944  0.982  0.983 0.59  0.829\n",
    "module_route.eval()\n",
    "with torch.no_grad():\n",
    "    for task in [ \"mnist\", \"EuroSAT\", \"cifar10\", \"car\", \"fruits\",\"GTSRB\", \"DTD\",\"resis\", \"grabage\",\"plants\"]:\n",
    "\n",
    "        model, train_dataset, test_dataset = load_model_data(task)   \n",
    "        every_task_param = get_every_task_param()  #ties\n",
    "        all_tensor_list = []\n",
    "        merge_lora_dict = dict()\n",
    "        k = 0.1\n",
    "        for key in every_task_param[0]:\n",
    "            for i in range(len(every_task_param)):\n",
    "                all_tensor_list.append(every_task_param[i][key].flatten())\n",
    "                # print(every_task_param[i][key].flatten().shape)\n",
    "        delete_value,_ = torch.cat(all_tensor_list).abs().kthvalue(int(k*torch.cat(all_tensor_list).shape[0]))\n",
    "\n",
    "        for key in every_task_param[0]:\n",
    "            sum_tensor = torch.stack([every_task_param[i][key] for i in range(len(every_task_param))], dim=0).sum(dim=0)\n",
    "            # merge_lora_dict[key] = sum([every_task_param[i][key] for i in range(len(every_task_param))])/len(every_task_param)\n",
    "            for i in range(len(every_task_param)):\n",
    "                mask = every_task_param[i][key].abs() >= delete_value\n",
    "                every_task_param[i][key] = every_task_param[i][key] * mask\n",
    "                sign_match = torch.sign(sum_tensor) == torch.sign(every_task_param[i][key])\n",
    "                every_task_param[i][key] = every_task_param[i][key] * sign_match\n",
    "            # merge_lora_dict[key] = torch.stack([every_task_param[i][key] for i in range(len(every_task_param))], dim=0).sum(dim=0)/len(every_task_param)\n",
    "            tensor_list =  torch.stack([every_task_param[i][key] for i in range(len(every_task_param))], dim=0)   #\"max\"\n",
    "            merge_lora_dict[key] = tensor_list.gather(dim=0, index=tensor_list.abs().argmax(dim=0).unsqueeze(0)).squeeze(0)   #\"max\"\n",
    "        for name,param in model.named_parameters():\n",
    "            if \"lora\" in name and \"attention\" not in name:\n",
    "                param.data = merge_lora_dict[name]\n",
    "\n",
    "\n",
    "        # merge_model = ties_merge(model)\n",
    "        # merge_model = ti123(model)\n",
    "        merged_model = model\n",
    "        sum1 = 0\n",
    "        set_random()\n",
    "        for i in tqdm(range(min(1000, len(test_dataset)))):\n",
    "            # weight_list = F.softmax(module_route(test_dataset[i][\"pixel_values\"].unsqueeze(0).cuda())[0])\n",
    "            # merged_model = merge(weight_list, model)\n",
    "            with torch.no_grad():\n",
    "                outputs = merged_model(test_dataset[i][\"pixel_values\"].unsqueeze(0).cuda())\n",
    "            logits = outputs.logits\n",
    "            predicted_class = torch.argmax(logits, dim=-1).item()\n",
    "            if predicted_class != int(test_dataset[i]['labels']):\n",
    "                sum1 += 1\n",
    "        print(task, 1-sum1/min(1000, len(test_dataset)))\n",
    "        with open(f'./result_ft.txt',\"a\") as f:\n",
    "            f.write(f\"{round((1-sum1/min(1000, len(test_dataset))) * 100, 1)}%\")\n",
    "            f.write(\"   \")\n",
    "        temp += 1-sum1/min(1000, len(test_dataset))\n",
    "    with open(f'./result_ft.txt',\"a\") as f:\n",
    "        f.write(f\"{round(temp /10 * 100, 1)}%\")\n",
    "        f.write(\"\\n\")\n",
    "    print(time()-t1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "topk_vals, topk_indices = torch.topk(weight_list, k=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, num in enumerate(weight_list):\n",
    "    print(i, num)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, num in zip(topk_indices.cpu().numpy(),topk_vals):\n",
    "    print(i, num)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "topk_indices.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 假设小数为 0.756\n",
    "decimal_number = 0.756\n",
    "\n",
    "# 转换为百分数并保留一位小数\n",
    "percentage = round(decimal_number * 100, 1)\n",
    "\n",
    "# 格式化输出\n",
    "formatted_percentage = f\"{round(decimal_number * 100, 1)}%\"\n",
    "\n",
    "print(formatted_percentage) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model, train_dataset, test_dataset = load_model_data(\"car\")\n",
    "module_route.eval()\n",
    "correct = 0\n",
    "total = 0\n",
    "\n",
    "with torch.no_grad():\n",
    "    for i in range():\n",
    "        images, labels = test_dataset[i][\"pixel_values\"].unsqueeze(0).cuda(), test_dataset[i][\"labels\"].unsqueeze(0).cuda()\n",
    "        outputs = module_route(images)\n",
    "        _, preds = torch.max(outputs, 1)\n",
    "        correct += (preds == labels).sum().item()\n",
    "        for i in range(len(labels)):\n",
    "            if int(preds[i]) != 3:\n",
    "                print(preds[i], labels[i])\n",
    "        # break\n",
    "        total += labels.size(0)\n",
    "\n",
    "acc = 100 * correct / total\n",
    "print(f\"Validation Accuracy: {acc:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for name,para in model.named_parameters():\n",
    "    print(name, para.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image = Image.open(\"/mnt/data_B/lyf/vision_classify_ds/GTSRB/Final_Test/Images/12626.ppm\").convert('RGB')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoding = feature_extractor(images=images, return_tensors=\"pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import pandas as pd\n",
    "df_train = pd.read_csv(\"/mnt/data_B/lyf/vision_classify_ds/carBrands50/train.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "train_list = [os.path.join(\"/mnt/data_B/lyf/vision_classify_ds/carBrands50\",i) for i in df_train[\"image:FILE\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(77.3+ 82+84+84.8+ 85.3)/5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "list(df_train[\"category\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for name,para in model.named_parameters():\n",
    "    print(name, para.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.version\n",
    "\n",
    "\n",
    "torch.version"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "print(torch.version.cuda)       # CUDA 版本\n",
    "print(torch.backends.cudnn.version())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd \n",
    "images = []\n",
    "labels = []\n",
    "start_path = \"/mnt/data_B/lyf/vision_classify_ds/GTSRB/Final_Training/Images/\"\n",
    "for i, filename in enumerate(os.listdir(start_path)):\n",
    "    if filename.startswith(\"0\"):\n",
    "        for filename_ in os.listdir(os.path.join(start_path,filename)):\n",
    "            if filename_.endswith(\"csv\"):\n",
    "                df = pd.read_csv(os.path.join(start_path,filename,filename_ ))\n",
    "                images += [os.path.join(start_path,filename,i.split(\";\")[0]) for i in list(df.iloc[:, 0])]\n",
    "                labels += [int(i.split(\";\")[-1]) for i in list(df.iloc[:, 0])]\n",
    "train_list, test_list, train_labels, test_labels = train_test_split(\n",
    "        images, labels, test_size=0.2, random_state=42\n",
    "    )\n",
    "    #     label_list.append(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import pandas as pd\n",
    "df = pd.read_csv('/mnt/data_B/lyf/vision_classify_ds/plants/plants_train.csv')\n",
    "image_list = [os.path.join(\"/mnt/data_B/lyf/vision_classify_ds/plants\", name) for name in df[\"Image\"]]\n",
    "label_list = [int(label) for label in df[\"CATEGORY\"]]\n",
    "for i,path in enumerate(image_list):\n",
    "    print(i)\n",
    "    Image.open(path).convert('RGB')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lyf_jiuan",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
