{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3c1bc1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "import os \n",
    "import torch\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"models/R1-1.5B\", )\n",
    "model = AutoModelForCausalLM.from_pretrained(\"models/R1-1.5B\", torch_dtype = torch.bfloat16)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "649060e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json \n",
    "from tqdm import  tqdm\n",
    "from torch.nn import functional as F\n",
    "train_data_path = \"train_outputs/datasets_compression_dataset_results_models_R1-1.5B_none_ZSA_[0:10].json\"\n",
    "\n",
    "with open(train_data_path, \"r\") as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7600468a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_ll(text):\n",
    "    inputs = tokenizer(text, return_tensors=\"pt\")\n",
    "    input_ids = inputs[\"input_ids\"]\n",
    "    attention_mask = inputs[\"attention_mask\"]\n",
    "\n",
    "    with torch.no_grad():\n",
    "        outputs = model(input_ids=input_ids, attention_mask=attention_mask)\n",
    "        logits = outputs.logits  # shape: [batch, seq_len, vocab_size]\n",
    "\n",
    "    \n",
    "    # Shift so that tokens <t+1> are predicted from tokens <t>\n",
    "    shift_logits = logits[:, :-1, :]\n",
    "    shift_labels = input_ids[:, 1:]\n",
    "\n",
    "    # Compute log-probabilities of the correct next tokens\n",
    "    log_probs = F.log_softmax(shift_logits, dim=-1)\n",
    "    token_log_probs = log_probs.gather(dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(-1)\n",
    "\n",
    "    # Mask padding if any\n",
    "    token_log_probs = token_log_probs * attention_mask[:, 1:]\n",
    "\n",
    "    # Total log-likelihood\n",
    "    log_likelihood = token_log_probs.sum().item()\n",
    "\n",
    "    return log_likelihood\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44b48a20",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = {}\n",
    "for x in tqdm(data):\n",
    "    if x[\"problem\"] not in train_data:\n",
    "        train_data[x[\"problem\"]] = []\n",
    "    \n",
    "    question = x[\"problem\"]\n",
    "    summary = x[\"summary\"]\n",
    "    prediction = x[\"responses\"][0]\n",
    "    accuracy = x[\"accuracy\"][0]\n",
    "    assistant_response = summary + \"</think>\" + prediction \n",
    "\n",
    "    prompt = [{\n",
    "                \"role\": \"user\",\n",
    "                \"content\": f\"Please reason step by step, and put your final answer within \\\\boxed{{}}. Question: {question}\",\n",
    "            },\n",
    "            {\n",
    "                \"role\": \"assistant\",\n",
    "                \"content\": assistant_response,\n",
    "            }\n",
    "            ]\n",
    "\n",
    "    templated_prompt = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=False)\n",
    "    # templated_prompt += assistant_response + tokenizer.eos_token\n",
    "\n",
    "    ll = calculate_ll(templated_prompt)\n",
    "    train_data[question].append((templated_prompt, accuracy, ll))\n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dbba510",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_prompts = []  # list of (question, prompt, accuracy, ll)\n",
    "\n",
    "\n",
    "\n",
    "file_path = \"train_data.jsonl\"\n",
    "\n",
    "open(file_path, \"w\").close()\n",
    "\n",
    "for question, items in train_data.items():\n",
    "    # keep only correct answers (accuracy == 1)\n",
    "    correct_items = [x for x in items if x[1] == 1]\n",
    "\n",
    "    if correct_items:  # if there is at least one correct answer\n",
    "        best = max(correct_items, key=lambda x: x[2])  # highest log-likelihood\n",
    "        pr, rs = best[0].split(\"<think>\")\n",
    "        pr = pr+\"<think>\"\n",
    "    \n",
    "\n",
    "        d = {\n",
    "            \"question\": pr,\n",
    "            \"response\" : rs,\n",
    "        }\n",
    "        best_prompts.append(d)\n",
    "        with open(file_path, \"a\") as f:\n",
    "            f.write(json.dumps(d) + \"\\n\")\n",
    "    else:\n",
    "        # optional: skip or handle questions with no correct answers\n",
    "        pass\n",
    "\n",
    "\n",
    "\n",
    "print(len(best_prompts))\n",
    "best_prompts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b69760d1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
