import argparse
import prompts
import os
import pickle
import json
import tqdm
import cloudgpt_api
import torch
import datasets
import compute_results
import termcolor
from typing import Optional, Tuple, List, Dict, Union
import data_utils
import numpy as np
import json
import PIL
import asyncio
import aiohttp
import functools
import utils
import clip
import lavis
import numpy as np
import termcolor
import torch

from tools_pool import ToolsPool
import re


def parser_args():
    parser = argparse.ArgumentParser('')
    parser.add_argument("--device", type=int, default=0,
                        help='GPU to use for computation.')
    parser.add_argument("--preload", nargs='+', type=str, default=['img_features', 'captions', 'mods'],
                        help='List of properties to preload is computed once before.')
    parser.add_argument("--preload_path", nargs='+', type=str, default=r"results/",
                        help='preload file path.')
    # Base Model Choices
    parser.add_argument("--clip", type=str, default='ViT-B/32',
                        choices=['ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'RN50x4', 'ViT-bigG-14',
                                 'ViT-B-32', 'ViT-B-16', 'ViT-L-14', 'ViT-H-14', 'ViT-g-14'],
                        help="Which CLIP text-to-image retrieval model to use"),
    parser.add_argument("--blip", type=str, default='blip2_t5', choices=['blip2_t5'],
                        help="BLIP Image Caption Model to use.")
    # Dataset Arguments ['dress', 'toptee', 'shirt']
    parser.add_argument("--dataset", type=str, required=True,
                        choices=['cirr', 'circo',
                                 'fashioniq_dress', 'fashioniq_toptee', 'fashioniq_shirt',
                                 'genecis_change_attribute', 'genecis_change_object', 'genecis_focus_attribute',
                                 'genecis_focus_object'],
                        help="Dataset to use")
    parser.add_argument("--split", type=str, default='val', choices=['val', 'test'],
                        help='Dataset split to evaluate on. Some datasets require special testing protocols s.a. cirr/circo.')
    parser.add_argument("--dataset-path", type=str, required=True,
                        help="Path to the dataset")
    parser.add_argument("--preprocess-type", default="targetpad", type=str, choices=['clip', 'targetpad'],
                        help="Preprocess pipeline to use")
    # LLM & BLIP Prompt Arguments.
    available_prompts = [f'prompts.{x}' for x in prompts.__dict__.keys() if '__' not in x]
    parser.add_argument("--llm_prompt", default='prompts.simple_modifier_prompt', type=str, choices=available_prompts,
                        help='Denotes the base prompt to use to probe the LLM. Has to be available in prompts.py')
    parser.add_argument("--blip_prompt", default='prompts.blip_prompt', type=str, choices=available_prompts,
                        help='Denotes the base prompt to use alongside BLIP. Has to be available in prompts.py')
    parser.add_argument("--gpt_cap_prompt", default='prompts.gpt4v_cap_prompt', type=str, choices=available_prompts,
                        help='Denotes the base prompt to use alongside GPT4V. Has to be available in prompts.py')
    parser.add_argument("--gpt_cir_prompt", default='prompts.mllm_structural_modifier_prompt_fashion', type=str,
                        choices=available_prompts,
                        help='Denotes the base prompt to use alongside GPT4V. Has to be available in prompts.py')
    parser.add_argument("--gpt_cir_refine_prompt", default='prompts.mllm_refine_with_tool_evidence_prompt', type=str,
                        choices=available_prompts, help='Prompt for refining descriptions with tool-generated evidence')
    parser.add_argument("--weight-path", type=str, default='',
                        help='Where to store OpenCLIP weights.')
    parser.add_argument("--openai_engine", default='gpt-35-turbo-1106', type=str,
                        choices=["gpt-image-1-0415-global",
                                 "gpt-4-0409",
                                 "gpt-4-turbo-128k",
                                 "gpt-4o-mini-0718",
                                 "gpt-4o-mini-0718-global",
                                 "gpt-4o-0806-global",
                                 "gpt-4o-0806",
                                 "gpt-4o-0513-global",
                                 "gpt-4o-0513",
                                 "gpt-45-0227-global",
                                 "gpt-41-0414-global",
                                 "gpt-41-mini-0414-global",
                                 "gpt-41-nano-0414-global",
                                 "gemini-2.5-pro-06-17",
                                 "qwen2.5-vl-72b-instruct",
                                 "o1-preview-0912",
                                 "o1-preview-0912-global",
                                 "o1-mini-0912",
                                 "o1-mini-0912-global",
                                 "o3-0416-global",
                                 "o3-mini-0131-global",
                                 "o4-mini-0416-global",
                                 "gpt-5-0807-global",
                                 "gpt-5-mini-0807-global",
                                 "gpt-5-nano-0807-global",
                                 ],
                        help='Openai LLM Engine to use.')
    parser.add_argument("--api_key", type=str, required=True,
                        help='Openai API Key to use.')
    parser.add_argument("--base_url", type=str, required=True,
                        help='Openai API Base URL to use.')
    parser.add_argument("--batch_size", default=512, type=int,
                        help='Batch size to use.')
    parser.add_argument("--retrieval", type=str, default='default', choices=['default'],
                        help='Type of T2I Retrieval method.')
    # Tool Used Arguments
    parser.add_argument("--tool_cache_dir", type=str, default="./tool_cache",
                        help='Directory for tool cache and downloaded images')
    args = parser.parse_args()
    return args


