{"metadata":{"kernelspec":{"name":"bas","display_name":"bas"}},"nbformat_minor":5,"nbformat":4,"cells":[{"id":"0b344195-0e18-49c1-aedc-c7c5fa34448d","cell_type":"code","source":"import os\nimport json\nimport argparse\nfrom tqdm import tqdm\nimport numpy as np\nimport torch\nimport torchvision.transforms as T\nfrom PIL import Image, ImageDraw\nfrom diffusers.pipelines import FluxPipeline\n\nfrom src.flux.generate import generate, seed_everything\nfrom src.flux.condition import Condition\nfrom src.flux.module import Inter_Controller, Spatial_Controller\nfrom src.flux.pipeline_tools import visualize_masks\nfrom src.utils.dataset import find_applicable_scenes \n# from src.utils.dataset_eligen import json_generation \n\nfrom src.utils.scene import DiffusionScene\nfrom src.utils.prompt import gen_prompt, edit_prompt, identity_prompt, gen_prompt_2d, gen_prompt_new\nfrom src.utils.prompt_plan import gen_prompt, gen_prompt_new\nfrom src.utils.vlm import vlm_request, extract_and_parse_json, extract_and_parse_list\n","metadata":{"trusted":true,"libroFormatter":"formatter-string","libroCellType":"code","execution":{}},"execution_count":null,"outputs":[]},{"id":"a217a132-3933-4233-9288-92d163f243eb","cell_type":"code","source":"def json_generation(caption, entities=None):\n    if entities is None:\n        messages=[\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\"type\": \"text\", \"text\": f\"{identity_prompt.replace('<caption>', caption)}\"},\n                ]\n            }\n        ]\n        content = vlm_request(messages)\n        answer = content.split('</think>')[-1]\n        entities = extract_and_parse_list(answer)\n\n    messages=[\n        {\n            \"role\": \"user\",\n            \"content\": [\n                {\"type\": \"text\", \"text\": f\"{gen_prompt_new.replace('<caption>', caption).replace('<entities>', json.dumps(entities))}\"},\n            ]\n        }\n    ]\n    content = vlm_request(messages)\n    answer = content.split('</think>')[-1]\n    ans_json = extract_and_parse_json(answer)\n\n    data = {\n        'caption': caption,\n        'entities': entities,\n        'ans_json': ans_json,\n        'content': content,\n    }\n    \n    return data\n\ndef find_nonzero_bounding_box(vector):\n  \"\"\"\n  检测numpy向量（数组）中非零区域的边界框。\n\n  Args:\n    vector: 一个 NumPy 数组。\n\n  Returns:\n    如果向量中存在非零元素，则返回一个包含 (x_min, y_min, x_max, y_max) 的元组。\n    如果向量中所有元素都为零，则返回 None。\n  \"\"\"\n  # 检查输入是否为 NumPy 数组\n  if not isinstance(vector, np.ndarray):\n    raise TypeError(\"输入必须是 NumPy 数组\")\n\n  # 检查数组维度是否为 2\n  if vector.ndim != 2:\n      raise ValueError(\"输入数组必须是二维的\")\n\n  # 找到所有非零元素的索引\n  non_zero_indices = np.nonzero(vector)\n\n  # non_zero_indices 是一个包含两个数组的元组：\n  # 第一个数组是行索引 (y 坐标)\n  # 第二个数组是列索引 (x 坐标)\n  y_indices = non_zero_indices[0]\n  x_indices = non_zero_indices[1]\n\n  # 检查是否存在非零元素\n  if len(y_indices) == 0:\n    # 如果没有非零元素，则返回 None\n    return None\n\n  # 计算 x 和 y 坐标的最小值和最大值\n  y_min = np.min(y_indices)\n  y_max = np.max(y_indices)\n  x_min = np.min(x_indices)\n  x_max = np.max(x_indices)\n\n  return (x_min, y_min, x_max, y_max)\n\ndef entity_center(x_min, y_min, x_max, y_max, shape, step=0.1, c_max=0.05, c_min=0.0):\n    h, w = shape\n    if x_min>h*c_max and y_min>w*c_max and x_max<h*(1-c_max) and y_max<w*(1-c_max):\n        return step\n    else:\n        return 0\n\ndef layout_normalize(ans_json):\n    scene_size = ans_json['scene_parameters']['scene_size']\n    cam_pitch_angle = max(10, ans_json['scene_parameters']['camera_pitch_angle'])\n\n    ans_json['scene_parameters']['scene_size'] = 1\n    for i, entity in enumerate(ans_json['entity_layout']):\n        entity['position'] = [p / scene_size for p in entity['position']]\n        entity['size'] = [s / scene_size for s in entity['size']]\n    return ans_json\n\n\ndef generate_scene(ans_json_for_scene, total_move=None, total_move_y=None):\n    import copy\n    ans_json = copy.deepcopy(ans_json_for_scene)\n    ans_json = layout_normalize(ans_json)\n    scene_size = ans_json['scene_parameters']['scene_size'] / 2\n    cam_pitch_angle = 90 - ans_json['scene_parameters']['camera_pitch_angle']\n    print(cam_pitch_angle)\n    # cam_pitch_angle = 90\n    floor_scale_x = 1\n    floor_scale_y = 1\n\n    y_min = 100\n    y_max = 0\n    for i, entity in enumerate(ans_json['entity_layout']):\n        y_min = min(y_min, entity['position'][1] - entity['size'][2]/2)\n        y_max = max(y_max, entity['position'][1] + entity['size'][2]/2)\n    floor_offset = - (y_max + y_min) / 2\n    # floor_offset = scene_size/2\n\n    # x_min = 100\n    # x_max = 0\n    # for i, entity in enumerate(ans_json['entity_layout']):\n    #     x_min = min(x_min, entity['position'][0] - entity['size'][0]/2)\n    #     x_max = max(x_max, entity['position'][0] + entity['size'][0]/2)\n    # x_mean = (x_max + x_min) / 2\n    # for i, entity in enumerate(ans_json['entity_layout']):\n    #     entity['position'][0] -= x_mean\n\n    # Build the scene    \n    scene = DiffusionScene(scene_size=scene_size, fov=(60,60))\n    scene.move_camera(rotation_angle=cam_pitch_angle,rotation_axis=[1,0,0], translation=[0,0,0])# rotation_axis(x,z,y), translation(x, z, y)\n    # scene.move_camera(rotation_angle=0,rotation_axis=[1,0,0], translation=[0,-2*scene_size,0])# rotation_axis(x,z,y), translation(x, z, y)\n    scene.build_floor(scale_x=floor_scale_x, scale_y=floor_scale_y, floor_offset=floor_offset)\n\n    for i, entity in enumerate(ans_json['entity_layout']):\n        scene.add_box(id=f\"box_{i}\", size=entity['size'], origin=entity['position'], prompt=entity['entity_name'])\n        # scene.box(f\"box_{i}\").rotate_left(entity['orient'])\n        # mask_b2, latent_mask_b2, p_image_b2 = scene.get_box_masks(box_id=\"box_2\")\n\n    if total_move_y is not None: scene.move_camera(rotation_angle=0,rotation_axis=[1,0,0], translation=[0,0,total_move_y])# rotation_axis(x,z,y), translation(x, z, y)\n    if total_move is None:\n        num = 0\n        total_move = -0.68-0.5\n        scene.move_camera(rotation_angle=0,rotation_axis=[1,0,0], translation=[0,total_move,0])# rotation_axis(x,z,y), translation(x, z, y)\n        \n        depth_all = scene.render(single=True, floor=False, render_floor=False, depth_max=4*scene_size)\n        x_min, y_min, x_max, y_max = find_nonzero_bounding_box(depth_all[-1])\n        move = entity_center(x_min, y_min, x_max, y_max, depth_all[-1].shape)\n        while move != 0 and num < 20:\n            scene.move_camera(rotation_angle=0,rotation_axis=[1,0,0], translation=[0,move,0])# rotation_axis(x,z,y), translation(x, z, y)\n            total_move += move\n            depth_all = scene.render(single=True, floor=False, render_floor=False, depth_max=4*scene_size)\n            x_min, y_min, x_max, y_max = find_nonzero_bounding_box(depth_all[-1])\n            move = entity_center(x_min, y_min, x_max, y_max, depth_all[-1].shape)\n            num += 1\n    else:\n        scene.move_camera(rotation_angle=0,rotation_axis=[1,0,0], translation=[0,total_move,0])# rotation_axis(x,z,y), translation(x, z, y)\n\n    depth_all = scene.render(single=True, floor=False, depth_max=4*scene_size)\n    return depth_all, total_move","metadata":{"trusted":true,"libroFormatter":"formatter-string","execution":{}},"execution_count":null,"outputs":[]},{"id":"571df803-9f85-4e99-9b99-c3232621d538","cell_type":"code","source":"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n# Paths and parameters\nflux_path = \"black-forest-labs/FLUX.1-schnell\"\ncondition_size = 512\ntarget_size = 512\ncondition_type = \"eligen_loose\"\n\ndepth_path = 'depth_lora_path'\nsemantic_path = 'semantic_lora_path'\n\nmodel_config = {\n    'condition_type': \"eligen_loose\",\n    'inter_controller_type': None,\n    'eligen_depth_attn': False,\n    'latent_lora': ['semantic']\n}\n\n# Load model\npipe = FluxPipeline.from_pretrained(flux_path, torch_dtype=torch.bfloat16)\nprint(\"Load Flux model successfully\")\n\npipe.transformer.load_lora_adapter(depth_path, adapter_name='depth')\npipe.transformer.load_lora_adapter(semantic_path, adapter_name=\"semantic\",)\n\npipe.transformer.set_adapters(['semantic', 'depth'])\nprint(\"Load Flux lora successfully\")\nactive_adapters = pipe.get_active_adapters()\nprint(active_adapters)\n\npipe = pipe.to(device=device, dtype=torch.bfloat16)","metadata":{"trusted":true,"libroFormatter":"formatter-string","execution":{}},"execution_count":null,"outputs":[]},{"id":"54ee782a-68d6-4726-9f6c-3eb1d20390e8","cell_type":"code","source":"prompt =  \"\"\nentities = None or []\nseed = 42\nseed_everything(seed)\n\ndata = json_generation(prompt, entities)\nprint(data)\n\ntotal_move = None\ndepth_all, total_move = generate_scene(data['ans_json'], total_move=total_move)\ndata['ans_json']['total_move'] = total_move\ndata['seed'] = seed\nprint(total_move)\n\nos.makedirs(f'test/{prompt[:50]}', exist_ok=True)\njson_path = f'test/{prompt[:50]}/data.json'\nwith open(json_path, 'w') as f:\n    json.dump(data, f, indent=4)\n","metadata":{"trusted":true,"libroFormatter":"formatter-string","execution":{}},"execution_count":null,"outputs":[]},{"id":"57e37c51-72df-4176-960b-1f7f033592b9","cell_type":"code","source":"display(Image.fromarray(depth_all[-1]))","metadata":{"trusted":true,"libroFormatter":"formatter-string","execution":{}},"execution_count":null,"outputs":[]},{"id":"ef4b0cbe-c345-4c4a-996a-cf0f2cb0b298","cell_type":"code","source":"caption = data['caption']\neligen_entity_prompts = [entity['entity_name'] for entity in data['ans_json'][\"entity_layout\"]]\n\ncondition_imgs = []\nfor depth in depth_all[:-1]:\n    # depth = np.where(depth==depth_all[-1], depth, 0)\n    condition_imgs.append(Image.fromarray(depth).convert(\"RGB\").resize((condition_size, condition_size)))\n                \n# Process masks\neligen_entity_masks = []\neligen_entity_masks_pil = []\nfor img in condition_imgs:\n    # Create downsampled mask for model input\n    mask = np.array(img.resize((condition_size//8, condition_size//8)))\n    mask = np.where(mask > 0, 1, 0).astype(np.uint8)\n    mask_tensor = torch.from_numpy(mask).to(device=pipe.device, dtype=pipe.dtype)\n    eligen_entity_masks.append(mask_tensor.unsqueeze(0))\n    \n    # Create full resolution mask for visualization\n    mask_pil = np.where(np.array(img) > 0, 1, 0).astype(np.uint8)\n    eligen_entity_masks_pil.append(Image.fromarray(mask_pil*255))\n\n# Convert images to tensors and sort by depth\ncondition_imgs = torch.stack([T.ToTensor()(img) for img in condition_imgs])\n\n# Create final condition object\ncondition_data = {\n    \"condition\": condition_imgs,\n    \"eligen_entity_prompts\": eligen_entity_prompts,\n    \"eligen_entity_masks\": eligen_entity_masks,\n    'eligen_entity_masks_pil': eligen_entity_masks_pil,\n}\n\ncondition_ = Condition(\n    condition_type=condition_type,\n    condition=condition_data,\n    position_delta=[0, 0],\n)","metadata":{"trusted":true,"libroFormatter":"formatter-string","execution":{}},"execution_count":null,"outputs":[]},{"id":"5fb46df1-423d-49d8-adde-bc4fe561ace7","cell_type":"code","source":"# 构建generate参数\nnum_inference_steps = 4\n# num_inference_steps = 20\ngenerate_kwargs = {\n    \"prompt\": prompt,\n    \"default_lora\": True,\n    \"num_inference_steps\": num_inference_steps,\n    \"conditions\": [condition_],\n    \"height\": condition_size,\n    \"width\": condition_size,\n    \"eligen_entity_prompts\": eligen_entity_prompts,\n    \"eligen_entity_masks\": eligen_entity_masks,\n}\n\n# 执行生成\nimage = generate(\n    pipe, \n    model_config=model_config,\n    optimize=False,\n    layout=data['ans_json'],\n    eligen_entity_masks_pil=eligen_entity_masks_pil,\n    start=0,\n    end=2,\n    optim_step=1,\n    save_tmp_image=f'/mnt/workspace/workgroup/zheliu.lzy/vision_cot/OminiControl/show/{prompt[:50]}',\n    **generate_kwargs\n).images[0].resize((target_size, target_size))\ndisplay(image)\n\n# # 处理mask可视化\n# if eligen_entity_masks_pil:\n#     image_mask = visualize_masks(image, eligen_entity_masks_pil, eligen_entity_prompts)\n#     display(image_mask)\n","metadata":{"trusted":true,"libroFormatter":"formatter-string","execution":{}},"execution_count":null,"outputs":[]}]}