{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import torch\n",
    "from torch.nn.functional import normalize\n",
    "from collections import defaultdict\n",
    "import json\n",
    "import numpy as np\n",
    "import math\n",
    "\n",
    "def combine_gradients(gradient_path, chunk_num=8):\n",
    "    all_gradients = []\n",
    "    for idx in range(chunk_num):\n",
    "        g = torch.load(f\"{gradient_path}/output_{idx}\", map_location='cpu')\n",
    "        all_gradients.extend(torch.cat(g, dim=0)) \n",
    "    return all_gradients\n",
    "\n",
    "chunk_num = 8\n",
    "\n",
    "model_name = 'your_model_name'\n",
    "data_path = 'your_data_path'\n",
    "grads_path  = 'your_grads_path'\n",
    "grads = []\n",
    "for idx in range(chunk_num):\n",
    "    grads.extend(json.load(open(f'{grads_path}/output_norm_{idx}.json', 'r')))\n",
    "    \n",
    "data = json.load(open(data_path, 'r'))\n",
    "\n",
    "hyper_q = 0.001\n",
    "hyper_p = 0.15\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "llava_instruct: 1.648859331135432\n",
      "vqav2: 5.972603918809087\n",
      "aokvqa: 10.620936431331243\n",
      "refcoco: 2.4312766137318875\n",
      "vg: 1.4690650998121753\n",
      "gqa: 3.088610164895096\n",
      "ocrvqa: 3.500897007751465\n",
      "sharegpt: 0.6496585714760794\n"
     ]
    }
   ],
   "source": [
    "## compute task value\n",
    "from collections import defaultdict\n",
    "ds2si = defaultdict(float)\n",
    "ds2data = defaultdict(list)\n",
    "\n",
    "for idx, d in enumerate(data):\n",
    "    # ds2si[mapping[d['dataset']]] += grads[idx]\n",
    "    # ds2data[mapping[d['dataset']]].append(d)\n",
    "    if d['dataset'] == 'textcaps':\n",
    "        continue\n",
    "    ds2si[d['dataset']] += grads[idx]\n",
    "    ds2data[d['dataset']].append(d)\n",
    "\n",
    "total_si = 0\n",
    "for k, v in ds2si.items():\n",
    "    print(f\"{k}: {v / len(ds2data[k])}\")\n",
    "    total_si += v / len(ds2data[k])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "llava_instruct: 0.056118186043480185\n",
      "vqav2: 0.20327488922232217\n",
      "aokvqa: 0.3614787964956287\n",
      "refcoco: 0.08274740649865779\n",
      "vg: 0.049998970212015\n",
      "gqa: 0.10511945838946189\n",
      "ocrvqa: 0.11915145573077567\n",
      "sharegpt: 0.022110837407658568\n"
     ]
    }
   ],
   "source": [
    "for k, v in ds2si.items():\n",
    "    print(f\"{k}: {v / len(ds2data[k]) / total_si}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "## compute instance value\n",
    "def calculate_influence_score(training_info: torch.Tensor, validation_info: torch.Tensor):\n",
    "    \"\"\"Calculate the influence score.\n",
    "\n",
    "    Args:\n",
    "        training_info (torch.Tensor): training info (gradients/representations) stored in a tensor of shape N x N_DIM\n",
    "        validation_info (torch.Tensor): validation info (gradients/representations) stored in a tensor of shape N_VALID x N_DIM\n",
    "    \"\"\"\n",
    "    # N x N_VALID\n",
    "    influence_scores = torch.matmul(\n",
    "        training_info, validation_info.transpose(0, 1))\n",
    "    return influence_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "## initializing gradient features\n",
    "grads = combine_gradients(grads_path, chunk_num)\n",
    "\n",
    "grads = torch.cat([sg.unsqueeze(0) for sg in grads], dim = 0)\n",
    "grads_normalized = normalize(grads, dim=1)\n",
    "\n",
    "## record the gradient features and corrsponding index in original data within each super category\n",
    "ds2grad = defaultdict(list)\n",
    "ds2idx = defaultdict(list)\n",
    "\n",
    "for idx, d in enumerate(data):\n",
    "    # ds2grad[mapping[d['dataset']]].append(grads_normalized[idx])\n",
    "    # ds2idx[mapping[d['dataset']]].append(idx)\n",
    "    if d['dataset'] == 'textcaps':\n",
    "        continue\n",
    "    ds2grad[d['dataset']].append(grads_normalized[idx])\n",
    "    ds2idx[d['dataset']].append(idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "llava_instruct\n",
      "5600.297540753292\n",
      "torch.Size([157712])\n",
      "1.0000013\n",
      "vqav2\n",
      "20285.756587474873\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_55537/472782945.py:54: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  instance_weight = torch.nn.functional.softmax((influence_score * q).clone().detach())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([91781])\n",
      "1.0000247\n",
      "aokvqa\n",
      "36073.668052642315\n",
      "torch.Size([66160])\n",
      "1.0001174\n",
      "refcoco\n",
      "8257.752607311604\n",
      "torch.Size([48447])\n",
      "1.0000353\n",
      "vg\n",
      "4989.632232616973\n",
      "torch.Size([86417])\n",
      "1.0000199\n",
      "gqa\n",
      "10490.364814138833\n",
      "torch.Size([72140])\n",
      "1.0000205\n",
      "ocrvqa\n",
      "11890.683779216039\n",
      "torch.Size([80000])\n",
      "1.0000191\n",
      "sharegpt\n",
      "2206.5443858460644\n",
      "torch.Size([40688])\n",
      "1.0000079\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "selected_data_size = hyper_p * len(data)\n",
    "selected_idx = []\n",
    "for ds in ds2grad.keys():\n",
    "    print(ds)\n",
    "    ds_size = ds2si[ds] / len(ds2data[ds]) / total_si * selected_data_size\n",
    "    print(ds_size)\n",
    "    \n",
    "    target_grads_sub_tasks = torch.cat(ds2grad[ds])\n",
    "    target_grads_sub_tasks = target_grads_sub_tasks.reshape([-1, 8192])\n",
    "\n",
    "    ## in case the data samples of one task is too large, causing out-of-memory\n",
    "    if ds == 'ALLaVA-ins':\n",
    "        chunk_size = 80000\n",
    "        chunk_data = [target_grads_sub_tasks[i:i+chunk_size] for i in range(0, len(target_grads_sub_tasks), chunk_size)]\n",
    "        iscores = []\n",
    "        for d in chunk_data:\n",
    "            influence_score = calculate_influence_score(d, target_grads_sub_tasks)\n",
    "            influence_score = influence_score.reshape(\n",
    "                        influence_score.shape[0], 1, -1).mean(-1).max(-1)[0]\n",
    "            iscores.append(influence_score)\n",
    "        influence_score = torch.cat(iscores)\n",
    "\n",
    "    else:\n",
    "        influence_score = calculate_influence_score(target_grads_sub_tasks, target_grads_sub_tasks)\n",
    "        influence_score = influence_score.reshape(\n",
    "                    influence_score.shape[0], 1, -1).mean(-1).max(-1)[0]\n",
    "\n",
    "    print(influence_score.size())\n",
    "\n",
    "    ## hard sampling\n",
    "    # _, selected_indices = torch.topk(influence_score, int(ds_size))\n",
    "    # for idx in selected_indices:\n",
    "    #     selected_idx.append(ds2idx[ds][idx])\n",
    "    \n",
    "    ## soft sampling old implementation\n",
    "    # median_is = np.median(influence_score)\n",
    "\n",
    "    # def mi_score(mutual_influence):\n",
    "    #     import math\n",
    "    #     q = 1\n",
    "    #     def sigmoid(x):\n",
    "    #         return 1 / (1 + math.exp(q * -x))\n",
    "    #     if mutual_influence < median_is:\n",
    "    #         return 1e-5\n",
    "    #     return sigmoid(mutual_influence)\n",
    "\n",
    "    # instance_weight = [mi_score(ins) for ins in influence_score]\n",
    "    # p = instance_weight / np.sum(instance_weight)\n",
    "    # selected_idx.extend(np.random.choice(ds2idx[ds], p = list(p), size=int(ds_size), replace=False))\n",
    "    \n",
    "    ## soft sampling new implementation\n",
    "    q = hyper_q\n",
    "    instance_weight = torch.nn.functional.softmax((influence_score * q).clone().detach())\n",
    "    instance_weight = instance_weight.cpu().numpy()\n",
    "    instance_weight = instance_weight / np.sum(instance_weight)\n",
    "\n",
    "    ## enough samples for selection\n",
    "    if ds_size < len(ds2idx[ds]):\n",
    "        selected_idx.extend(np.random.choice(ds2idx[ds], p = instance_weight, size=int(ds_size), replace=False))\n",
    "        \n",
    "    ## else, oversampling\n",
    "    else:\n",
    "        oversample_times = int(ds_size // len(ds2idx[ds]))\n",
    "        sampled_numbers = int(ds_size % len(ds2idx[ds]))\n",
    "        for i in range(oversample_times):\n",
    "            selected_idx.extend(ds2idx[ds])\n",
    "        selected_idx.extend(np.random.choice(ds2idx[ds], p = instance_weight, size=sampled_numbers, replace=False))\n",
    "\n",
    "    print(max(instance_weight) / sorted(instance_weight)[-int(min(len(ds2idx[ds]), ds_size))])\n",
    "    \n",
    "    ## random\n",
    "    # selected_idx.extend(random.sample(ds2idx[ds], int(ds_size)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defaultdict(int,\n",
       "            {'llava_instruct': 5600,\n",
       "             'vqav2': 20285,\n",
       "             'aokvqa': 36073,\n",
       "             'refcoco': 8257,\n",
       "             'vg': 4989,\n",
       "             'gqa': 10490,\n",
       "             'ocrvqa': 11890,\n",
       "             'sharegpt': 2206})"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_data = []\n",
    "for idx in selected_idx:\n",
    "    new_data.append(data[idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.14999293549657447"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(new_data) / len(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path = f\"selected_data/Tive_q_{hyper_q}_p_{hyper_p}_{model_name}_v3_v3.json\"\n",
    "json.dump(new_data, open(save_path, 'w'), indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mllm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
