#!/usr/bin/env python3
"""
build_reasoning_chains_from_trajectory_multiturn.py

Convert trajectory-based reasoning chains into multi-turn dialogue format.
This creates interactive conversations where the assistant searches through different bounding box regions
in the image using phrases from trajectory data to find small targets.
"""

import argparse, io, json, re, textwrap
from pathlib import Path
from typing import List, Tuple, Dict, Any
from collections import defaultdict
from datetime import datetime

from PIL import Image, ImageDraw
import datasets
import pandas as pd
import numpy as np
from tqdm import tqdm
import random
import os
import multiprocessing
from functools import partial

random.seed(42)

# ────────────────────────────────────────────────────────────────────────────────
# Regex helpers and coordinate extraction
# ────────────────────────────────────────────────────────────────────────────────
COORD_RE = re.compile(r"\(\s*([\d.]+)\s*,\s*([\d.]+)\s*\)")
SENT_END_RE = re.compile(r"[.!?](?:\s|$|\n)")

def extract_coordinates_from_chain(chain_text: str) -> List[Tuple[float, float, int, int]]:
    """Return [(x,y,start_idx,end_idx), …] in textual order."""
    coords = []
    for m in COORD_RE.finditer(chain_text):
        coords.append((float(m.group(1)), float(m.group(2)), m.start(), m.end()))
    return coords

def extract_final_answer(chain_text: str) -> str | None:
    """Extract the final answer from <answer> tags."""
    m = re.search(r"<answer>\s*(.*?)\s*</answer>", chain_text, re.DOTALL)
    if m:
        return m.group(1).strip()
    return None

def normalize_bbox(bbox: List[float], image_width: int, image_height: int) -> List[int]:
    """Convert normalized bbox coordinates to absolute pixel coordinates."""
    if all(0 <= coord <= 1 for coord in bbox):
        # Normalized coordinates
        x, y, w, h = bbox
        return [
            int(x * image_width),
            int(y * image_height), 
            int(w * image_width),
            int(h * image_height)
        ]
    else:
        # Already absolute coordinates
        return [int(coord) for coord in bbox]

def bbox_to_center_point(bbox: List[int]) -> Tuple[int, int]:
    """Convert bbox [x, y, w, h] to center point (x, y)."""
    x, y, w, h = bbox
    center_x = x + w // 2
    center_y = y + h // 2
    return center_x, center_y

def _get_bbox_crop(image: Image.Image, bbox: List[int], crop_size: int = 512, draw_bbox: bool = True) -> Image.Image:
    """
    Get a crop of the image that includes the bounding box with some padding.
    Also draws a bounding box rectangle at the location.
    """
    x, y, w, h = bbox
    width, height = image.size
    
    # Calculate padding to make the crop more informative
    padding = max(w, h) // 2
    
    # Clamp bbox coordinates to be within image bounds
    x = max(0, min(x, width - w))
    y = max(0, min(y, height - h))
    
    # Calculate crop boundaries with padding
    left = max(0, x - padding)
    top = max(0, y - padding)
    right = min(width, x + w + padding)
    bottom = min(height, y + h + padding)
    
    # Ensure that right > left and bottom > top
    if right <= left:
        right = left + 1
    if bottom <= top:
        bottom = top + 1
    
    # Create the crop
    crop = image.crop((left, top, right, bottom))
    
    # Draw a bounding box rectangle at the bbox location (relative to the crop)
    if draw_bbox:
        draw = ImageDraw.Draw(crop)
        bbox_left = x - left
        bbox_top = y - top
        bbox_right = bbox_left + w
        bbox_bottom = bbox_top + h
        
        # Draw the bounding box
        draw.rectangle(
            [bbox_left, bbox_top, bbox_right, bbox_bottom],
            outline="red",
            width=4
        )

    # Resize to target crop size
    crop = crop.resize((crop_size, crop_size), Image.Resampling.LANCZOS)

    return crop

def img_to_bytes_png(img: Image.Image) -> bytes:
    """Convert PIL image to PNG bytes."""
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    return buf.getvalue()

