import json
from typing import List, Dict, Any, Tuple
import os
import random
from PIL import Image, ImageDraw
from io import BytesIO
import base64
import wandb
import re
from collections import defaultdict
from datetime import datetime

random.seed(42)

def normalize_bbox(bbox: List[float], image_width: int, image_height: int) -> List[int]:
    """Convert normalized bbox coordinates to absolute pixel coordinates."""
    if all(0 <= coord <= 1 for coord in bbox):
        # Normalized coordinates
        x, y, w, h = bbox
        return [
            int(x * image_width),
            int(y * image_height), 
            int(w * image_width),
            int(h * image_height)
        ]
    else:
        # Already absolute coordinates
        return [int(coord) for coord in bbox]

def normalize_trace_points(trace_points: List[List[float]], image_width: int, image_height: int) -> List[Tuple[int, int]]:
    """Convert normalized trace points to absolute pixel coordinates."""
    normalized_points = []
    for point in trace_points:
        if len(point) >= 2:
            x, y = point[0], point[1]
            if 0 <= x <= 1 and 0 <= y <= 1:
                # Normalized coordinates
                normalized_points.append((int(x * image_width), int(y * image_height)))
            else:
                # Already absolute coordinates
                normalized_points.append((int(x), int(y)))
    return normalized_points

def build_trajectory_reasoning(phrases_data: List[Dict], target_info: Dict, image_width: int = 640, image_height: int = 480) -> str:
    """
    Build a reasoning chain using trajectory data that leads to discovering the small target.
    
    Args:
        phrases_data: List of phrase data with bboxes (not using trace points)
        target_info: Information about the small target to find
        image_width: Width of the image for coordinate normalization
        image_height: Height of the image for coordinate normalization
    
    Returns:
        Formatted reasoning chain string
    """
    
    # Build the reasoning chain
    reasoning_steps = []
    
    # Sample a subset of phrases for the reasoning chain (to avoid overly long chains)
    max_phrases = min(8, len(phrases_data))
    selected_phrases = random.sample(phrases_data, max_phrases)
    
    # Sort phrases by their spatial proximity to the target
    target_bbox = target_info['bbox']
    target_center_x = target_bbox[0] + target_bbox[2] / 2
    target_center_y = target_bbox[1] + target_bbox[3] / 2
    
    def distance_to_target(phrase_data):
        bbox = phrase_data['bbox']
        # Convert to absolute coordinates if needed
        abs_bbox = normalize_bbox(bbox, image_width, image_height)
        center_x = abs_bbox[0] + abs_bbox[2] / 2
        center_y = abs_bbox[1] + abs_bbox[3] / 2
        return ((center_x - target_center_x) ** 2 + (center_y - target_center_y) ** 2) ** 0.5
    
    # Sort by distance to target, with some randomness
    selected_phrases.sort(key=distance_to_target)
    
    # Build reasoning steps using only bbox information
    for i, phrase_data in enumerate(selected_phrases):
        phrase = phrase_data['phrase']
        bbox = normalize_bbox(phrase_data['bbox'], image_width, image_height)
        
        # Use bbox coordinates directly in reasoning
        x, y, w, h = bbox
        bbox_text = f"[{x}, {y}, {w}, {h}]"
        
        # Create reasoning text for this step using bbox
        if i == 0:
            step_text = f"Starting to examine the image, I notice {phrase.lower()}. This region occupies the bounding box {bbox_text}, providing context for understanding the overall scene layout."
        elif i == len(selected_phrases) - 1:
            # Make the last step lead toward the target
            step_text = f"Examining {phrase.lower()}, I focus on the region with bounding box {bbox_text}. This area appears to be close to where smaller objects might be located. Looking more carefully at the details within this bounded region."
        else:
            step_text = f"Moving to examine {phrase.lower()}, the region with bounding box {bbox_text} shows {phrase.lower()}. This helps establish the spatial relationships in the scene."
        
        reasoning_steps.append(step_text)
    
    # Add a final step that discovers the target using its bbox
    target_bbox_abs = target_bbox  # Assume target bbox is already in absolute coordinates
    target_bbox_text = f"[{target_bbox_abs[0]:.1f}, {target_bbox_abs[1]:.1f}, {target_bbox_abs[2]:.1f}, {target_bbox_abs[3]:.1f}]"
    target_text = f"Upon closer inspection of the areas examined, I can identify a small {target_info['category']} located within bounding box {target_bbox_text}. This {target_info['token']} is quite small with an area ratio of {target_info['area_ratio']:.6f}, making it a challenging target to spot initially."
    
    reasoning_steps.append(target_text)
    
    # Combine all reasoning steps
    full_reasoning = " ".join(reasoning_steps)
    
    # Format as think/answer chain
    formatted_chain = f"<think>\n{full_reasoning}\n</think>\n<answer>{target_info['token']}</answer>"
    
    return formatted_chain

