{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65e26fed-2df7-4f02-887d-0a06aa300388",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import csv\n",
    "import glob\n",
    "import json\n",
    "import torch\n",
    "\n",
    "from PIL import Image\n",
    "from tqdm import tqdm\n",
    "from transformers import GenerationConfig\n",
    "from gptqmodel import GPTQModel\n",
    "\n",
    "# load model\n",
    "# customize load device\n",
    "load_device = \"cuda:0\"\n",
    "torch.cuda.set_device(load_device)\n",
    "# We take AIDC-AI/Ovis2-34B-GPTQ-Int4 as an example. Note that the code snippet is \n",
    "# applicable to any GPTQ-quantized Ovis2 model.\n",
    "model = GPTQModel.load(\"AIDC-AI/Ovis2-34B-GPTQ-Int8\", device=load_device, trust_remote_code=True)\n",
    "model.model.generation_config = GenerationConfig.from_pretrained(\"AIDC-AI/Ovis2-34B-GPTQ-Int8\")\n",
    "text_tokenizer = model.get_text_tokenizer()\n",
    "visual_tokenizer = model.get_visual_tokenizer()\n",
    "\n",
    "text = \"\"\"You are an image analysis tool specialized in facial attribute classification. \n",
    "          For the provided face image, output a JSON object with the following attributes:\n",
    "        \n",
    "            {\n",
    "              \"gender\": [\"male\", \"female\"],\n",
    "              \"age\": [\"young\", \"middle-aged\", \"senior\"],\n",
    "              \"skin_color\": [\"light\", \"medium\", \"dark\"],\n",
    "              \"ancestry\": [\"asian\", \"south_asian\", \"black\", \"latino/hispanic\", \"middle_eastern\", \"white\", \"indigenous\"],\n",
    "              \"hair_color\": [\"black\", \"brown\", \"red\", \"blonde\", \"gray\", \"other\"],\n",
    "              \"bangs\": [\"yes\", \"no\"],\n",
    "              \"bald\": [\"yes\", \"no\"],\n",
    "              \"beard\": [\"no\", \"mustache\", \"stubble\", \"full\"],\n",
    "              \"glasses\": [\"no\", \"regular\", \"sun\"],\n",
    "              \"headwear\": [\"no\", \"beanie\", \"cap\", \"hat\", \"headband\", \"hijab\", \"helmet\", \"turban\"],\n",
    "            }\n",
    "\n",
    "          Ensure the labeling is based on visible evidence only. If an attribute is unclear, return \"unknown\".\n",
    "          \n",
    "          Only output the JSON without any additional explanation or text.\n",
    "          \n",
    "          Example JSON output:\n",
    "          \n",
    "          {\n",
    "            \"gender\": \"female\",\n",
    "            \"age\": \"middle-aged\",\n",
    "            \"skin_color\": \"light\",\n",
    "            \"ancestry\": \"asian\",\n",
    "            \"hair_color\": \"black\",\n",
    "            \"bangs\": \"no\",\n",
    "            \"bald\": \"no\",\n",
    "            \"beard\": \"no\",\n",
    "            \"glasses\": \"sun\",\n",
    "            \"headwear\": \"beanie\",\n",
    "          }\n",
    "          \"\"\"\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0145ccd-ad74-4c08-92c6-c1e981ae3753",
   "metadata": {},
   "outputs": [],
   "source": [
    "default_keys = [\"gender\", \"age\", \"skin_color\", \"ancestry\", \"hair_color\", \"bangs\", \n",
    "                \"bald\", \"beard\", \"glasses\", \"headwear\"]\n",
    "\n",
    "# Function to process a single image\n",
    "def process_images(batch_images):\n",
    "    batch_input_ids = []\n",
    "    batch_attention_mask = []\n",
    "    batch_pixel_values = []\n",
    "    for image_path in batch_images:\n",
    "        image = Image.open(image_path)\n",
    "        query = f'<image>\\n{text}'\n",
    "        prompt, input_ids, pixel_values = model.preprocess_inputs(query, [image], max_partition=9)\n",
    "        attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)\n",
    "        batch_input_ids.append(input_ids.to(device=model.device))\n",
    "        batch_attention_mask.append(attention_mask.to(device=model.device))\n",
    "        batch_pixel_values.append(pixel_values.to(dtype=visual_tokenizer.dtype, device=visual_tokenizer.device))\n",
    "    batch_input_ids = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[0]) for i in batch_input_ids], batch_first=True,\n",
    "                                                      padding_value=0.0).flip(dims=[1])\n",
    "    batch_input_ids = batch_input_ids[:, -model.config.multimodal_max_length:]\n",
    "    batch_attention_mask = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[0]) for i in batch_attention_mask],\n",
    "                                                           batch_first=True, padding_value=False).flip(dims=[1])\n",
    "    batch_attention_mask = batch_attention_mask[:, -model.config.multimodal_max_length:]\n",
    "    # generate output\n",
    "    with torch.inference_mode():\n",
    "        gen_kwargs = dict(\n",
    "            max_new_tokens=1024,\n",
    "            do_sample=False,\n",
    "            top_p=None,\n",
    "            top_k=None,\n",
    "            temperature=None,\n",
    "            repetition_penalty=None,\n",
    "            eos_token_id=model.generation_config.eos_token_id,\n",
    "            pad_token_id=text_tokenizer.pad_token_id,\n",
    "            use_cache=True\n",
    "        )\n",
    "        output_ids = model.generate(batch_input_ids, pixel_values=batch_pixel_values, attention_mask=batch_attention_mask,**gen_kwargs)\n",
    "    result = []\n",
    "    for i in range(len(batch_images)):\n",
    "        output = text_tokenizer.decode(output_ids[i], skip_special_tokens=True)\n",
    "        # Extracting the JSON part\n",
    "        json_output = re.search(r\"\\{.*\\}\", output, re.DOTALL).group()\n",
    "        json_output = re.sub(r',\\s*}', '}', json_output)  # This removes a comma just before a closing curly brace\n",
    "        data = json.loads(json_output)\n",
    "        result.append([batch_images[i].replace('../RFW/data/', '')] + [data.get(key, \"unknown\") for key in default_keys])\n",
    "    torch.cuda.empty_cache()\n",
    "    return result\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6c880b3-cc8e-4c02-aa95-c308dab33fb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "init = 0\n",
    "files = glob.glob('../RFW/data/*/*/*')\n",
    "\n",
    "output_file = \"ovis2_labels.csv\"\n",
    "\n",
    "# Column names\n",
    "columns = [\"File Name\", \"Gender\", \"Age\", \"Skin Color\", \"Ancestry\", \"Hair Color\", \n",
    "           \"Bangs\", \"Bald\", \"Beard\", \"Glasses\", \"Headwear\"]\n",
    "\n",
    "# Chunk settings\n",
    "chunk_size = 6  # Number of images processed per batch\n",
    "\n",
    "# Write CSV header\n",
    "with open(output_file, \"w\", newline=\"\") as csvfile:\n",
    "    writer = csv.writer(csvfile)\n",
    "    writer.writerow(columns)  # CSV Header\n",
    "\n",
    "# Process in chunks\n",
    "for start_idx in tqdm(range(init, len(files), chunk_size), desc=\"Processing Images\"):\n",
    "    end_idx = min(start_idx + chunk_size, len(files))\n",
    "    chunk = files[start_idx:end_idx]\n",
    "    results = process_images(chunk)\n",
    "    # Save chunk results to CSV\n",
    "    with open(output_file, \"a\", newline=\"\") as csvfile:\n",
    "        writer = csv.writer(csvfile)\n",
    "        writer.writerows(results)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