# ────────────────────────────────────────────────────────────────────────────────
# Core conversion routine
# ────────────────────────────────────────────────────────────────────────────────

def _next_sentence_end(txt: str, start_idx: int) -> int:
    """Return index just after the next sentence-terminating punctuation."""
    m = SENT_END_RE.search(txt, pos=start_idx)
    return (m.end() if m else len(txt))

def generate_reasoning_text(phrase: str, bbox: List[int], target_category: str, step_num: int, total_steps: int) -> str:
    """
    Generate contextual reasoning text for a search step based on the phrase and bounding box.
    """
    x, y, w, h = bbox
    bbox_text = f"[{x}, {y}, {w}, {h}]"
    
    reasoning_templates = [
        f"I need to examine {phrase.lower()} to understand the scene layout. Let me check the bounding box region {bbox_text} to see what's there.",
        f"Looking at {phrase.lower()}, I should investigate this region with bounding box {bbox_text} as it might provide clues about where the {target_category} could be located.",
        f"To systematically search for the {target_category}, I'll examine {phrase.lower()} within the bounding box {bbox_text}.",
        f"The phrase '{phrase}' suggests this could be relevant. Let me focus on the bounding box {bbox_text} to get a better view of this region.",
        f"Continuing my search, I notice {phrase.lower()}. I should examine the bounding box {bbox_text} to see if it contains any small objects."
    ]
    
    if step_num == 1:
        return f"I'll start by examining {phrase.lower()}. This region with bounding box {bbox_text} should help me understand the overall scene structure."
    elif step_num == total_steps:
        return f"As I continue searching, I focus on {phrase.lower()} with bounding box {bbox_text}. This area might contain the {target_category} I'm looking for."
    else:
        return random.choice(reasoning_templates)