def load_trajectory_data(trajectory_file: str) -> Dict:
    """Load trajectory data from JSON file."""
    with open(trajectory_file, 'r', encoding='utf-8') as f:
        return json.load(f)

def load_small_targets_data(targets_file: str) -> Dict:
    """Load small targets data from JSON file."""
    with open(targets_file, 'r', encoding='utf-8') as f:
        return json.load(f)

def create_reasoning_chains(trajectory_data: Dict, targets_data: Dict) -> List[Dict]:
    """
    Create reasoning chains by matching trajectory data with small targets.
    
    Args:
        trajectory_data: Loaded trajectory JSON data
        targets_data: Loaded small targets JSON data
        
    Returns:
        List of training examples with reasoning chains
    """
    
    # Index targets by image_id for faster lookup
    targets_by_image = defaultdict(list)
    for target in targets_data['small_targets']:
        targets_by_image[str(target['image_id'])].append(target)
    
    training_examples = []
    
    for sample in trajectory_data['data']:
        image_id = str(sample['image_id'])
        
        # Find matching small targets for this image
        if image_id in targets_by_image:
            # Process all targets for this image instead of sampling just one
            available_targets = targets_by_image[image_id]
            
            # Get image dimensions (defaulting to common values if not available)
            image_width = 640  # These should ideally be read from image metadata
            image_height = 480
            
            try:
                # Try to get actual image dimensions if image exists
                if os.path.exists(sample['image_path']):
                    with Image.open(sample['image_path']) as img:
                        image_width, image_height = img.size
            except:
                pass  # Use defaults if image can't be opened
            
            # Create a training example for each target in this image
            for selected_target in available_targets:
                # Build reasoning chain
                reasoning_chain = build_trajectory_reasoning(
                    sample['phrases_data'], 
                    selected_target,
                    image_width,
                    image_height
                )
                
                # Create question about finding the small target
                question = f"Can you identify and locate the {selected_target['category']} in this image? Please provide detailed reasoning about how you locate it."
                
                # Create training example
                training_example = {
                    "id": f"traj_{sample['image_id']}_{selected_target['annotation_id']}",
                    "image_id": sample['image_id'],
                    "image_path": sample['image_path'],
                    "question": question,
                    "reasoning_chain": reasoning_chain,
                    "target_info": selected_target,
                    "full_caption": sample['full_caption'],
                    "phrases_count": len(sample['phrases_data'])
                }
                
                training_examples.append(training_example)
    
    return training_examples

def create_sft_format(training_examples: List[Dict], system_prompt: str) -> List[Dict]:
    """Convert training examples to SFT format."""
    sft_entries = []
    
    for example in training_examples:
        sft_entry = {
            "id": example["id"],
            "metadata": {
                "image_id": example["image_id"],
                "target_category": example["target_info"]["category"],
                "target_area_ratio": example["target_info"]["area_ratio"],
                "phrases_count": example["phrases_count"]
            },
            "messages": [
                {
                    "role": "system",
                    "content": system_prompt
                },
                {
                    "role": "user",
                    "content": example["question"]
                },
                {
                    "role": "assistant", 
                    "content": example["reasoning_chain"]
                }
            ],
            "images": [example["image_path"]],
            "target_info": example["target_info"]
        }
        sft_entries.append(sft_entry)
    
    return sft_entries

