{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f73261c0-9c56-4517-aa34-086bec04c475",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.evaluator import Evaluator\n",
    "from src.engines import HuggingFaceEngine, VLLMEngine\n",
    "from src.scoring import DetectionScoringMethod, FuzzingScoringMethod\n",
    "from src.samplers import TopSampler, RandomSampler, QuantileSampler\n",
    "\n",
    "import pickle\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acfd4168",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_to_interpret = [\"Qwen/Qwen2.5-0.5B\", \"google/gemma-2-2b\"][0]\n",
    "context_tokenizer = AutoTokenizer.from_pretrained(model_to_interpret)\n",
    "\n",
    "interpreter_model = [\"Qwen/Qwen2.5-7B-Instruct\", \"Qwen/Qwen3-14B-AWQ\"][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44553161",
   "metadata": {},
   "outputs": [],
   "source": [
    "engine_type = \"vllm\"\n",
    "\n",
    "if engine_type == \"hf\":\n",
    "\n",
    "    model = AutoModelForCausalLM.from_pretrained(interpreter_model).to(\"cuda\")\n",
    "    tokenizer = AutoTokenizer.from_pretrained(interpreter_model)\n",
    "    \n",
    "    engine = HuggingFaceEngine(\n",
    "        model=model, \n",
    "        tokenizer=tokenizer, \n",
    "        device=\"cuda\"\n",
    "    )\n",
    "\n",
    "elif engine_type == \"vllm\":\n",
    "    \n",
    "    engine = VLLMEngine(\n",
    "        model_name_or_path=interpreter_model,\n",
    "        max_model_len=4096,\n",
    "        dtype=\"auto\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b86054f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "pos_sampler = TopSampler(sample_size=64, top_quantile=0.5, random_sample=True)\n",
    "neg_sampler = RandomSampler(64)\n",
    "\n",
    "evaluator = Evaluator(\n",
    "    inference_engine=engine,\n",
    "    pos_sampler=pos_sampler,\n",
    "    neg_sampler=neg_sampler,\n",
    "    examples_in_prompt=8,\n",
    "    batch_size=32,\n",
    "    show_progress=True\n",
    ")\n",
    "\n",
    "detection_scorer = DetectionScoringMethod(\n",
    "    context_tokenizer=context_tokenizer,\n",
    "    max_example_length=24,\n",
    ")\n",
    "\n",
    "fuzzing_scorer = FuzzingScoringMethod(\n",
    "    context_tokenizer=context_tokenizer,\n",
    "    max_example_length=24,\n",
    "    correct_pos_ratio=0.5,\n",
    "    incorrect_pos_ratio=0.25,\n",
    "    incorrect_neg_ratio=0.25,\n",
    ")\n",
    "\n",
    "name = 'post'\n",
    "features = pickle.load(open(f'results/{name}.pkl', 'rb'))\n",
    "\n",
    "features_to_evaluate = list(features.values())\n",
    "interpretations_list = [\n",
    "    f.interpretations[-1] for f in features_to_evaluate\n",
    "]\n",
    "\n",
    "generation_kwargs = {\n",
    "    'max_tokens': 32,\n",
    "    'temperature': 0.75\n",
    "}\n",
    "engine.chat_template_args['enable_thinking'] = False\n",
    "engine.chat_template_args['tokenize'] = False\n",
    "\n",
    "detection_metrics = evaluator.run(\n",
    "    features_to_evaluate, \n",
    "    interpretations_list, \n",
    "    detection_scorer,\n",
    "    **generation_kwargs\n",
    ")\n",
    "\n",
    "fuzzing_metrics = evaluator.run(\n",
    "    features_to_evaluate, \n",
    "    interpretations_list, \n",
    "    fuzzing_scorer,\n",
    "    **generation_kwargs\n",
    ")\n",
    "\n",
    "for f in fuzzing_metrics.keys():\n",
    "    result = {\n",
    "        'detection': detection_metrics[f]['score'], \n",
    "        'fuzzing': fuzzing_metrics[f]['score']\n",
    "    }\n",
    "    # print(f\"Feature {f}: {features[f].interpretations[-1].value} ({result})\")\n",
    "    features[f].interpretations[-1].score = result\n",
    "\n",
    "with open(f'results/{name}.pkl', 'wb') as f:\n",
    "    pickle.dump(features, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61529e46-0a1f-4aa3-8b9e-e6073c10ed44",
   "metadata": {},
   "outputs": [],
   "source": [
    "index = np.random.choice(list(features.keys()), 1)[0]\n",
    "# index = 39\n",
    "\n",
    "interp = features[index].interpretations[-1]\n",
    "det = interp.score['detection']\n",
    "fuz = interp.score['fuzzing']\n",
    "print(f\"({det:.2f}, {fuz:.2f}) - {interp.value}\")\n",
    "\n",
    "print(features[index].show(\n",
    "    tokenizer=context_tokenizer,\n",
    "    context_window=15,\n",
    "    max_example_length=100,\n",
    "    examples_per_quantile=5\n",
    "))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60460bd9-1458-4a77-8ff8-8923fea52834",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