args = parser_args()

# Initialize tools pool after args parsing
tools_pool = ToolsPool(
    api_key=args.api_key,
    cache_dir="./tool_cache"
)


def get_predeal_dict():
    ### Argument Checks.
    preload_dict = {key: None for key in ['img_features', 'captions', 'mods']}
    preload_str = f'{args.dataset}_{args.openai_engine}_{args.clip}_{args.split}'.replace('/',
                                                                                          '-')  # fashioniq_dress_blip2_t5_ViT-g-14_val
    print(preload_str)

    if len(args.preload):
        os.makedirs(os.path.join(args.preload_path, 'precomputed'), exist_ok=True)
    if 'img_features' in args.preload:
        # # CLIP embeddings only have to be computed when CLIP model changes.
        # img_features_load_str = f'{args.dataset}_{args.clip}_{args.split}'.replace('/', '-')
        preload_dict['img_features'] = os.path.join(args.preload_path, 'precomputed', preload_str + '_img_features.pkl')

    if 'captions' in args.preload:
        # # BLIP captions only have to be computed when BLIP model or BLIP prompt changes.
        caption_load_str = f'{args.dataset}_{args.openai_engine}_{args.split}'.replace('/', '-')
        if args.gpt_cap_prompt != 'prompts.blip_prompt':
            preload_dict['captions'] = os.path.join(args.preload_path, 'precomputed',
                                                    caption_load_str + f'_captions_{args.gpt_cap_prompt.split(".")[-1]}.pkl')
        else:
            preload_dict['captions'] = os.path.join(args.preload_path, 'precomputed',
                                                    caption_load_str + '_captions.pkl')

    if 'mods' in args.preload:
        # # LLM-based caption modifications have to be queried only when MLLM model or MLLM prompt changes.
        mod_load_str = f'{args.dataset}_{args.split}'.replace('/', '-')

        # Include both CIR prompt and refine prompt in filename
        cir_prompt_name = args.gpt_cir_prompt.split(".")[-1]
        refine_prompt_name = args.gpt_cir_refine_prompt.split(".")[-1]

        preload_dict['mods'] = os.path.join(args.preload_path, 'precomputed',  # FIX: Use args.preload_path
                                            mod_load_str + f'_mods_{cir_prompt_name}_refine_{refine_prompt_name}.pkl')

        if args.openai_engine != 'gpt-3.5-turbo':
            preload_dict['mods'] = preload_dict['mods'].replace('.pkl', f'_{args.openai_engine}.pkl')

    if preload_dict['mods']:
        print(f"Using mods file: {preload_dict['mods']}")

    if args.split == 'test':
        # Include refine prompt in test submission filename as well
        refine_prompt_name = args.gpt_cir_refine_prompt.split(".")[-1]
        preload_dict[
            'test'] = preload_str + f'{args.gpt_cir_prompt.split(".")[-1]}_refine_{refine_prompt_name}_test_submission.json'

    return preload_dict