def visualize_reasoning_chain(example: Dict, output_dir: str):
    """Create visualization of the reasoning chain on the image using bboxes with different colors."""
    try:
        image_path = example["image_path"]
        if not os.path.exists(image_path):
            return
            
        with Image.open(image_path) as img:
            if img.mode != 'RGB':
                img = img.convert('RGB')
            
            draw = ImageDraw.Draw(img)
            
            # Define colors for different reasoning steps
            colors = [
                "blue", "green", "orange", "purple", "cyan", 
                "magenta", "yellow", "lime", "pink", "brown"
            ]
            
            # Draw target bbox in red
            target = example["target_info"]
            bbox = target["bbox"]
            x, y, w, h = bbox
            draw.rectangle([x, y, x+w, y+h], outline="red", width=4)
            # Add target label with background for better visibility
            label = f"TARGET: {target['category']}"
            draw.rectangle([x-2, y-25, x+len(label)*8, y-5], fill="red")
            draw.text((x, y-22), label, fill="white")
            
            # Extract bbox coordinates from reasoning chain
            reasoning_text = example["reasoning_chain"]
            bbox_matches = re.findall(r'\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]', reasoning_text)
            
            # Draw reasoning bboxes with different colors and numbers
            for i, (x_str, y_str, w_str, h_str) in enumerate(bbox_matches[:-1]):  # Exclude target bbox
                x, y, w, h = int(x_str), int(y_str), int(w_str), int(h_str)
                
                # Use different colors, cycling through the color list
                color = colors[i % len(colors)]
                
                # Draw bbox with thicker outline
                draw.rectangle([x, y, x+w, y+h], outline=color, width=3)
                
                # Add step number with colored background
                step_num = str(i + 1)
                text_width = len(step_num) * 10 + 4
                draw.rectangle([x-2, y-20, x+text_width, y-2], fill=color)
                draw.text((x, y-18), step_num, fill="white")
                
                # Also add a small colored circle at the center for easy identification
                center_x = x + w // 2
                center_y = y + h // 2
                circle_radius = 8
                draw.ellipse([center_x-circle_radius, center_y-circle_radius, 
                             center_x+circle_radius, center_y+circle_radius], 
                             fill=color, outline="white", width=2)
                draw.text((center_x-4, center_y-6), step_num, fill="white")
            
            # Add legend in the bottom right corner
            img_width, img_height = img.size
            legend_x = img_width - 200
            legend_y = img_height - 100
            
            # Draw legend background
            draw.rectangle([legend_x-10, legend_y-10, img_width-10, img_height-10], 
                          fill="white", outline="black", width=2)
            
            # Add legend title
            draw.text((legend_x, legend_y), "Reasoning Steps:", fill="black")
            
            # Add color legend for first few steps
            for i in range(min(len(bbox_matches)-1, 5)):  # Show up to 5 steps in legend
                color = colors[i % len(colors)]
                legend_item_y = legend_y + 15 + i * 12
                # Small colored square
                draw.rectangle([legend_x, legend_item_y, legend_x+10, legend_item_y+10], 
                              fill=color)
                # Step number and label
                draw.text((legend_x + 15, legend_item_y), f"Step {i+1}", fill="black")
            
            # Add target in legend
            target_y = legend_y + 15 + min(len(bbox_matches)-1, 5) * 12
            draw.rectangle([legend_x, target_y, legend_x+10, target_y+10], fill="red")
            draw.text((legend_x + 15, target_y), "Target", fill="black")
            
            # Save visualization
            os.makedirs(output_dir, exist_ok=True)
            output_path = os.path.join(output_dir, f"{example['id']}_visualization.jpg")
            img.save(output_path)
            
    except Exception as e:
        print(f"Error creating visualization for {example['id']}: {e}")