def convert_trajectory_to_dialogue(phrases_data: List[Dict],
                                   img_path: Path,
                                   question: str,
                                   target_info: Dict,
                                   max_turns: int = 5,
                                   crop_size: int = 512,
                                   shuffle_prob: float = 0.3,
                                   sys_prompt: str = None,
                                   draw_bbox: bool = True,
                                   offset: int = 50) -> Dict:
    """
    Convert trajectory-based phrase data to multi-turn dialogue format using bounding boxes.
    """
    img = Image.open(img_path).convert("RGB")
    image_width, image_height = img.size
    
    if not phrases_data:
        return None
    
    # Select phrases for the dialogue, prioritizing those closer to target
    target_bbox = target_info['bbox']
    target_center_x = target_bbox[0] + target_bbox[2] / 2
    target_center_y = target_bbox[1] + target_bbox[3] / 2
    
    def distance_to_target(phrase_data):
        bbox = phrase_data['bbox']
        abs_bbox = normalize_bbox(bbox, image_width, image_height)
        center_x, center_y = bbox_to_center_point(abs_bbox)
        return ((center_x - target_center_x) ** 2 + (center_y - target_center_y) ** 2) ** 0.5
    
    # Select and sort phrases
    max_phrases = min(max_turns, len(phrases_data))
    selected_phrases = random.sample(phrases_data, max_phrases)
    selected_phrases.sort(key=distance_to_target)
    
    turns = []
    seen_bboxes = set()
    tool_call_count = 0
    
    # Process each phrase as a search step
    for i, phrase_data in enumerate(selected_phrases):
        phrase = phrase_data['phrase']
        bbox = normalize_bbox(phrase_data['bbox'], image_width, image_height)
        
        # Ensure bbox coordinates are within image bounds
        x, y, w, h = bbox
        x = max(0, min(x, image_width - w))
        y = max(0, min(y, image_height - h))
        w = min(w, image_width - x)
        h = min(h, image_height - y)
        bbox = [x, y, w, h]
        
        bbox_key = tuple(bbox)
        if bbox_key in seen_bboxes or tool_call_count >= max_turns:
            continue
            
        seen_bboxes.add(bbox_key)
        tool_call_count += 1
        
        # Generate reasoning text
        reasoning_text = generate_reasoning_text(
            phrase, bbox, target_info['category'], 
            i + 1, len(selected_phrases)
        )
        
        think_block = f"<think> {reasoning_text} </think>\n"
        
        # Create tool call with bounding box
        tool_json = {"name": "search_bbox", 
                    "arguments": {"bbox": bbox}}
        action_block = "<tool_call>\n" + json.dumps(tool_json) + "\n</tool_call>"
        
        turns.append({"role": "assistant", "content": think_block + action_block})
        
        # Build user observation (bbox crop)
        try:
            crop_img = _get_bbox_crop(img, bbox, crop_size, draw_bbox)
        except Exception as e:
            print(f"Error cropping bbox {bbox} for image {img_path} with dimensions {img.size}: {e}")
            print(f"Original bbox: {phrase_data['bbox']}, Normalized bbox: {bbox}")
            continue
            
        if draw_bbox:
            turns.append({
                "role": "user",
                "content": "<observation>\nHere is the crop of the image focused on the bounding box region, with a red rectangle highlighting the area:\n<image>\n</observation>",
                "_img_bytes": img_to_bytes_png(crop_img)
            })
        else:
            turns.append({
                "role": "user", 
                "content": "<observation>\nHere is the crop of the image focused on the bounding box region:\n<image>\n</observation>",
                "_img_bytes": img_to_bytes_png(crop_img)
            })
    
    # Add final answer
    final_think = "<think> Based on all the information I've gathered, I'll now provide my final answer. </think>\n"
    answer_block = f"<answer> {target_info['token']} </answer>"
    turns.append({"role": "assistant", "content": final_think + answer_block})
    
    # Assemble message list
    messages = [
        {"role": "system", "content": sys_prompt},
        {"role": "user", "content": f"<image>\n{question}"}
    ] + [{k: v for k, v in t.items() if k != "_img_bytes"} for t in turns]
    
    image_list = [{"bytes": img_to_bytes_png(img)}] + [
        {"bytes": t["_img_bytes"]} for t in turns if "_img_bytes" in t
    ]
    
    # Optional shuffling of intermediate turns
    if len(messages) > 5 and random.random() < shuffle_prob:
        # Find tool call pairs (assistant + user observation)
        tool_call_indices = []
        for i in range(2, len(messages) - 2, 2):  # Skip system, user, and final answer
            if (i < len(messages) and "tool_call" in messages[i].get("content", "") and
                i + 1 < len(messages) and messages[i + 1]["role"] == "user"):
                tool_call_indices.append(i)
        
        if len(tool_call_indices) > 1:
            # Randomly swap two tool call pairs
            idx1, idx2 = random.sample(tool_call_indices, 2)
            # Swap the tool call and observation pairs
            messages[idx1], messages[idx2] = messages[idx2], messages[idx1]
            messages[idx1+1], messages[idx2+1] = messages[idx2+1], messages[idx1+1]
            
            # Also swap corresponding images
            img_idx1 = (idx1 - 2) // 2 + 1
            img_idx2 = (idx2 - 2) // 2 + 1
            if img_idx1 < len(image_list) and img_idx2 < len(image_list):
                image_list[img_idx1], image_list[img_idx2] = image_list[img_idx2], image_list[img_idx1]
    
    return {"messages": messages, "images": image_list}

def load_trajectory_data(trajectory_file: str) -> Dict:
    """Load trajectory data from JSON file."""
    with open(trajectory_file, 'r', encoding='utf-8') as f:
        return json.load(f)

def load_small_targets_data(targets_file: str) -> Dict:
    """Load small targets data from JSON file."""
    with open(targets_file, 'r', encoding='utf-8') as f:
        return json.load(f)

def create_multiturn_training_examples(trajectory_data: Dict, targets_data: Dict) -> List[Dict]:
    """Create multi-turn training examples from trajectory and target data."""
    
    # Index targets by image_id
    targets_by_image = defaultdict(list)
    for target in targets_data['small_targets']:
        targets_by_image[str(target['image_id'])].append(target)
    
    training_examples = []
    
    for sample in trajectory_data['data']:
        image_id = str(sample['image_id'])
        
        if image_id in targets_by_image:
            # Process all targets for this image instead of sampling just one
            available_targets = targets_by_image[image_id]
            
            # Create a training example for each target in this image
            for selected_target in available_targets:
                # Create question about finding the target
                question = f"Can you help me locate the {selected_target['category']} in this image? Please examine different areas systematically to find this small object."
                
                # Create training example
                training_example = {
                    "id": f"traj_mt_{sample['image_id']}_{selected_target['annotation_id']}",
                    "image_id": sample['image_id'],
                    "image_path": sample['image_path'],
                    "question": question,
                    "target_info": selected_target,
                    "phrases_data": sample['phrases_data']
                }
                
                training_examples.append(training_example)
    
    return training_examples