async def process_single_item_async_with_tools(sys_prompt, instruction, image_path, tgt_image_path, args, tools_pool):
    """
    Enhanced with multi-image support for tools and tracking both original and refined descriptions
    """
    base_user_prompt = f'''
    <Input>
        {{
            "Original Image": <image_url>
            "Manipulation text": {instruction}.
        }}
    </Input>
    '''

    max_retries = 6
    max_tool_iterations = 2

    tool_results = {}
    current_image_path = image_path
    tool_was_used = False

    # Track both original and refined descriptions
    original_description = None
    refined_description = None
    original_thought = ""
    original_reflection = ""

    for tool_iteration in range(max_tool_iterations + 1):
        enhanced_user_prompt = base_user_prompt

        # If we have tool results from previous iteration, add context
        if tool_results and 'search' in tool_results:
            search_data = tool_results['search']
            if 'text' in search_data:
                enhanced_user_prompt += f"\n\n{search_data['text']}"
                enhanced_user_prompt += "\n\nBased on the visual references, please provide a refined Target Image Description."

        for attempt in range(max_retries):
            try:
                loop = asyncio.get_running_loop()

                # Prepare the appropriate function based on tool results

                if tool_results and tool_was_used:
                    # Use multi-image function with tool results for final refinement
                    print(f"   → Generating refined description with tool results...")

                    # CHANGE TO: Use different prompt for refinement stage
                    sys_prompt_refine = eval(args.gpt_cir_refine_prompt)  # New refinement prompt

                    # Build the refinement user prompt with proper format
                    tool_type = tool_results.get('tool_type', '')

                    if tool_type == 'searching':
                        # Extract search title if available
                        search_title = ""
                        if 'search' in tool_results and 'images' in tool_results['search']:
                            search_images = tool_results['search']['images']
                            if search_images and len(search_images) > 0:
                                search_title = search_images[0].get('title', 'Search Result')

                        enhanced_user_prompt = f'''<Input>
                                                {{
                                                    "Original Image": <original_image>,
                                                    "Manipulation text": "{instruction}",
                                                    "Tool Evidence": <search_result_image>,
                                                    "Tool Type": "search",
                                                    "Search Title": "{search_title}"
                                                }}
                                                </Input>'''

                    elif tool_type == 'image_editing':
                        enhanced_user_prompt = f'''<Input>
                                                {{
                                                    "Original Image": <original_image>,
                                                    "Manipulation text": "{instruction}",
                                                    "Tool Evidence": <edited_image>,
                                                    "Tool Type": "edit"
                                                }}
                                                </Input>'''

                    else:
                        # Fallback format
                        enhanced_user_prompt = base_user_prompt

                    # Call with refinement prompt and structured input
                    func_with_args = functools.partial(
                        cloudgpt_api.attempt_openai_completion_with_tools,
                        sys_prompt=sys_prompt_refine,  # Use refinement prompt
                        user_prompt=enhanced_user_prompt,
                        original_image=image_path,
                        tool_results=tool_results,
                        engine=args.openai_engine,
                        api_key=args.api_key,
                        base_url=args.base_url
                    )
                else:
                    # Standard single image call for initial analysis
                    func_with_args = functools.partial(
                        cloudgpt_api.attempt_openai_completion_CoT,
                        sys_prompt=sys_prompt,
                        user_prompt=enhanced_user_prompt,
                        image=current_image_path,
                        engine=args.openai_engine,
                        api_key=args.api_key,
                        base_url=args.base_url
                    )

                resp = await loop.run_in_executor(None, func_with_args)

                # print(resp)

                # Parse response
                if resp.startswith('<Response>'):
                    resp = resp.replace('<Response>', '').replace('</Response>', '').strip()
                if resp.startswith('```json'):
                    resp = resp.replace('```json', '').replace('```', '').strip()

                resp_dict = json.loads(resp)

                # Extract components
                thought = resp_dict.get('Thoughts', '')
                reflection = resp_dict.get('Reflections', '')
                tool_usage = resp_dict.get('Tool Usage', 'None')
                modified_caption = resp_dict.get('Target Image Description', instruction)

                # Store original description if this is the first response
                if original_description is None and tool_iteration == 0:
                    original_description = modified_caption
                    original_thought = thought
                    original_reflection = reflection

                # If tools were just used, we got our final description
                if tool_was_used and tool_iteration > 0:
                    # Parse refinement stage response with new format
                    original_img_desc = resp_dict.get('Original Image Description', None)
                    tool_evidence_desc = resp_dict.get('Tool-Generated Visual Evidence', None)
                    refined_thought = resp_dict.get('Thoughts', None)
                    refined_reflection = resp_dict.get('Reflections', None)
                    refined_caption = resp_dict.get('Target Image Description', instruction)

                    refined_description = refined_caption
                    if not refined_description:
                        refined_description = original_description or instruction

                    print(f"\n   ===== DESCRIPTION COMPARISON =====")
                    print(f"   Original Description: {original_description[:150]}...")
                    print(f"   Refined Description:  {refined_description[:150]}...")
                    if tool_evidence_desc:
                        print(f"   Tool Evidence: {tool_evidence_desc[:150]}...")
                    print(f"   ===================================\n")

                    # Enhanced metadata to include all fields
                    combined_info = {
                        'original_description': original_description,
                        'refined_description': refined_description,
                        'original_thought': original_thought,
                        'refined_thought': refined_thought,
                        'original_reflection': original_reflection,
                        'refined_reflection': refined_reflection,
                        'tool_used': tool_results.get('tool_type', None),
                        'tool_query': tool_results.get('query', None),
                        # New fields from refinement response
                        'original_img_desc': original_img_desc,
                        'tool_evidence_desc': tool_evidence_desc
                    }

                    # Return with enhanced metadata
                    return refined_description, json.dumps(combined_info), refined_reflection, tool_results.get(
                        'original_tool_usage', tool_usage)

                # Execute tools if requested (only in first iterations)
                if tool_usage != "None" and tool_iteration < max_tool_iterations and not tool_was_used:
                    if isinstance(tool_usage, str):
                        try:
                            tool_usage = json.loads(tool_usage)
                        except:
                            pass

                    if isinstance(tool_usage, list) and len(tool_usage) > 0:
                        tool_info = tool_usage[0]
                        tool_name = tool_info.get('tool', '')
                        tool_query = tool_info.get('query', '').replace('<search>', '').replace('</search>',
                                                                                                '').replace('<edit>',
                                                                                                            '').replace(
                            '</edit>', '').strip()

                        if tool_name == "searching":
                            # print(f"\n Executing visual search: '{tool_query[:80]}...'")

                            # Execute search
                            search_results = await tools_pool.search_images_with_download(tool_query, num_results=1)

                            if search_results and search_results[0].get('image_available'):
                                # Format for multi-image API
                                formatted_data = tools_pool.format_search_results_for_llm(search_results)
                                tool_results = {
                                    'search': formatted_data,
                                    'tool_type': 'searching',
                                    'query': tool_query,
                                    'original_tool_usage': tool_usage  # Keep original for counting
                                }
                                tool_was_used = True
                                print(f"     Found visual reference, will refine description")
                                # Continue to next iteration with search results
                                break
                            else:
                                print("     No search results found, continuing without tool")
                                tool_usage = "None"  # Reset to prevent counting

                        elif tool_name == "image_editing":
                            # print(f"\n Executing image edit: '{tool_query[:80]}...'")

                            # Make sure we're using a local image path
                            image_to_edit = current_image_path

                            # Check if it's a URL (shouldn't happen with our setup)
                            if isinstance(image_to_edit, str) and image_to_edit.startswith('http'):
                                print(f"     Cannot edit URL, need local file")
                                tool_usage = "None"
                            else:
                                target_img_path = tgt_image_path

                                edited_image_path = await tools_pool.edit_image(
                                    image_to_edit,
                                    tool_query,
                                    output_dir=None,
                                    original_manipulation=instruction,
                                    target_image_path=target_img_path
                                )

                                if edited_image_path:
                                    current_image_path = edited_image_path
                                    tool_results = {
                                        'edited_image': edited_image_path,
                                        'tool_type': 'image_editing',
                                        'query': tool_query,
                                        'original_tool_usage': tool_usage  # Keep original for counting
                                    }
                                    tool_was_used = True
                                    print(f"     Image edited, will refine description")
                                    # Continue to next iteration with edited image
                                    break
                                else:
                                    print("     Image editing failed, continuing without tool")
                                    tool_usage = "None"

                # If no tools needed or max iterations reached
                if tool_usage == "None" or tool_iteration >= max_tool_iterations:
                    if not modified_caption:
                        modified_caption = instruction

                    # Keep consistent structure with None values for unused fields
                    no_tool_info = {
                        'original_description': modified_caption,
                        'refined_description': modified_caption,
                        'original_thought': thought,
                        'refined_thought': None,
                        'original_reflection': reflection,
                        'refined_reflection': None,
                        'tool_used': None,
                        'tool_query': None,
                        'original_img_desc': None,
                        'tool_evidence_desc': None
                    }

                    # Return with consistent structure
                    return modified_caption, json.dumps(no_tool_info) if thought else thought, reflection, "None"

            except json.JSONDecodeError as e:
                print(f"Attempt {attempt + 1}/{max_retries} failed: Invalid JSON - {str(e)[:50]}")
            except Exception as e:
                print(f"Attempt {attempt + 1}/{max_retries} failed: {str(e)[:100]}")

            if attempt < max_retries - 1:
                await asyncio.sleep(1)

    # Fallback
    return instruction, "", "", "None"


