{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49b13a92-92c1-4529-8c43-d294d882ffa1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# conda create --name qwen --clone ovis\n",
    "# conda activate qwen\n",
    "# module load cuda/12.2\n",
    "# pip install qwen-vl-utils==0.0.8\n",
    "# pip install torchvision==0.19.0\n",
    "# pip install gekko zstandard\n",
    "# pip install git+https://github.com/casper-hansen/AutoAWQ.git --no-deps\n",
    "\n",
    "# pip install autoawq --no-deps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfd6f7a4-362d-4d8c-bbc0-58cca90610ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import csv\n",
    "import glob\n",
    "import json\n",
    "import torch\n",
    "\n",
    "from tqdm import tqdm\n",
    "from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor\n",
    "from qwen_vl_utils import process_vision_info\n",
    "\n",
    "# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.\n",
    "model = Qwen2_5_VLForConditionalGeneration.from_pretrained(\n",
    "    \"Qwen/Qwen2.5-VL-32B-Instruct-AWQ\",\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    attn_implementation=\"flash_attention_2\",\n",
    "    device_map=\"auto\",\n",
    ")\n",
    "processor = AutoProcessor.from_pretrained(\"Qwen/Qwen2.5-VL-32B-Instruct-AWQ\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65e26fed-2df7-4f02-887d-0a06aa300388",
   "metadata": {},
   "outputs": [],
   "source": [
    "text = \"\"\"You are an image analysis tool specialized in facial attribute classification. \n",
    "          For the provided face image, output a JSON object with the following attributes:\n",
    "        \n",
    "            {\n",
    "              \"gender\": [\"male\", \"female\"],\n",
    "              \"age\": [\"young\", \"middle-aged\", \"senior\"],\n",
    "              \"skin_color\": [\"light\", \"medium\", \"dark\"],\n",
    "              \"ancestry\": [\"asian\", \"south_asian\", \"black\", \"latino/hispanic\", \"middle_eastern\", \"white\", \"indigenous\"],\n",
    "              \"hair_color\": [\"black\", \"brown\", \"red\", \"blonde\", \"gray\", \"other\"],\n",
    "              \"bangs\": [\"yes\", \"no\"],\n",
    "              \"bald\": [\"yes\", \"no\"],\n",
    "              \"beard\": [\"no\", \"mustache\", \"stubble\", \"full\"],\n",
    "              \"glasses\": [\"no\", \"regular\", \"sun\"],\n",
    "              \"headwear\": [\"no\", \"beanie\", \"cap\", \"hat\", \"headband\", \"hijab\", \"helmet\", \"turban\"],\n",
    "            }\n",
    "\n",
    "          Ensure the labeling is based on visible evidence only. If an attribute is unclear, return \"unknown\".\n",
    "          \n",
    "          Only output the JSON without any additional explanation or text.\n",
    "          \n",
    "          Example JSON output:\n",
    "          \n",
    "          {\n",
    "            \"gender\": \"female\",\n",
    "            \"age\": \"middle-aged\",\n",
    "            \"skin_color\": \"light\",\n",
    "            \"ancestry\": \"asian\",\n",
    "            \"hair_color\": \"black\",\n",
    "            \"bangs\": \"no\",\n",
    "            \"bald\": \"no\",\n",
    "            \"beard\": \"no\",\n",
    "            \"glasses\": \"sun\",\n",
    "            \"headwear\": \"beanie\",\n",
    "          }\n",
    "          \"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0145ccd-ad74-4c08-92c6-c1e981ae3753",
   "metadata": {},
   "outputs": [],
   "source": [
    "default_keys = [\"gender\", \"age\", \"skin_color\", \"ancestry\", \"hair_color\", \"bangs\", \n",
    "                \"bald\", \"beard\", \"glasses\", \"headwear\"]\n",
    "\n",
    "# Function to process a single image\n",
    "def process_images(batch_images):\n",
    "    # Sample messages for batch inference\n",
    "    messages = []\n",
    "    for image_path in batch_images:\n",
    "        messages.append([\n",
    "            {\n",
    "                \"role\": \"user\",\n",
    "                \"content\": [\n",
    "                    {\"type\": \"image\", \"image\": image_path},\n",
    "                    {\"type\": \"text\", \"text\": text},\n",
    "                ],\n",
    "            }\n",
    "        ])\n",
    "    # Preparation for batch inference\n",
    "    texts = [\n",
    "        processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)\n",
    "        for msg in messages\n",
    "    ]\n",
    "    image_inputs, video_inputs = process_vision_info(messages)\n",
    "    inputs = processor(\n",
    "        text=texts,\n",
    "        images=image_inputs,\n",
    "        videos=video_inputs,\n",
    "        padding=True,\n",
    "        return_tensors=\"pt\",\n",
    "    )\n",
    "    inputs = inputs.to(\"cuda\")\n",
    "    with torch.inference_mode():\n",
    "        generated_ids = model.generate(**inputs, max_new_tokens=128)\n",
    "    generated_ids_trimmed = [\n",
    "        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)\n",
    "    ]\n",
    "    output_texts = processor.batch_decode(\n",
    "        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False\n",
    "    )\n",
    "    result = []\n",
    "    for i in range(len(batch_images)):\n",
    "        # Extracting the JSON part\n",
    "        json_output = re.search(r\"\\{.*\\}\", output_texts[i], re.DOTALL).group()\n",
    "        json_output = re.sub(r',\\s*}', '}', json_output)  # This removes a comma just before a closing curly brace\n",
    "        data = json.loads(json_output)\n",
    "        result.append([batch_images[i].replace('../RFW/data/', '')] + [data.get(key, \"unknown\") for key in default_keys])\n",
    "    torch.cuda.empty_cache()\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6c880b3-cc8e-4c02-aa95-c308dab33fb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "init = 0\n",
    "files = glob.glob('../RFW/data/*/*/*')\n",
    "\n",
    "output_file = \"qwen7b_labels.csv\"\n",
    "\n",
    "# Column names\n",
    "columns = [\"File Name\", \"Gender\", \"Age\", \"Skin Color\", \"Ancestry\", \"Hair Color\", \n",
    "           \"Bangs\", \"Bald\", \"Beard\", \"Glasses\", \"Headwear\"]\n",
    "\n",
    "# Chunk settings\n",
    "chunk_size = 32  # Number of images processed per batch\n",
    "\n",
    "# Write CSV header\n",
    "with open(output_file, \"w\", newline=\"\") as csvfile:\n",
    "    writer = csv.writer(csvfile)\n",
    "    writer.writerow(columns)  # CSV Header\n",
    "\n",
    "# Process in chunks\n",
    "for start_idx in tqdm(range(init, len(files), chunk_size), desc=\"Processing Images\"):\n",
    "    end_idx = min(start_idx + chunk_size, len(files))\n",
    "    chunk = files[start_idx:end_idx]\n",
    "    results = process_images(chunk)\n",
    "    # Save chunk results to CSV\n",
    "    with open(output_file, \"a\", newline=\"\") as csvfile:\n",
    "        writer = csv.writer(csvfile)\n",
    "        writer.writerows(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6956ff6e-b82e-4ac2-be5f-7aebf9785530",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
