{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc88b481",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa440433",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "##################################################################\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1,2,3,4,5,6,7\"\n",
    "##################################################################\n",
    "\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "from src.utils import env_utils\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "\n",
    "logger.info(f\"{torch.__version__=}, {torch.version.cuda=}\")\n",
    "logger.info(\n",
    "    f\"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}\"\n",
    ")\n",
    "logger.info(f\"{transformers.__version__=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b8766a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.training_utils import get_device_map\n",
    "\n",
    "# model_key = \"meta-llama/Llama-3.2-3B\"\n",
    "# model_key = \"meta-llama/Llama-3.1-8B\"\n",
    "# model_key = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
    "# model_key = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
    "# model_key = \"meta-llama/Llama-3.1-405B-Instruct\"\n",
    "\n",
    "# model_key = \"google/gemma-2-9b-it\"\n",
    "# model_key = \"google/gemma-3-12b-it\"\n",
    "model_key = \"google/gemma-2-27b-it\"\n",
    "\n",
    "# model_key = \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\"\n",
    "\n",
    "# model_key = \"allenai/OLMo-2-1124-7B-Instruct\"\n",
    "# model_key = \"allenai/OLMo-7B-0424-hf\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen2-7B\"\n",
    "# model_key = \"Qwen/Qwen2.5-14B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-32B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-72B-Instruct\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen3-1.7B\"\n",
    "# model_key = \"Qwen/Qwen3-4B\"\n",
    "# model_key = \"Qwen/Qwen3-8B\"\n",
    "# model_key = \"Qwen/Qwen3-14B\"\n",
    "# model_key = \"Qwen/Qwen3-32B\"\n",
    "\n",
    "# device_map = get_device_map(model_key, 30, n_gpus=8)\n",
    "# device_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "310827ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models import ModelandTokenizer\n",
    "\n",
    "# from transformers import BitsAndBytesConfig\n",
    "\n",
    "mt = ModelandTokenizer(\n",
    "    model_key=model_key,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    # device_map=device_map,\n",
    "    device_map=\"auto\",\n",
    "    # quantization_config = BitsAndBytesConfig(\n",
    "    #     # load_in_4bit=True\n",
    "    #     load_in_8bit=True\n",
    "    # )\n",
    "    attn_implementation=\"eager\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c17359f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import SelectOneTask, CountingTask, YesNoTask, SelectFirstTask\n",
    "from src.selection.data import SelectionSample, CountingSample, YesNoSample\n",
    "\n",
    "#################################################################################\n",
    "TASK_CLS = SelectOneTask\n",
    "# TASK_CLS = CountingTask\n",
    "# TASK_CLS = YesNoTask\n",
    "# TASK_CLS = SelectFirstTask\n",
    "prompt_template_idx = 0\n",
    "N_DISTRACTORS = 5\n",
    "OPTION_STYLE = \"single_line\"\n",
    "#################################################################################\n",
    "\n",
    "select_task = TASK_CLS.load(\n",
    "    path=os.path.join(\n",
    "        env_utils.DEFAULT_DATA_DIR,\n",
    "        \"selection\",\n",
    "        # \"profession.json\"\n",
    "        # \"nationality.json\"\n",
    "        \"objects.json\",\n",
    "    )\n",
    ")\n",
    "\n",
    "print(select_task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f921352",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import predict_next_token, generate_with_patch\n",
    "\n",
    "sample = select_task.get_random_sample(\n",
    "    mt=mt, \n",
    "    category=\"fruit\",\n",
    "    filter_by_lm_prediction=True,\n",
    "    prompt_template_idx=3,\n",
    "    # n_options=5,\n",
    "    n_distractors=3,\n",
    ")\n",
    "print(f'\"{sample.prompt()}\" >> {mt.tokenizer.decode(sample.ans_token_id)}')\n",
    "\n",
    "gen = generate_with_patch(\n",
    "    mt=mt, \n",
    "    inputs=sample.prompt(),\n",
    "    n_gen_per_prompt=1,\n",
    "    remove_prefix=True,\n",
    "    max_new_tokens=20,\n",
    "    do_sample=False,\n",
    ")\n",
    "print(f\"Generation:\\\"{gen[0]}\\\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "764d792e",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample.prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00d0e4c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import get_counterfactual_samples_interface\n",
    "\n",
    "counterfactual_sampler = get_counterfactual_samples_interface[select_task.task_name]\n",
    "\n",
    "print(counterfactual_sampler)\n",
    "\n",
    "patch_sample, clean_sample = counterfactual_sampler(\n",
    "    mt=mt,\n",
    "    task=select_task,\n",
    "    patch_category=\"fruit\",\n",
    "    clean_category=\"vehicle\",\n",
    "    filter_by_lm_prediction=True,\n",
    "    prompt_template_idx=3,\n",
    "    option_style=OPTION_STYLE,\n",
    "    # distinct_options=True,\n",
    "    # n_options=6\n",
    ")\n",
    "\n",
    "# patch_sample.default_option_style = \"single_line\"\n",
    "# clean_sample.default_option_style = \"numbered\"\n",
    "\n",
    "print(\"-\" * 100)\n",
    "print(patch_sample.prompt(), \">>\", mt.tokenizer.decode(patch_sample.ans_token_id))\n",
    "print(clean_sample.prompt(), \">>\", mt.tokenizer.decode(clean_sample.ans_token_id))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31a8c3bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "(\n",
    "    mt.tokenizer.decode(clean_sample.metadata[\"track_type_obj_token_id\"]),\n",
    "    mt.tokenizer.decode(clean_sample.ans_token_id),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "568ebee2",
   "metadata": {},
   "outputs": [],
   "source": [
    "patch_sample.prediction, clean_sample.prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c85b70a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "from dataclasses_json import DataClassJsonMixin\n",
    "from typing import Union\n",
    "\n",
    "@dataclass\n",
    "class CounterFactualSamplePair(DataClassJsonMixin):\n",
    "    patch_sample: Union[SelectionSample, CountingSample, YesNoSample]\n",
    "    clean_sample: Union[SelectionSample, CountingSample, YesNoSample]\n",
    "\n",
    "    @staticmethod\n",
    "    def sample_type_to_class():\n",
    "        return {\n",
    "            \"selection\": SelectionSample,\n",
    "            \"counting\": CountingSample,\n",
    "            \"yes_no\": YesNoSample,\n",
    "        }\n",
    "\n",
    "    def detensorize(self):\n",
    "        for sample in [self.patch_sample, self.clean_sample]:\n",
    "            class_name = type(sample).__name__\n",
    "            type_to_name = {\n",
    "                \"SelectionSample\": \"selection\",\n",
    "                \"CountingSample\": \"counting\",\n",
    "                \"YesNoSample\": \"yes_no\",\n",
    "            }\n",
    "            sample.metadata[\"sample_type\"] = type_to_name[class_name]\n",
    "        self.patch_sample.detensorize()\n",
    "        self.clean_sample.detensorize()\n",
    "\n",
    "    @staticmethod\n",
    "    def from_dict(d):\n",
    "        sample_type = d[\"patch_sample\"][\"metadata\"].pop(\"sample_type\")\n",
    "        sample_cls = CounterFactualSamplePair.sample_type_to_class()[sample_type]\n",
    "        patch_sample = sample_cls.from_dict(d[\"patch_sample\"])\n",
    "        sample_type = d[\"clean_sample\"][\"metadata\"].pop(\"sample_type\")\n",
    "        sample_cls = CounterFactualSamplePair.sample_type_to_class()[sample_type]\n",
    "        clean_sample = sample_cls.from_dict(d[\"clean_sample\"])\n",
    "        return CounterFactualSamplePair(\n",
    "            patch_sample=patch_sample,\n",
    "            clean_sample=clean_sample,\n",
    "        )\n",
    "\n",
    "\n",
    "cf_pair = CounterFactualSamplePair(\n",
    "    patch_sample=patch_sample,\n",
    "    clean_sample=clean_sample,\n",
    ")\n",
    "cf_pair.detensorize()\n",
    "with open(\"cf_pair_debug.json\", \"w\") as f:\n",
    "    json.dump(cf_pair.to_dict(), f, indent=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "735339e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"cf_pair_debug.json\", \"r\") as f:\n",
    "    cf_pair_data = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2ff5536",
   "metadata": {},
   "outputs": [],
   "source": [
    "cf_pair_data[\"patch_sample\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdda579f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# CountingSample.from_dict(cf_pair_data[\"patch_sample\"])\n",
    "cf_pair_loaded = CounterFactualSamplePair.from_dict(cf_pair_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49057d0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "cf_pair_loaded.patch_sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38955bcf",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(cf_pair_loaded.patch_sample.prompt(), \">>\", mt.tokenizer.decode(cf_pair_loaded.patch_sample.ans_token_id))\n",
    "cf_pair_loaded.patch_sample.prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75c1e423",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_save_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"samples\",\n",
    "    \"training\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    select_task.task_name,\n",
    "    \"objects\"\n",
    ")\n",
    "\n",
    "os.makedirs(sample_save_path, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33958ba4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import free_gpu_cache\n",
    "from src.selection.data import get_counterfactual_samples_interface\n",
    "import random\n",
    "\n",
    "free_gpu_cache()\n",
    "validation_set = []\n",
    "validation_limit = 8\n",
    "\n",
    "counterfactual_sampler = get_counterfactual_samples_interface[select_task.task_name]\n",
    "\n",
    "while len(validation_set) < validation_limit:\n",
    "    print(f\"sample {len(validation_set)+1} / {validation_limit}\")\n",
    "    patch, clean = counterfactual_sampler(\n",
    "        mt=mt,\n",
    "        task=select_task,\n",
    "        filter_by_lm_prediction=True,\n",
    "        prompt_template_idx=prompt_template_idx,\n",
    "        option_style=OPTION_STYLE,\n",
    "        distinct_options=True,\n",
    "        n_distractors=N_DISTRACTORS,\n",
    "        # patch_n_distractors=N_DISTRACTORS,\n",
    "        # clean_n_distractors=N_DISTRACTORS\n",
    "        # n_options = random.choice([5])\n",
    "    )\n",
    "    validation_set.append((clean, patch))\n",
    "    cf_pair = CounterFactualSamplePair(\n",
    "        patch_sample=patch,\n",
    "        clean_sample=clean,\n",
    "    )\n",
    "    cf_pair.detensorize()\n",
    "    with open(os.path.join(sample_save_path, f\"{len(validation_set):05d}.json\"), \"w\") as f:\n",
    "        json.dump(cf_pair.to_dict(), f, indent=2)\n",
    "\n",
    "len(validation_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72131711",
   "metadata": {},
   "outputs": [],
   "source": [
    "free_gpu_cache()\n",
    "validation_set = []\n",
    "validation_limit = 8\n",
    "\n",
    "sample_load_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"samples\",\n",
    "    \"training\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    select_task.task_name,\n",
    "    \"objects\"\n",
    ")\n",
    "\n",
    "sample_files = [\n",
    "    os.path.join(sample_load_path, f)\n",
    "    for f in os.listdir(sample_load_path)\n",
    "    if f.endswith(\".json\")\n",
    "]\n",
    "random.shuffle(sample_files)\n",
    "sample_files = sample_files[:validation_limit]\n",
    "for sample_file in sample_files:\n",
    "    with open(sample_file, \"r\") as f:\n",
    "        cf_pair_data = json.load(f)\n",
    "    cf_pair = CounterFactualSamplePair.from_dict(cf_pair_data)\n",
    "    validation_set.append((cf_pair.clean_sample, cf_pair.patch_sample))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f9272b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean, patch = validation_set[5]\n",
    "print(clean.prompt(), \">>\", mt.tokenizer.decode(clean.ans_token_id))\n",
    "print(patch.prompt(), \">>\", mt.tokenizer.decode(patch.ans_token_id))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdfe511c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "connection",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