# Update process_batch_async
async def process_batch_async(sys_prompt, relative_captions, ref_image_paths, target_image_paths, args):
    """
    Creates and runs asynchronous tasks with tools pool support
    """
    # Initialize tools pool once for the batch
    tools_pool = ToolsPool(api_key=args.api_key)

    tasks = []
    for i in range(len(ref_image_paths)):
        instruction = relative_captions[i]
        image_path = ref_image_paths[i]
        tgt_image_path = target_image_paths[i]  # for visualization
        # Create a task with tools pool support
        task = process_single_item_async_with_tools(
            sys_prompt, instruction, image_path, tgt_image_path, args, tools_pool
        )
        tasks.append(task)

    # Run all tasks concurrently
    results = await asyncio.gather(*tasks)
    return results


def MLLM_CIR(device: torch.device, args: argparse.Namespace, query_dataset: torch.utils.data.Dataset,
             preload_dict: Dict[str, Union[str, None]], **kwargs) -> Tuple[torch.Tensor, List[str], list]:
    if preload_dict['mods'] is None or not os.path.exists(preload_dict['mods']):
        # GENERATE NEW RESULTS
        all_modified_captions = []
        all_original_captions = []  # Track original descriptions
        all_refined_captions = []  # Track refined descriptions
        all_thoughts, all_reflations, all_relative_captions = [], [], []
        gt_img_ids, query_ids, target_names, reference_names = [], [], [], []

        all_original_img_descs = []  # Original image descriptions from refinement
        all_tool_evidence_descs = []  # Tool-generated visual evidence descriptions
        all_refined_thoughts = []  # Refined thoughts
        all_refined_reflections = []  # Refined reflections
        all_tool_used = []  # Tool types used
        all_tool_queries = []  # Tool queries used

        query_loader = torch.utils.data.DataLoader(
            dataset=query_dataset, batch_size=args.batch_size, num_workers=8,
            pin_memory=False, collate_fn=data_utils.collate_fn, shuffle=False)

        query_iterator = tqdm.tqdm(query_loader, position=0, desc='Predicting Target captions with MLLM...')
        # count each tools used
        tot_search_cnt, tot_edit_cnt, tot_reason_cnt = 0, 0, 0

        for batch in query_iterator:
            # --- Batch data preparation (remains the same) ---
            batch_reference_names, batch_target_names = [], []
            if 'genecis' in args.dataset:
                ref_image_path = batch[0]
                relative_captions = batch[1]
            else:
                ref_image_path = batch['reference_image_path']
                reference_names.extend(batch['reference_name'])
                batch_reference_names.extend(batch['reference_name'])
                if 'fashioniq' not in args.dataset:
                    relative_captions = batch['relative_caption']
                else:
                    rel_caps = np.array(batch['relative_captions']).T.flatten().tolist()
                    relative_captions = [f"{rel_caps[i].strip('.?, ')} and {rel_caps[i + 1].strip('.?, ')}" for i in
                                         range(0, len(rel_caps), 2)]

                if 'target_name' in batch:
                    target_names.extend(batch['target_name'])
                    target_image_path = batch['target_image_path']
                    batch_target_names.extend(batch['target_name'])
                else:  # as a placeholder
                    target_names.extend(batch['reference_name'])
                    target_image_path = batch['reference_image_path']
                    batch_target_names.extend(batch['reference_name'])
                if 'gt_img_ids' in batch:
                    gt_img_ids.extend(np.array(batch['gt_img_ids']).T.tolist())
                if 'pair_id' in batch:
                    query_ids.extend(batch['pair_id'])

            sys_prompt = eval(args.gpt_cir_prompt)

            # --- ASYNCHRONOUS API CALLS ---
            results = asyncio.run(
                process_batch_async(sys_prompt, relative_captions, ref_image_path, target_image_path, args))

            # --- Unpack the results ---
            modified_captions, thoughts, reflations, tool_useds = zip(*results)

            # --- Post-processing and logging ---
            for i, (caption, thought, reflection, tool_used_list) in enumerate(
                    zip(modified_captions, thoughts, reflations, tool_useds)):
                # Initialize with None for all optional fields
                original_desc = caption
                refined_desc = caption
                original_img_desc = None
                tool_evidence_desc = None
                refined_thought = None
                refined_reflection = None
                tool_type = None
                tool_query = None

                # Try to extract all fields from thought metadata
                try:
                    if thought and thought.startswith('{'):
                        thought_data = json.loads(thought)
                        if 'original_description' in thought_data:
                            original_desc = thought_data.get('original_description', caption)
                            refined_desc = thought_data.get('refined_description', caption)
                            # Extract all new fields, keeping None if not present
                            original_img_desc = thought_data.get('original_img_desc', None)
                            tool_evidence_desc = thought_data.get('tool_evidence_desc', None)
                            refined_thought = thought_data.get('refined_thought', None)
                            refined_reflection = thought_data.get('refined_reflection', None)
                            tool_type = thought_data.get('tool_used', None)
                            tool_query = thought_data.get('tool_query', None)
                            # Restore original thought for backward compatibility
                            thought = thought_data.get('original_thought', thought)
                            reflection = thought_data.get('original_reflection', reflection)
                except:
                    pass

                # Store all fields
                all_original_captions.append(original_desc)
                all_refined_captions.append(refined_desc)
                all_original_img_descs.append(original_img_desc)
                all_tool_evidence_descs.append(tool_evidence_desc)
                all_refined_thoughts.append(refined_thought)
                all_refined_reflections.append(refined_reflection)
                all_tool_used.append(tool_type)
                all_tool_queries.append(tool_query)

                # Tool counting and logging (keep original logic)
                if tool_used_list != "None":
                    if type(tool_used_list) is str:
                        try:
                            tool_used_list = json.loads(tool_used_list)
                        except:
                            pass

                    if isinstance(tool_used_list, list) and len(tool_used_list) > 0:
                        tool_name = tool_used_list[0].get('tool', '')
                        tool_used_query = tool_used_list[0].get('query', '')

                        if tool_name == "searching":
                            tot_search_cnt += 1
                            tool_used_query = tool_used_query.replace('<search>', '').replace('</search>', '')
                        elif tool_name == "reasoning":
                            tot_reason_cnt += 1
                            tool_used_query = tool_used_query.replace('<reasoning>', '').replace('</reasoning>', '')
                        elif tool_name == "image_editing":
                            tot_edit_cnt += 1
                            tool_used_query = tool_used_query.replace('<edit>', '').replace('</edit>', '')

                        if tool_name != "reasoning":
                            # Enhanced logging with all fields
                            print("\n" + "=" * 80)
                            print(
                                f" Reference: {batch_reference_names[i] if len(batch_reference_names) > i else 'N/A'}")
                            print(f" Manipulation: {relative_captions[i]}")
                            print(f" Tool Used: {tool_name} - Query: {tool_used_query[:100]}...")

                            if original_img_desc:
                                print(f"\n [Refinement] Original Image Desc: {original_img_desc[:150]}...")
                            if tool_evidence_desc:
                                print(f" [Refinement] Tool Evidence: {tool_evidence_desc[:150]}...")

                            print(f"\n Initial Target Description: {original_desc[:150]}...")
                            print(f" Refined Target Description: {refined_desc[:150]}...")

                            if original_desc != refined_desc:
                                print(f" ✓ Description was refined using {tool_name}")

                            print(f"\n Target: {batch_target_names[i] if len(batch_target_names) > i else 'N/A'}")
                            print("=" * 80 + "\n")

            all_modified_captions.extend(modified_captions)
            all_thoughts.extend(thoughts)
            all_reflations.extend(reflations)
            all_relative_captions.extend(relative_captions)

            print(f"\n Progress: {len(all_modified_captions)} items processed")
            print(f"    Search: {tot_search_cnt} |  Edit: {tot_edit_cnt} |  Reason: {tot_reason_cnt}")

        # --- Saving results ---
        if preload_dict['mods'] is not None:
            res_dict = {
                'target_names': target_names,
                'targets': gt_img_ids,
                'reference_names': reference_names,
                'query_ids': query_ids,
                'start_captions': [],
                'thoughts': all_thoughts,
                'reflections': all_reflations,
                'modified_captions': all_modified_captions,
                'original_descriptions': all_original_captions,
                'refined_descriptions': all_refined_captions,
                'instructions': all_relative_captions,
                # New fields - all can be None
                'original_img_descs': all_original_img_descs,
                'tool_evidence_descs': all_tool_evidence_descs,
                'refined_thoughts': all_refined_thoughts,
                'refined_reflections': all_refined_reflections,
                'tool_used': all_tool_used,
                'tool_queries': all_tool_queries
            }
            pickle.dump(res_dict, open(preload_dict['mods'], 'wb'))

            print(f"\n===== Final Statistics =====")
            print(f"Total items processed: {len(all_modified_captions)}")
            print(f"Items with tools used: {sum(1 for t in all_tool_used if t is not None)}")
            print(f"  - Search: {sum(1 for t in all_tool_used if t == 'searching')}")
            print(f"  - Edit: {sum(1 for t in all_tool_used if t == 'image_editing')}")
            print(
                f"Items refined: {sum(1 for i, (orig, ref) in enumerate(zip(all_original_captions, all_refined_captions)) if orig != ref)}")
            print(f"Items with tool evidence: {sum(1 for t in all_tool_evidence_descs if t is not None)}")
            print("============================\n")

    else:
        # LOAD FROM PICKLE
        print(f'Loading predicted target image captions from {preload_dict["mods"]}!')
        res_dict = pickle.load(open(preload_dict['mods'], 'rb'))

        # Handle old format without separate descriptions
        if 'original_descriptions' not in res_dict:
            res_dict['original_descriptions'] = res_dict.get('modified_captions', [])
            res_dict['refined_descriptions'] = res_dict.get('modified_captions', [])

        # Handle new fields for backward compatibility - use None for missing
        if 'original_img_descs' not in res_dict:
            num_items = len(res_dict['modified_captions'])
            res_dict['original_img_descs'] = [None] * num_items
            res_dict['tool_evidence_descs'] = [None] * num_items
            res_dict['refined_thoughts'] = [None] * num_items
            res_dict['refined_reflections'] = [None] * num_items
            res_dict['tool_used'] = [None] * num_items
            res_dict['tool_queries'] = [None] * num_items

        # Extract all fields
        target_names = res_dict['target_names']
        gt_img_ids = res_dict['targets']
        reference_names = res_dict['reference_names']
        query_ids = res_dict['query_ids']
        all_thoughts = res_dict['thoughts']
        all_reflations = res_dict['reflections']
        all_modified_captions = res_dict['modified_captions']
        all_original_captions = res_dict['original_descriptions']
        all_refined_captions = res_dict['refined_descriptions']
        all_relative_captions = res_dict['instructions']
        # New fields
        all_original_img_descs = res_dict.get('original_img_descs', [None] * len(all_modified_captions))
        all_tool_evidence_descs = res_dict.get('tool_evidence_descs', [None] * len(all_modified_captions))
        all_refined_thoughts = res_dict.get('refined_thoughts', [None] * len(all_modified_captions))
        all_refined_reflections = res_dict.get('refined_reflections', [None] * len(all_modified_captions))
        all_tool_used = res_dict.get('tool_used', [None] * len(all_modified_captions))
        all_tool_queries = res_dict.get('tool_queries', [None] * len(all_modified_captions))

        # Summary statistics:
        print(f"\n===== Loaded Statistics =====")
        print(f"Total items loaded: {len(all_modified_captions)}")
        print(f"Items with tools used: {sum(1 for t in all_tool_used if t is not None)}")
        print(f"  - Search: {sum(1 for t in all_tool_used if t == 'searching')}")
        print(f"  - Edit: {sum(1 for t in all_tool_used if t == 'image_editing')}")
        print(
            f"Items refined: {sum(1 for i, (orig, ref) in enumerate(zip(all_original_captions, all_refined_captions)) if orig != ref)}")
        print(f"Items with tool evidence: {sum(1 for t in all_tool_evidence_descs if t is not None)}")
        print("============================\n")

    # IMPORTANT: This return statement must be at the function level, not inside if/else
    return target_names, gt_img_ids, reference_names, query_ids, [], all_thoughts, all_reflations, all_modified_captions, all_original_captions, all_refined_captions, all_relative_captions


