{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa95b00d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from constants import * \n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "import torch\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "from accelerate import infer_auto_device_map,init_empty_weights, load_checkpoint_and_dispatch\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "with init_empty_weights():\n",
    "    model = AutoModelForCausalLM.from_pretrained(\"meta-llama/Llama-3.1-405B\",torch_dtype=torch.bfloat16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d3254fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "device_map = infer_auto_device_map(model, no_split_module_classes=[\"LlamaDecoderLayer\"],max_memory={**{i: \"70GiB\" for i in range(8)},**{\"cpu\":\"1000Gib\"}})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17e04f0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = load_checkpoint_and_dispatch(\n",
    "    model, checkpoint=\"meta-llama/Llama-3.1-405B\",device_map=device_map\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "557d0aec",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3.1-70B\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e47dc885",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.pad_token = tokenizer.eos_token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d7eabcc",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_pred = \"Meta-Llama-3.1-405B\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ab72ccf",
   "metadata": {},
   "outputs": [],
   "source": [
    "answer_tokens_llama = {362: \" A\",426: \" B\" , 356: \" C\", 423:\" D\"}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41e45421",
   "metadata": {},
   "outputs": [],
   "source": [
    "softmax = torch.nn.Softmax(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "841f1d71",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "batch_size_max = 10\n",
    "\n",
    "for task in tqdm([task.split(\"subject=\")[1].split(\",\")[0] for task in meta_predictions.split(\"\\n\")][49:]):\n",
    "    print(task)\n",
    "    res = [\n",
    "        i\n",
    "        for i in file_list\n",
    "        if i.startswith(\"mmlu:subject=%s\" % task) and \"openai_gpt-4o-2024-05-13\" in i\n",
    "    ]\n",
    "    assert len(res) == 1\n",
    "\n",
    "    n = 0\n",
    "    for k in res:\n",
    "\n",
    "        path_instance = os.path.join(\n",
    "            dir_predictions,\n",
    "            k,\n",
    "            \"display_requests.json\",\n",
    "        )\n",
    "\n",
    "        reqs = json.load(open(path_instance, \"r\"))\n",
    "        \n",
    "        \n",
    "        encoded_input = tokenizer([reqs[i][\"request\"][\"prompt\"] for i in range(len(reqs))], return_tensors='pt',padding=True)\n",
    "        \n",
    "        shape = encoded_input[\"input_ids\"].shape[1]\n",
    "        \n",
    "        batch_size = int(np.floor(batch_size_max * 600 / shape))\n",
    "        \n",
    "        \n",
    "        answers_collect = []\n",
    "        for b in tqdm(range(int(np.ceil(len(reqs)/batch_size)))):            \n",
    "            \n",
    "            m = batch_size\n",
    "            if b == np.ceil(len(reqs)/batch_size)-1 and len(reqs)%batch_size>0:\n",
    "                m = len(reqs)%batch_size\n",
    "            \n",
    "            encoded_input = tokenizer([reqs[b*batch_size+i][\"request\"][\"prompt\"] for i in range(m)], return_tensors='pt',padding=True)\n",
    "            with torch.no_grad():\n",
    "                output = model(**encoded_input)\n",
    "                \n",
    "                \n",
    "            for i,n in enumerate(encoded_input.attention_mask.sum(-1)):\n",
    "                subselect = softmax(output.logits[i][n-1])[[362,426, 356, 423]]\n",
    "                subselect = subselect/subselect.sum()\n",
    "                \n",
    "                answers = {\"instance_id\":reqs[b*batch_size+i][\"instance_id\"]}\n",
    "                for j in range(4):\n",
    "                    answers[j] = subselect[j].cpu().item()\n",
    "                answers_collect.append(answers)\n",
    "        \n",
    "        path_save = os.path.join(\n",
    "            dir_predictions,\n",
    "            model_pred + \"_\" +  task,\n",
    "        )\n",
    "        os.mkdir(path_save)\n",
    "        \n",
    "        with open(path_save+\"/predictions_proba.json\", \"w\") as final:\n",
    "            json.dump(answers_collect, final)\n",
    "        \n",
    "        "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "vscode": {
   "interpreter": {
    "hash": "7fddc940b2fd2f13da2f5a82af96e5187a221fc62c2c19af53ab202281c29800"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
