{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49b13a92-92c1-4529-8c43-d294d882ffa1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# conda create --name ola --clone qwen\n",
    "# conda activate ola\n",
    "# module load cuda/12.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "279ec1e3-571f-48f9-9e5b-497e610d1903",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['LOWRES_RESIZE'] = '384x32'\n",
    "os.environ['HIGHRES_BASE'] = '0x32'\n",
    "os.environ['VIDEO_RESIZE'] = \"0x64\"\n",
    "os.environ['VIDEO_MAXRES'] = \"480\"\n",
    "os.environ['VIDEO_MINRES'] = \"288\"\n",
    "os.environ['MAXRES'] = '1536'\n",
    "os.environ['MINRES'] = '0'\n",
    "os.environ['REGIONAL_POOL'] = '2x'\n",
    "os.environ['FORCE_NO_DOWNSAMPLE'] = '1'\n",
    "os.environ['LOAD_VISION_EARLY'] = '1'\n",
    "os.environ['SKIP_LOAD_VIT'] = '1'\n",
    "    \n",
    "\n",
    "import torch\n",
    "import json\n",
    "import glob\n",
    "import csv\n",
    "import re\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "import transformers\n",
    "from tqdm import tqdm\n",
    "from typing import Dict, Optional, Sequence, List\n",
    "from ola.conversation import conv_templates, SeparatorStyle, Conversation\n",
    "from ola.model.builder import load_pretrained_model\n",
    "from ola.utils import disable_torch_init\n",
    "from ola.datasets.preprocess import tokenizer_image_token, tokenizer_speech_image_token, tokenizer_speech_question_image_token\n",
    "from ola.mm_utils import get_model_name_from_path, KeywordsStoppingCriteria, process_anyres_video, process_anyres_highres_image\n",
    "from ola.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN\n",
    "\n",
    "model_path = \"THUdyh/Ola-7b\"\n",
    "tokenizer, model, image_processor, _ = load_pretrained_model(model_path, None)\n",
    "model = model.to('cuda').eval()\n",
    "model = model.bfloat16()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0145ccd-ad74-4c08-92c6-c1e981ae3753",
   "metadata": {},
   "outputs": [],
   "source": [
    "text = \"\"\"For the provided face image, output a JSON object with the following attributes:\n",
    "        \n",
    "            {\n",
    "              \"gender\": [\"male\", \"female\"],\n",
    "              \"age\": [\"young\", \"middle-aged\", \"senior\"],\n",
    "              \"skin_color\": [\"light\", \"medium\", \"dark\"],\n",
    "              \"ancestry\": [\"asian\", \"south_asian\", \"black\", \"latino/hispanic\", \"middle_eastern\", \"white\", \"indigenous\"],\n",
    "              \"hair_color\": [\"black\", \"brown\", \"red\", \"blonde\", \"gray\", \"other\"],\n",
    "              \"bangs\": [\"yes\", \"no\"],\n",
    "              \"bald\": [\"yes\", \"no\"],\n",
    "              \"beard\": [\"no\", \"mustache\", \"stubble\", \"full\"],\n",
    "              \"glasses\": [\"no\", \"regular\", \"sun\"],\n",
    "              \"headwear\": [\"no\", \"beanie\", \"cap\", \"hat\", \"headband\", \"hijab\", \"helmet\", \"turban\"],\n",
    "            }\n",
    "\n",
    "          Ensure the labeling is based on visible evidence only. If an attribute is unclear, return \"unknown\".\n",
    "          \n",
    "          Only output the JSON without any additional explanation or text.\n",
    "          \n",
    "          Example JSON output:\n",
    "          \n",
    "          {\n",
    "            \"gender\": \"female\",\n",
    "            \"age\": \"middle-aged\",\n",
    "            \"skin_color\": \"light\",\n",
    "            \"ancestry\": \"asian\",\n",
    "            \"hair_color\": \"black\",\n",
    "            \"bangs\": \"no\",\n",
    "            \"bald\": \"no\",\n",
    "            \"beard\": \"no\",\n",
    "            \"glasses\": \"sun\",\n",
    "            \"headwear\": \"beanie\",\n",
    "          }\n",
    "          \"\"\"\n",
    "\n",
    "default_keys = [\"gender\", \"age\", \"skin_color\", \"ancestry\", \"hair_color\", \"bangs\", \n",
    "                \"bald\", \"beard\", \"glasses\", \"headwear\"]\n",
    "\n",
    "# Function to process a single image\n",
    "def process_images(image_paths):\n",
    "    # Load images\n",
    "    images = [Image.open(path) for path in image_paths]\n",
    "    image_sizes = [img.size for img in images]\n",
    "    # Prompt setup\n",
    "    qs = DEFAULT_IMAGE_TOKEN + \"\\n\" + text\n",
    "    conv = Conversation(\n",
    "        system=\"\"\"<|im_start|>system\n",
    "You are an image analysis tool specialized in facial attribute classification.\"\"\",\n",
    "        roles=(\"<|im_start|>user\", \"<|im_start|>tool\"),\n",
    "        version=\"qwen\",\n",
    "        messages=[],\n",
    "        offset=0,\n",
    "        sep_style=SeparatorStyle.CHATML,\n",
    "        sep=\"<|im_end|>\",\n",
    "    )\n",
    "    conv.append_message(conv.roles[0], qs)\n",
    "    conv.append_message(conv.roles[1], None)\n",
    "    prompt = conv.get_prompt()\n",
    "    # Tokenize prompt and repeat it for all images in batch\n",
    "    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors=\"pt\").to('cuda')\n",
    "    input_ids = input_ids.unsqueeze(0).repeat(len(images), 1)  # (batch_size, seq_len)\n",
    "    pad_token_ids = 151643\n",
    "    attention_masks = input_ids.ne(pad_token_ids).long().to('cuda')\n",
    "    # Image preprocessing\n",
    "    image_processor.do_resize = False\n",
    "    image_processor.do_center_crop = False\n",
    "    image_tensor, image_highres_tensor = [], []\n",
    "    for img in images:\n",
    "        tensor, highres = process_anyres_highres_image(img, image_processor)\n",
    "        image_tensor.append(tensor)\n",
    "        image_highres_tensor.append(highres)\n",
    "    image_tensor = torch.stack(image_tensor).bfloat16().to(\"cuda\")  # (batch_size, C, H, W)\n",
    "    image_highres_tensor = torch.stack(image_highres_tensor).bfloat16().to(\"cuda\")\n",
    "    # Dummy speech inputs (repeated for each sample in batch)\n",
    "    batch_size = len(images)\n",
    "    speechs = [torch.zeros(1, 3000, 128).bfloat16().to('cuda')] * batch_size\n",
    "    speech_lengths = [torch.LongTensor([3000]).to('cuda')] * batch_size\n",
    "    speech_wavs = [torch.zeros([1, 480000]).to('cuda')] * batch_size\n",
    "    speech_chunks = [torch.LongTensor([1]).to('cuda')] * batch_size\n",
    "    # Stopping criteria\n",
    "    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n",
    "    keywords = [stop_str]\n",
    "    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n",
    "    # Inference\n",
    "    with torch.inference_mode():\n",
    "        output_ids = model.generate(\n",
    "            inputs=input_ids,\n",
    "            images=image_tensor,\n",
    "            images_highres=image_highres_tensor,\n",
    "            image_sizes=image_sizes,\n",
    "            modalities=['image'],\n",
    "            speech=speechs,\n",
    "            speech_lengths=speech_lengths,\n",
    "            speech_chunks=speech_chunks,\n",
    "            speech_wav=speech_wavs,\n",
    "            attention_mask=attention_masks,\n",
    "            use_cache=True,\n",
    "            stopping_criteria=[stopping_criteria],\n",
    "            do_sample=False,\n",
    "            num_beams=1,\n",
    "            max_new_tokens=1024,\n",
    "        )\n",
    "    # Decode outputs\n",
    "    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)\n",
    "    result = []\n",
    "    for i in range(len(image_paths)):\n",
    "        # Extracting the JSON part\n",
    "        output = outputs[i]\n",
    "        output = output.strip()\n",
    "        if output.endswith(stop_str):\n",
    "            output = output[:-len(stop_str)]\n",
    "        json_output = re.search(r\"\\{.*\\}\", output.strip(), re.DOTALL).group()\n",
    "        json_output = re.sub(r',\\s*}', '}', json_output)  # This removes a comma just before a closing curly brace\n",
    "        data = json.loads(json_output)\n",
    "        result.append([image_paths[i].replace('../ICCV 2025/RFW/data/', '')] + [data.get(key, \"unknown\") for key in default_keys])\n",
    "    torch.cuda.empty_cache()\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6c880b3-cc8e-4c02-aa95-c308dab33fb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "init = 28572\n",
    "files = glob.glob('../ICCV 2025/RFW/data/*/*/*')\n",
    "\n",
    "output_file = \"ola.csv\"\n",
    "\n",
    "# Column names\n",
    "columns = [\"File Name\", \"Gender\", \"Age\", \"Skin Color\", \"Ancestry\", \"Hair Color\", \n",
    "           \"Bangs\", \"Bald\", \"Beard\", \"Glasses\", \"Headwear\"]\n",
    "\n",
    "# Chunk settings\n",
    "chunk_size = 1  # Number of images processed per batch\n",
    "\n",
    "# # Write CSV header\n",
    "# with open(output_file, \"w\", newline=\"\") as csvfile:\n",
    "#     writer = csv.writer(csvfile)\n",
    "#     writer.writerow(columns)  # CSV Header\n",
    "\n",
    "# Process in chunks\n",
    "for start_idx in tqdm(range(init, len(files), chunk_size), desc=\"Processing Images\"):\n",
    "    end_idx = min(start_idx + chunk_size, len(files))\n",
    "    chunk = files[start_idx:end_idx]\n",
    "    results = process_images(chunk)\n",
    "    # Save chunk results to CSV\n",
    "    with open(output_file, \"a\", newline=\"\") as csvfile:\n",
    "        writer = csv.writer(csvfile)\n",
    "        writer.writerows(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "179ba76e-9ea3-418d-876e-344698220319",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