if __name__ == "__main__":
    # --- Set Device.
    termcolor.cprint(f'Starting evaluation on {args.dataset.upper()} (split: {args.split})\n', color='green',
                     attrs=['bold'])
    device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")

    ### Get preload dictionary - CALL ONCE HERE
    preload_dict = get_predeal_dict()

    ### Load CLIP model, BLIP model & Preprocessing.
    print(f'Loading CLIP {args.clip}... ', end='')

    if args.clip in ['ViT-bigG-14', 'ViT-B-32', 'ViT-B-16', 'ViT-L-14', 'ViT-H-14', 'ViT-g-14']:
        import open_clip

        pretraining = {
            'ViT-B-32': 'laion2b_s34b_b79k',
            'ViT-B-16': 'laion2b_s34b_b88k',
            'ViT-L-14': 'laion2b_s32b_b82k',
            'ViT-H-14': 'laion2b_s32b_b79k',
            'ViT-g-14': 'laion2b_s34b_b88k',
            'ViT-bigG-14': 'laion2b_s39b_b160k'
        }

        clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(args.clip,
                                                                               pretrained=pretraining[args.clip])
        clip_model = clip_model.eval().requires_grad_(False).to(device)
        tokenizer = open_clip.get_tokenizer(args.clip)
        clip_model.tokenizer = tokenizer
    else:
        clip_model, clip_preprocess = clip.load(args.clip, device=device, jit=False)
        clip_model = clip_model.float().eval().requires_grad_(False).to(device)

    print('Done.')

    if args.preprocess_type == 'targetpad':
        print('Target pad preprocess pipeline is used.')
        preprocess = data_utils.targetpad_transform(1.25, clip_preprocess.transforms[0].size)
    elif args.preprocess_type == 'clip':
        print('CLIP preprocess pipeline is used.')
        preprocess = clip_preprocess

    import omegaconf

    model_cls = lavis.common.registry.registry.get_model_class(args.blip)
    preprocess_cfg = omegaconf.OmegaConf.load(model_cls.default_config_path("pretrain_flant5xxl")).preprocess
    vis_processors, _ = lavis.models.load_preprocess(preprocess_cfg)

    # --- Load Evaluation Datasets.
    target_datasets, query_datasets, pairings = [], [], []
    if 'fashioniq' in args.dataset.lower():
        dress_type = args.dataset.split('_')[-1]
        target_datasets.append(
            datasets.FashionIQDataset(args.dataset_path, args.split, [dress_type], 'classic', preprocess,
                                      blip_transform=vis_processors['eval']))
        query_datasets.append(
            datasets.FashionIQDataset(args.dataset_path, args.split, [dress_type], 'relative', preprocess,
                                      blip_transform=vis_processors['eval']))
        pairings.append(dress_type)
        compute_results_function = compute_results.fiq

    elif args.dataset.lower() == 'cirr':
        split = 'test1' if args.split == 'test' else args.split
        target_datasets.append(datasets.CIRRDataset(args.dataset_path, split, 'classic', preprocess,
                                                    blip_transform=vis_processors['eval']))
        query_datasets.append(datasets.CIRRDataset(args.dataset_path, split, 'relative', preprocess,
                                                   blip_transform=vis_processors['eval']))
        compute_results_function = compute_results.cirr
        pairings.append('default')

    elif args.dataset.lower() == 'circo':
        target_datasets.append(datasets.CIRCODataset(args.dataset_path, args.split, 'classic', preprocess,
                                                     blip_transform=vis_processors['eval']))
        query_datasets.append(datasets.CIRCODataset(args.dataset_path, args.split, 'relative', preprocess,
                                                    blip_transform=vis_processors['eval']))
        compute_results_function = compute_results.circo
        pairings.append('default')

    elif 'genecis' in args.dataset.lower():
        data_split = '_'.join(args.dataset.lower().split('_')[1:])
        prop_file = os.path.join(args.dataset_path, 'genecis', data_split + '.json')

        if 'object' in args.dataset.lower():
            datapath = os.path.join(args.dataset_path, 'coco2017', 'val2017')
            genecis_dataset = datasets.COCOValSubset(root_dir=datapath, val_split_path=prop_file, data_split=data_split)
        elif 'attribute' in args.dataset.lower():
            datapath = os.path.join(args.dataset_path, 'Visual_Genome', 'VG_All')
            genecis_dataset = datasets.VAWValSubset(image_dir=datapath, val_split_path=prop_file, data_split=data_split)

        target_datasets.append(genecis_dataset)
        query_datasets.append(genecis_dataset)
        compute_results_function = compute_results.genecis
        pairings.append('default')

    # # --- get predeal dicts from each stage
    # preload_dict = get_predeal_dict()

    # --- Evaluate performances.
    for query_dataset, target_dataset, pairing in zip(query_datasets, target_datasets, pairings):
        termcolor.cprint(f'\n------ Evaluating Retrieval Setup: {pairing}', color='yellow', attrs=['bold'])

        ### General Input Arguments.
        input_kwargs = {
            'device': device, 'args': args, 'query_dataset': query_dataset, 'target_dataset': target_dataset,
            'preload_dict': preload_dict,  # Use the preload_dict from above
        }

        # --- Predict target captions
        target_names, targets, reference_names, query_ids, start_captions, all_thoughts, all_reflations, modified_captions, original_captions, refined_captions, instructions = MLLM_CIR(
            **input_kwargs)

        termcolor.cprint(f'\n------ Evaluating Retrieval Setup: {pairing}', color='yellow', attrs=['bold'])


        ### Compute Target Image Features
        print(f'Extracting target image features using CLIP: {args.clip}.')
        index_features, index_names, index_ranks, aux_data = utils.extract_image_features(
            device, args, target_dataset, clip_model, preload=preload_dict['img_features'])
        index_features = torch.nn.functional.normalize(index_features.float(), dim=-1)
        input_kwargs.update({'index_features': index_features, 'index_names': index_names, 'index_ranks': index_ranks})

        ### Compute Method-specific Query Features.
        # This part can be interchanged with any other method implementation.
        print(f'Generating conditional query predictions (CLIP: {args.clip}, BLIP: {args.blip}).')
        out_dict = utils.generate_predictions_gpt(**input_kwargs)
        input_kwargs.update(out_dict)

        ### Compute Dataset-specific Retrieval Scores.
        # This part is dataset-specific and declared above.
        print('Computing final retrieval metrics.')
        if args.dataset == 'genecis_focus_attribute':
            aux_data['ref_features'] = torch.nn.functional.normalize(aux_data['ref_features'].float().to(device))
            out_dict['predicted_features'] = torch.nn.functional.normalize(
                (out_dict['predicted_features'].float() + aux_data['ref_features']) / 2, dim=-1)

        input_kwargs.update(out_dict)
        result_metrics = compute_results_function(**input_kwargs)

        # Print metrics.
        print('\n')
        if result_metrics is not None:
            termcolor.cprint(f'Metrics for {args.dataset.upper()} ({args.split})- {pairing}', attrs=['bold'])
            for k, v in result_metrics.items():
                print(f"{pairing}_{k} = {v:.2f}")
        else:
            termcolor.cprint(f'No explicit metrics available for {args.dataset.upper()} ({args.split}) - {pairing}.',
                             attrs=['bold'])