if __name__ == "__main__":
    # Configuration
    # trajectory_file = "/storage-root/datasets/yangfan/ICLR_AAAI/Trajectory-VLM/small_targets_trajectory_with_vis_20250825_080253.json"
    trajectory_file = "/storage-root/datasets/yangfan/ICLR_AAAI/Trajectory-VLM/small_targets_trajectory_node3_chunk3_20250826_005702.json"
    targets_file = "/storage-root/datasets/yangfan/ICLR_AAAI/Trajectory-VLM/coco_bbox_analysis_20250824_232913/small_targets_details.json"
    
    # Extract trajectory file info and create timestamped output directory
    trajectory_basename = os.path.splitext(os.path.basename(trajectory_file))[0]
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = f"trajectory_reasoning_chains_{trajectory_basename}_{timestamp}"
    
    val_split = 0.1
    create_visualizations = True
    log_to_wandb = False
    
    # System prompt for trajectory-based reasoning
    system_prompt = """You are an assistant that helps locate small objects in images through systematic spatial reasoning. You must examine the image carefully, following a logical path through different regions to identify and locate small targets.

All reasoning processes must be enclosed within '<think>' tags, with each step referencing specific bounding box regions in the image:

<think>
[Step-by-step spatial reasoning with bounding box references] [x1, y1, w1, h1]. [Further examination] [x2, y2, w2, h2]. [Final analysis leading to target discovery] [x3, y3, w3, h3].
</think>

The final answer should identify the small target object in '<answer>' tags:
<answer> target_object </answer>

Your task is to:
- Systematically examine different bounding box regions of the image
- Use spatial relationships between regions to navigate toward smaller details
- Identify small objects that might be easily missed within their bounding boxes
- Provide precise reasoning about the target's location and characteristics using bounding box coordinates
"""

    print("Loading trajectory data...")
    trajectory_data = load_trajectory_data(trajectory_file)
    
    print("Loading small targets data...")
    targets_data = load_small_targets_data(targets_file)
    
    print("Creating reasoning chains for all trajectory data...")
    training_examples = create_reasoning_chains(trajectory_data, targets_data)
    
    print(f"Generated {len(training_examples)} training examples")
    
    # Train/val split
    random.shuffle(training_examples)
    val_size = int(len(training_examples) * val_split)
    train_examples = training_examples[val_size:]
    val_examples = training_examples[:val_size]
    
    # Convert to SFT format
    train_sft = create_sft_format(train_examples, system_prompt)
    val_sft = create_sft_format(val_examples, system_prompt)
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Save training data
    train_path = os.path.join(output_dir, "trajectory_reasoning_train.json")
    with open(train_path, 'w', encoding='utf-8') as f:
        json.dump(train_sft, f, indent=2)
    print(f"Saved training data: {train_path} ({len(train_sft)} examples)")
    
    # Save validation data
    val_path = os.path.join(output_dir, "trajectory_reasoning_val.json")
    with open(val_path, 'w', encoding='utf-8') as f:
        json.dump(val_sft, f, indent=2)
    print(f"Saved validation data: {val_path} ({len(val_sft)} examples)")
    
    # Save raw examples for analysis
    raw_path = os.path.join(output_dir, "raw_examples.json")
    with open(raw_path, 'w', encoding='utf-8') as f:
        json.dump(training_examples, f, indent=2)
    print(f"Saved raw examples: {raw_path}")
    
    # Create visualizations
    if create_visualizations:
        print("Creating visualizations...")
        vis_dir = os.path.join(output_dir, "visualizations")
        sample_examples = random.sample(training_examples, min(20, len(training_examples)))
        for example in sample_examples:
            visualize_reasoning_chain(example, vis_dir)
        print(f"Created visualizations in: {vis_dir}")
    
    # Log to W&B if requested
    if log_to_wandb:
        print("Logging to W&B...")
        wandb.init(project="trajectory-vlm", name="trajectory_reasoning_chains")
        
        # Log dataset statistics
        wandb.log({
            "total_examples": len(training_examples),
            "train_examples": len(train_examples),
            "val_examples": len(val_examples),
            "avg_phrases_per_image": sum(ex["phrases_count"] for ex in training_examples) / len(training_examples),
            "unique_target_categories": len(set(ex["target_info"]["category"] for ex in training_examples))
        })
        
        # Create sample HTML for viewing
        html_content = "<html><body><h1>Trajectory Reasoning Chains</h1>"
        sample_size = min(10, len(training_examples))
        samples = random.sample(training_examples, sample_size)
        
        for example in samples:
            question = example["question"]
            reasoning = example["reasoning_chain"].replace('<', '&lt;').replace('>', '&gt;')
            target_info = example["target_info"]
            
            html_content += f"""
            <div style="border:1px solid #ccc; margin:20px; padding:15px;">
                <h3>Image ID: {example['image_id']}</h3>
                <p><b>Target:</b> {target_info['category']} ({target_info['token']})</p>
                <p><b>Area Ratio:</b> {target_info['area_ratio']:.6f}</p>
                <p><b>Question:</b> {question}</p>
                <p><b>Reasoning Chain:</b></p>
                <pre style="background:#f5f5f5; padding:10px; white-space:pre-wrap;">{reasoning}</pre>
            </div>
            """
        
        html_content += "</body></html>"
        wandb.log({"sample_chains": wandb.Html(html_content)})
        wandb.finish()
    
    print("Done!")