def process_chain_batch(chains_batch, sys_prompt, max_turns, crop_size, shuffle_prob, draw_bbox, offset):
    """Process a batch of chains in parallel."""
    results = []
    for chain in chains_batch:
        phrases_data = chain["phrases_data"]
        question = chain["question"]
        img_path = Path(chain["image_path"])
        target_info = chain["target_info"]
        
        dlg = convert_trajectory_to_dialogue(
            phrases_data, img_path, question, target_info,
            max_turns, crop_size, shuffle_prob, sys_prompt, draw_bbox, offset
        )
        if dlg:
            results.append(dlg)
    return results


def process_paths_batch(batch_data):
    """Process a batch to save images and create path references."""
    batch_idx, batch, images_folder = batch_data
    rows_with_paths = []
    
    for i, row in enumerate(batch):
        idx = batch_idx * len(batch) + i
        row_with_paths = {"messages": []}
        
        # Save images
        image_paths = []
        for img_idx, img in enumerate(row["images"]):
            img_bytes = img["bytes"]
            img_filename = f"trajectory_mt_{idx}_{img_idx}.png"
            img_path = f"{images_folder}/{img_filename}"
            
            os.makedirs(os.path.dirname(img_path), exist_ok=True)
            with open(img_path, "wb") as f:
                f.write(img_bytes)
                
            image_paths.append(img_path)
        
        # Copy messages
        for msg in row["messages"]:
            msg_copy = {k: v for k, v in msg.items() if k != "_img_bytes"}
            row_with_paths["messages"].append(msg_copy)
        
        row_with_paths["images"] = image_paths
        rows_with_paths.append(row_with_paths)
    
    return rows_with_paths

