{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49b13a92-92c1-4529-8c43-d294d882ffa1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# conda create --name sailvl --clone qwen\n",
    "# conda activate sailvl\n",
    "# module load cuda/12.2\n",
    "\n",
    "# pip install timm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc076001-0f19-4bab-9d14-3b9b2b5bdb50",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import csv\n",
    "import glob\n",
    "import json\n",
    "import torch\n",
    "import numpy as np\n",
    "import torch\n",
    "import torchvision.transforms as T\n",
    "\n",
    "from tqdm import tqdm\n",
    "from PIL import Image\n",
    "from torchvision.transforms.functional import InterpolationMode\n",
    "from transformers import AutoModel, AutoTokenizer\n",
    "\n",
    "IMAGENET_MEAN = (0.485, 0.456, 0.406)\n",
    "IMAGENET_STD = (0.229, 0.224, 0.225)\n",
    "\n",
    "def build_transform(input_size):\n",
    "    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD\n",
    "    transform = T.Compose([\n",
    "        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),\n",
    "        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),\n",
    "        T.ToTensor(),\n",
    "        T.Normalize(mean=MEAN, std=STD)\n",
    "    ])\n",
    "    return transform\n",
    "\n",
    "def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):\n",
    "    best_ratio_diff = float('inf')\n",
    "    best_ratio = (1, 1)\n",
    "    area = width * height\n",
    "    for ratio in target_ratios:\n",
    "        target_aspect_ratio = ratio[0] / ratio[1]\n",
    "        ratio_diff = abs(aspect_ratio - target_aspect_ratio)\n",
    "        if ratio_diff < best_ratio_diff:\n",
    "            best_ratio_diff = ratio_diff\n",
    "            best_ratio = ratio\n",
    "        elif ratio_diff == best_ratio_diff:\n",
    "            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:\n",
    "                best_ratio = ratio\n",
    "    return best_ratio\n",
    "\n",
    "def dynamic_preprocess(image, min_num=1, max_num=10, image_size=448, use_thumbnail=False):\n",
    "    orig_width, orig_height = image.size\n",
    "    aspect_ratio = orig_width / orig_height\n",
    "    # calculate the existing image aspect ratio\n",
    "    target_ratios = set(\n",
    "        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if\n",
    "        i * j <= max_num and i * j >= min_num)\n",
    "    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])\n",
    "    # find the closest aspect ratio to the target\n",
    "    target_aspect_ratio = find_closest_aspect_ratio(\n",
    "        aspect_ratio, target_ratios, orig_width, orig_height, image_size)\n",
    "    # calculate the target width and height\n",
    "    target_width = image_size * target_aspect_ratio[0]\n",
    "    target_height = image_size * target_aspect_ratio[1]\n",
    "    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]\n",
    "    # resize the image\n",
    "    resized_img = image.resize((target_width, target_height))\n",
    "    processed_images = []\n",
    "    for i in range(blocks):\n",
    "        box = (\n",
    "            (i % (target_width // image_size)) * image_size,\n",
    "            (i // (target_width // image_size)) * image_size,\n",
    "            ((i % (target_width // image_size)) + 1) * image_size,\n",
    "            ((i // (target_width // image_size)) + 1) * image_size\n",
    "        )\n",
    "        # split the image\n",
    "        split_img = resized_img.crop(box)\n",
    "        processed_images.append(split_img)\n",
    "    assert len(processed_images) == blocks\n",
    "    if use_thumbnail and len(processed_images) != 1:\n",
    "        thumbnail_img = image.resize((image_size, image_size))\n",
    "        processed_images.append(thumbnail_img)\n",
    "    return processed_images\n",
    "\n",
    "def load_image(image_file, input_size=448, max_num=10):\n",
    "    image = Image.open(image_file).convert('RGB')\n",
    "    transform = build_transform(input_size=input_size)\n",
    "    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)\n",
    "    pixel_values = [transform(image) for image in images]\n",
    "    pixel_values = torch.stack(pixel_values)\n",
    "    return pixel_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b1b9b4a-105b-4f68-9439-47cadb6faf4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"BytedanceDouyinContent/SAIL-VL-1d6-8B\"\n",
    "model = AutoModel.from_pretrained(\n",
    "    path,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    trust_remote_code=True).eval().cuda()\n",
    "tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)\n",
    "generation_config = dict(max_new_tokens=1024, do_sample=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0145ccd-ad74-4c08-92c6-c1e981ae3753",
   "metadata": {},
   "outputs": [],
   "source": [
    "text = \"\"\"You are an image analysis tool specialized in facial attribute classification. \n",
    "          For the provided face image, output a JSON object with the following attributes:\n",
    "        \n",
    "            {\n",
    "              \"gender\": [\"male\", \"female\"],\n",
    "              \"age\": [\"young\", \"middle-aged\", \"senior\"],\n",
    "              \"skin_color\": [\"light\", \"medium\", \"dark\"],\n",
    "              \"ancestry\": [\"asian\", \"south_asian\", \"black\", \"latino/hispanic\", \"middle_eastern\", \"white\", \"indigenous\"],\n",
    "              \"hair_color\": [\"black\", \"brown\", \"red\", \"blonde\", \"gray\", \"other\"],\n",
    "              \"bangs\": [\"yes\", \"no\"],\n",
    "              \"bald\": [\"yes\", \"no\"],\n",
    "              \"beard\": [\"no\", \"mustache\", \"stubble\", \"full\"],\n",
    "              \"glasses\": [\"no\", \"regular\", \"sun\"],\n",
    "              \"headwear\": [\"no\", \"beanie\", \"cap\", \"hat\", \"headband\", \"hijab\", \"helmet\", \"turban\"],\n",
    "            }\n",
    "\n",
    "          Ensure the labeling is based on visible evidence only. If an attribute is unclear, return \"unknown\".\n",
    "          \n",
    "          Only output the JSON without any additional explanation or text.\n",
    "          \n",
    "          Example JSON output:\n",
    "          \n",
    "          {\n",
    "            \"gender\": \"female\",\n",
    "            \"age\": \"middle-aged\",\n",
    "            \"skin_color\": \"light\",\n",
    "            \"ancestry\": \"asian\",\n",
    "            \"hair_color\": \"black\",\n",
    "            \"bangs\": \"no\",\n",
    "            \"bald\": \"no\",\n",
    "            \"beard\": \"no\",\n",
    "            \"glasses\": \"sun\",\n",
    "            \"headwear\": \"beanie\",\n",
    "          }\n",
    "          \"\"\"\n",
    "\n",
    "default_keys = [\"gender\", \"age\", \"skin_color\", \"ancestry\", \"hair_color\", \"bangs\", \n",
    "                \"bald\", \"beard\", \"glasses\", \"headwear\"]\n",
    "\n",
    "# Function to process a single image\n",
    "def process_images(batch_images):\n",
    "    # Sample messages for batch inference\n",
    "    pixel_values = []\n",
    "    for image_path in batch_images:\n",
    "        pixel_values.append(load_image(image_path, max_num=12).to(torch.bfloat16).cuda())\n",
    "    num_patches_list = [pixel_value.size(0) for pixel_value in pixel_values]\n",
    "    pixel_values = torch.cat(pixel_values, dim=0)\n",
    "    questions = [f'<image>\\n{text}'] * len(num_patches_list)\n",
    "    \n",
    "    with torch.inference_mode():\n",
    "        responses = model.batch_chat(tokenizer, pixel_values,\n",
    "                                 num_patches_list=num_patches_list,\n",
    "                                 questions=questions,\n",
    "                                 generation_config=generation_config)\n",
    "    result = []\n",
    "    for i in range(len(batch_images)):\n",
    "        # Extracting the JSON part\n",
    "        json_output = re.search(r\"\\{.*\\}\", responses[i], re.DOTALL).group()\n",
    "        json_output = re.sub(r',\\s*}', '}', json_output)  # This removes a comma just before a closing curly brace\n",
    "        data = json.loads(json_output)\n",
    "        result.append([batch_images[i].replace('../RFW/data/', '')] + [data.get(key, \"unknown\") for key in default_keys])\n",
    "    torch.cuda.empty_cache()\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6c880b3-cc8e-4c02-aa95-c308dab33fb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "init = 0\n",
    "files = glob.glob('../RFW/data/*/*/*')\n",
    "\n",
    "output_file = \"sailvl.csv\"\n",
    "\n",
    "# Column names\n",
    "columns = [\"File Name\", \"Gender\", \"Age\", \"Skin Color\", \"Ancestry\", \"Hair Color\", \n",
    "           \"Bangs\", \"Bald\", \"Beard\", \"Glasses\", \"Headwear\"]\n",
    "\n",
    "# Chunk settings\n",
    "chunk_size = 32  # Number of images processed per batch\n",
    "\n",
    "# Write CSV header\n",
    "with open(output_file, \"w\", newline=\"\") as csvfile:\n",
    "    writer = csv.writer(csvfile)\n",
    "    writer.writerow(columns)  # CSV Header\n",
    "\n",
    "# Process in chunks\n",
    "for start_idx in tqdm(range(init, len(files), chunk_size), desc=\"Processing Images\"):\n",
    "    end_idx = min(start_idx + chunk_size, len(files))\n",
    "    chunk = files[start_idx:end_idx]\n",
    "    results = process_images(chunk)\n",
    "    # Save chunk results to CSV\n",
    "    with open(output_file, \"a\", newline=\"\") as csvfile:\n",
    "        writer = csv.writer(csvfile)\n",
    "        writer.writerows(results)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6956ff6e-b82e-4ac2-be5f-7aebf9785530",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