def main():
    # Configuration
    trajectory_file = "/storage-root/datasets/yangfan/ICLR_AAAI/Trajectory-VLM/small_targets_trajectory_refactored_20250825_192014.json"
    targets_file = "/storage-root/datasets/yangfan/ICLR_AAAI/Trajectory-VLM/coco_bbox_analysis_20250824_232913/small_targets_details.json"
    
    # Extract trajectory file info and create timestamped output directory
    trajectory_basename = os.path.splitext(os.path.basename(trajectory_file))[0]
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = f"trajectory_multiturn_reasoning_chains_{trajectory_basename}_{timestamp}"
    images_folder = f"{output_dir}/images"
    max_turns = 5
    crop_size = 512
    shuffle_prob = 0.3
    draw_bbox = True
    offset = 50
    max_visualization_examples = 10  # Only save visualizations for first 10 examples
    
    # System prompt for multi-turn trajectory reasoning with bounding boxes
    sys_prompt = textwrap.dedent("""You are a helpful assistant tasked with locating small objects in images. You should systematically reason through the problem step by step by checking and verifying relevant image regions, while grounding reasoning steps to specific bounding box regions in the image:
- At each turn, first clearly reason about ONE area or element in the image enclosed in <think> </think> tags.
- After reasoning, either:
  a) Output a search action formatted precisely as:
     <tool_call>
     {"name": "search_bbox", "arguments": {"bbox": [x, y, w, h]}}
     </tool_call>
  b) If confident you've found the correct answer, output your final answer enclosed in <answer> object_name </answer> tags.
- Only answer if you are confident about the answer. If you are not confident, output a search action. You should not always end after one turn.
- You should not repeat the same bounding box in a tool call more than once. Bounding boxes must be unique across tool calls, including boxes that are the same or nearly identical (e.g., differing by only a few pixels).
- If unclear, infer based on likely context or purpose.
- Your final answer should be a single word or phrase identifying the target object.
- Verify each step by examining multiple possible solutions before selecting a final answer.""")
    
    print("Loading trajectory data...")
    trajectory_data = load_trajectory_data(trajectory_file)
    
    print("Loading small targets data...")
    targets_data = load_small_targets_data(targets_file)
    
    print("Creating training examples for all trajectory data...")
    training_examples = create_multiturn_training_examples(
        trajectory_data, targets_data
    )
    
    print(f"Generated {len(training_examples)} training examples")
    
    # Process chains to multi-turn format
    print("Converting to multi-turn dialogue format...")
    num_processes = max(1, multiprocessing.cpu_count() - 1)
    batch_size = max(1, len(training_examples) // num_processes)
    batches = [training_examples[i:i+batch_size] for i in range(0, len(training_examples), batch_size)]
    
    process_fn = partial(process_chain_batch, 
                        sys_prompt=sys_prompt,
                        max_turns=max_turns,
                        crop_size=crop_size, 
                        shuffle_prob=shuffle_prob,
                        draw_bbox=draw_bbox,
                        offset=offset)
    
    with multiprocessing.Pool(processes=num_processes) as pool:
        batch_results = list(tqdm(pool.imap(process_fn, batches), 
                                total=len(batches),
                                desc="Converting to dialogue"))
    
    # Flatten results
    dialogue_rows = [item for sublist in batch_results for item in sublist]
    print(f"Successfully converted {len(dialogue_rows)} examples to multi-turn format")
    
    # Save images and create path-based dataset (limited visualizations)
    print(f"Saving images for first {max_visualization_examples} examples and creating path-based dataset...")
    os.makedirs(images_folder, exist_ok=True)
    
    # Limit the number of examples with visualizations
    visualization_rows = dialogue_rows[:max_visualization_examples]
    non_visualization_rows = dialogue_rows[max_visualization_examples:]
    
    # Process visualization examples to save images
    final_rows = []
    if visualization_rows:
        batch_size = max(1, len(visualization_rows) // num_processes) 
        batches = [(i, visualization_rows[i*batch_size:(i+1)*batch_size], images_folder)
                  for i in range(min(num_processes, len(visualization_rows)))]
        if len(batches) > 0 and len(batches) * batch_size < len(visualization_rows):
            batches.append((len(batches), visualization_rows[len(batches)*batch_size:], images_folder))
        
        with multiprocessing.Pool(processes=num_processes) as pool:
            batch_results = list(tqdm(pool.imap(process_paths_batch, batches),
                                    total=len(batches), 
                                    desc="Saving images for visualization examples"))
        
        # Flatten visualization results
        final_rows.extend([item for sublist in batch_results for item in sublist])
    
    # For non-visualization examples, just remove image bytes and keep messages
    for row in non_visualization_rows:
        row_without_images = {"messages": []}
        for msg in row["messages"]:
            msg_copy = {k: v for k, v in msg.items() if k != "_img_bytes"}
            row_without_images["messages"].append(msg_copy)
        row_without_images["images"] = []  # Empty image list
        final_rows.append(row_without_images)
    
    # Create output directory and save results
    os.makedirs(output_dir, exist_ok=True)
    
    # Save as parquet
    output_parquet = f"{output_dir}/trajectory_multiturn_dataset.parquet"
    datasets.Dataset.from_pandas(pd.DataFrame(final_rows)).to_parquet(output_parquet)
    print(f"Saved multi-turn dataset to {output_parquet}")
    
    # Also save as JSON for inspection
    output_json = f"{output_dir}/trajectory_multiturn_dataset.json"
    with open(output_json, 'w', encoding='utf-8') as f:
        json.dump(final_rows[:10], f, indent=2)  # Save first 10 examples
    print(f"Saved sample examples to {output_json}")
    
    print(f"Total multi-turn examples created: {len(final_rows)}")
    print(f"Visualization images saved for first {len(visualization_rows)} examples in: {images_folder}")
    print(f"Remaining {len(non_visualization_rows)} examples saved without visualizations to save space")

if __name__ == "__main__":
    main()