import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from PIL import Image
from tqdm import tqdm
import os
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
import gc
from matplotlib import rcParams
import cv2
import io
import re

def build_model(args):
    model_type = args.model_type
    model_name_or_path = args.model_name_or_path
    if model_type == "cogagent":
        from models.cogagent import CogAgentModel
        model = CogAgentModel()
        model.load_model()
    elif model_type == "seeclick":
        from models.seeclick import SeeClickModel
        model = SeeClickModel()
        model.load_model()
    elif model_type == "qwen1vl":
        from models.qwen1vl import Qwen1VLModel
        model = Qwen1VLModel()
        model.load_model()
    elif model_type == "qwen2vl":
        from models.qwen2vl import Qwen2VLModel
        model = Qwen2VLModel()
        if args.model_name_or_path:
            model.load_model(model_name_or_path=model_name_or_path)
        else:
            model.load_model()
    elif model_type == "qwen2_5_vl":
        from models.qwen2_5_vl import Qwen2_5_VLModel
        model = Qwen2_5_VLModel()
        if args.model_name_or_path:
            model.load_model(model_name_or_path=model_name_or_path)
        else:
            model.load_model()
    elif model_type == "minicpmv":
        from models.minicpmv import MiniCPMVModel
        model = MiniCPMVModel()
        model.load_model()
    elif model_type == "internvl":
        from models.internvl import InternVLModel
        model = InternVLModel()
        model.load_model()
    elif model_type in ["gpt4o", "gpt4v"]:
        from models.gpt4x import GPT4XModel
        model = GPT4XModel()
    elif model_type == "osatlas-4b":
        from models.osatlas4b import OSAtlas4BModel
        model = OSAtlas4BModel()
        model.load_model()
    elif model_type == "osatlas-7b":
        from models.osatlas7b import OSAtlas7BModel
        model = OSAtlas7BModel()
        model.load_model()
    elif model_type == "uground":
        from models.uground import UGroundModel
        model = UGroundModel()
        model.load_model()
    elif model_type == "fuyu":
        from models.fuyu import FuyuModel
        model = FuyuModel()
        model.load_model()
    elif model_type == "showui":
        from models.showui import ShowUIModel
        model = ShowUIModel()
        model.load_model()
    elif model_type == "ariaui":
        from models.ariaui import AriaUIVLLMModel
        model = AriaUIVLLMModel()
        model.load_model()
    elif model_type == "cogagent24":
        from models.cogagent24 import CogAgent24Model
        model = CogAgent24Model()
        model.load_model()
    # New: UI-TARS
    elif model_type == "uitars":
        from models.uitars import UITarsModel
        model = UITarsModel()
        model.load_model()
    # New: UGround-V1-7B
    elif model_type == "ugroundv1":
        from models.ugroundv1 import UGroundV1Model
        model = UGroundV1Model()
        model.load_model()
    elif model_type == "tianxi-7b":
        from models.tianxi7b import LenovoAction7BModel
        model = LenovoAction7BModel()
        model.set_verbose(True)
        # If no path provided, use default or raise error
        model.load_model(model_name_or_path="tangliang/TianXi_Action_Grounding_7B", device="cuda")
    else:
        raise ValueError(f"Unsupported model type {model_type}.")
    model.set_generation_config(temperature=0, max_new_tokens=256)
    return model

# Select correct and wrong samples
def select_samples(results, num_samples=3, sample_id=None, img_path=None):
    """Select correct and wrong samples
    
    Args:
        results: evaluation result dict
        num_samples: number of samples to select
        sample_id: specific sample ID to use
        img_path: specific image path to use (substring match)
        
    Returns:
        list of correct samples and list of wrong samples
    """
    # Get details
    details = results['details']
    
    if sample_id:
        # If a sample ID is specified, use it preferentially
        selected_sample = next((sample for sample in details if sample.get('id') == sample_id), None)
        if selected_sample:
            print(f"Using specified sample ID: {sample_id}")
            if selected_sample['correctness'] == 'correct':
                return [selected_sample], []
            else:
                return [], [selected_sample]
        else:
            print(f"Warning: sample with ID {sample_id} not found, falling back to default selection")
    
    if img_path:
        # If an image path is specified, use samples matching this path
        selected_samples = [sample for sample in details if img_path in sample['img_path']]
        if selected_samples:
            print(f"Using specified image path: {img_path}")
            correct_samples = [s for s in selected_samples if s['correctness'] == 'correct']
            wrong_samples = [s for s in selected_samples if s['correctness'] == 'wrong']
            return correct_samples[:num_samples], wrong_samples[:num_samples]
        else:
            print(f"Warning: no samples with image path containing {img_path}, falling back to default selection")
    
    # Default selection
    correct_samples = [sample for sample in details if sample['correctness'] == 'correct']
    wrong_samples = [sample for sample in details if sample['correctness'] == 'wrong']
    
    # Random selection
    np.random.seed(42)
    selected_correct = np.random.choice(correct_samples, min(num_samples, len(correct_samples)), replace=False)
    selected_wrong = np.random.choice(wrong_samples, min(num_samples, len(wrong_samples)), replace=False)
    
    return list(selected_correct), list(selected_wrong)


# Get bbox from a sample
def get_bbox_from_sample(sample):
    """Get bbox from sample"""
    if 'bbox' in sample and sample['bbox']:
        return sample['bbox']
    return None

# Draw bbox on image
def draw_bbox(ax, bbox, color='lime', linewidth=0.5, label=None):
    """Draw bbox on image
    
    Args:
        ax: matplotlib Axes
        bbox: [x_min, y_min, x_max, y_max]
        color: bbox color
        linewidth: bbox line width
        label: label text
    """
    if not bbox:
        return
        
    x_min, y_min, x_max, y_max = bbox
    width = x_max - x_min
    height = y_max - y_min
    
    # Rectangle
    import matplotlib.patches as patches
    rect = patches.Rectangle(
        (x_min, y_min), width, height, 
        linewidth=linewidth, edgecolor=color, facecolor='none',
        label=label if label else "GT BBox"
    )
    ax.add_patch(rect)

# Configure matplotlib CJK font fallback
def setup_cjk_font():
    """Configure matplotlib to support CJK fonts"""
    # Try common CJK-capable fonts in order
    font_list = ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS', 'WenQuanYi Zen Hei', 'Hiragino Sans GB']
    
    # Check availability
    for font_name in font_list:
        try:
            plt.rcParams['font.sans-serif'] = [font_name] + plt.rcParams['font.sans-serif']
            # Simple smoke test
            plt.figure(figsize=(1, 1))
            plt.text(0.5, 0.5, 'Test')
            plt.close()
            print(f"Successfully set CJK-compatible font: {font_name}")
            break
        except Exception as e:
            print(f"Font {font_name} not available: {e}")
            continue
    
    # Fix minus sign display
    plt.rcParams['axes.unicode_minus'] = False
    
    # Global font size
    plt.rcParams['font.size'] = 12

# Call this to set font fallback
setup_cjk_font()

def save_points(points, save_path):
    """Save predicted points to file"""
    import json
    with open(save_path, 'w') as f:
        json.dump(points, f)
    print(f"Saved predicted points to: {save_path}")

def load_points(load_path):
    """Load predicted points from file"""
    import json
    with open(load_path, 'r') as f:
        points = json.load(f)
    print(f"Loaded {len(points)} predicted points from {load_path}")
    return points

def generate_block_heatmap(image, points, block_size=None, min_blocks=20, max_blocks=100, colormap='coolwarm'):
    """
    Generate heatmap based on block statistics
    
    Args:
        image: original image (numpy array)
        points: list of predicted points (x, y)
        block_size: block size in pixels, auto if None
        min_blocks: minimal number of blocks across width
        max_blocks: maximal number of blocks across width
        colormap: matplotlib colormap name
        
    Returns:
        Final heatmap image (numpy array)
    """
    h, w = image.shape[:2]
    
    # Determine block size
    if block_size is None:
        # Auto size in [min_blocks, max_blocks] range
        block_size = max(min(w // min_blocks, h // min_blocks), w // max_blocks)
    
    print(f"Generating heatmap: block size {block_size}x{block_size} pixels")
    
    # Number of blocks
    n_blocks_x = w // block_size + (1 if w % block_size else 0)
    n_blocks_y = h // block_size + (1 if h % block_size else 0)
    
    print(f"Heatmap blocks: {n_blocks_x}x{n_blocks_y}")
    
    # Block count grid
    block_counts = np.zeros((n_blocks_y, n_blocks_x), dtype=int)
    
    # Count points per block
    for x, y in points:
        if 0 <= x < w and 0 <= y < h:
            block_x = int(x // block_size)
            block_y = int(y // block_size)
            if block_x < n_blocks_x and block_y < n_blocks_y:
                block_counts[block_y, block_x] += 1
    
    # Stats
    non_zero_blocks = np.count_nonzero(block_counts)
    max_count = np.max(block_counts)
    mean_count = np.mean(block_counts[block_counts > 0]) if non_zero_blocks > 0 else 0
    print(f"Heatmap stats: {non_zero_blocks}/{n_blocks_x * n_blocks_y} non-empty blocks, max: {max_count}, mean: {mean_count:.2f}")
    
    # Offscreen heatmap figure (not displayed)
    plt.figure(figsize=(n_blocks_x/10, n_blocks_y/10))
    plt.imshow(block_counts, cmap=colormap, interpolation='bilinear')
    plt.axis('off')
    
    # To numpy
    fig = plt.gcf()
    fig.canvas.draw()
    heatmap_data = np.array(fig.canvas.renderer.buffer_rgba())
    plt.close()
    
    # Resize to image size
    heatmap_resized = cv2.resize(heatmap_data, (w, h), interpolation=cv2.INTER_LINEAR)
    
    # Blend to original
    alpha = 0.6
    result = image.copy()
    
    # Mask
    mask = heatmap_resized[:,:,3] > 0
    
    # Apply
    for c in range(3):
        result[mask, c] = result[mask, c] * (1 - alpha) + heatmap_resized[mask, c] * alpha
    
    # Matplotlib overlay for better visualization
    fig, ax = plt.subplots(figsize=(12, 10))
    ax.imshow(image)
    
    # Expand block counts to pixels
    heatmap_expanded = np.zeros((h, w))
    for y in range(n_blocks_y):
        for x in range(n_blocks_x):
            y_start = y * block_size
            y_end = min((y + 1) * block_size, h)
            x_start = x * block_size
            x_end = min((x + 1) * block_size, w)
            heatmap_expanded[y_start:y_end, x_start:x_end] = block_counts[y, x]
    
    # Overlay heat
    if max_count > 0:
        log_counts = np.log1p(heatmap_expanded) / np.log1p(max_count)
        im = ax.imshow(log_counts, cmap=colormap, alpha=alpha, interpolation='bilinear')
        cbar = plt.colorbar(im, ax=ax)
        cbar.set_label('Predicted points (log scale)')
    
    ax.set_title(f"Block-count heatmap (based on {len(points)} points)")
    ax.axis('off')
    plt.tight_layout()
    
    # Save to buffer
    buffer = io.BytesIO()
    plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight')
    buffer.seek(0)
    final_image = np.array(Image.open(buffer))
    plt.close()
    
    return final_image


# Save scatter figure (borderless, no legend)
def save_scatter_figure(img_np, points, save_path, point_size=3):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    fig, ax = plt.subplots(figsize=(12, 10))
    ax.imshow(img_np)
    if points:
        x = [p[0] for p in points]
        y = [p[1] for p in points]
        ax.scatter(x, y, c='yellow', s=point_size, alpha=0.5)
    # No frame, no ticks
    for side in ['top', 'right', 'bottom', 'left']:
        ax.spines[side].set_visible(False)
    ax.set_xticks([]); ax.set_yticks([])
    ax.axis('off')
    plt.savefig(save_path, dpi=300, bbox_inches='tight', pad_inches=0)
    plt.close(fig)

# Save heatmap (with colorbar, borderless)
def save_heatmap_figure(img_np, points, save_path, *, use_block_heatmap=False,
                        block_size=None, min_heatmap_blocks=20, max_heatmap_blocks=100,
                        colormap='coolwarm', alpha=0.7):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    h, w = img_np.shape[:2]
    fig, ax = plt.subplots(figsize=(12, 10))
    ax.imshow(img_np)

    if use_block_heatmap:
        # Block counts
        if block_size is None:
            block_size = max(min(w // min_heatmap_blocks, h // min_heatmap_blocks),
                             max(1, w // max_heatmap_blocks))
        n_blocks_x = w // block_size + (1 if w % block_size else 0)
        n_blocks_y = h // block_size + (1 if h % block_size else 0)
        block_counts = np.zeros((n_blocks_y, n_blocks_x), dtype=float)
        for x, y in points:
            if 0 <= x < w and 0 <= y < h:
                bx = int(x // block_size); by = int(y // block_size)
                if bx < n_blocks_x and by < n_blocks_y:
                    block_counts[by, bx] += 1.0
        # Expand to pixels and log scale
        heat = np.zeros((h, w), dtype=float)
        for by in range(n_blocks_y):
            for bx in range(n_blocks_x):
                y0, y1 = by * block_size, min((by + 1) * block_size, h)
                x0, x1 = bx * block_size, min((bx + 1) * block_size, w)
                heat[y0:y1, x0:x1] = block_counts[by, bx]
        vmax = heat.max() if heat.max() > 0 else 1.0
        heat = np.log1p(heat) / np.log1p(vmax)
        im = ax.imshow(heat, cmap=colormap, alpha=alpha, interpolation='bilinear', vmin=0.0, vmax=1.0)
    else:
        # Pixel-grid density + Gaussian blur
        heat = np.zeros((h, w), dtype=np.float32)
        for x, y in points:
            if 0 <= x < w and 0 <= y < h:
                heat[int(y), int(x)] += 1.0
        if heat.max() > 0:
            heat = cv2.GaussianBlur(heat, (0, 0), sigmaX=max(w, h) / 100)
        vmax = heat.max() if heat.max() > 0 else 1.0
        im = ax.imshow(heat, cmap=colormap, alpha=alpha, interpolation='bilinear', vmin=0, vmax=vmax)

    # Colorbar
    cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('Predicted point density', rotation=90)

    # No frame, no ticks
    for side in ['top', 'right', 'bottom', 'left']:
        ax.spines[side].set_visible(False)
    ax.set_xticks([]); ax.set_yticks([])
    ax.axis('off')

    plt.savefig(save_path, dpi=300, bbox_inches='tight', pad_inches=0)
    plt.close(fig)

# Save overlay (green gt_bbox, red base point, borderless)
def save_overlay_figure(img_np, bbox, base_point, save_path):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    fig, ax = plt.subplots(figsize=(12, 10))
    ax.imshow(img_np)
    if bbox:
        draw_bbox(ax, bbox, color='lime', linewidth=1.5, label=None)
    if base_point:
        ax.plot(base_point[0], base_point[1], 'ro', markersize=4)
    for side in ['top', 'right', 'bottom', 'left']:
        ax.spines[side].set_visible(False)
    ax.set_xticks([]); ax.set_yticks([])
    ax.axis('off')
    plt.savefig(save_path, dpi=300, bbox_inches='tight', pad_inches=0)
    plt.close(fig)


def visualize_model_attention_points(model, image_path, instruction, 
                                    num_iterations=100, M=10, N=10,
                                    min_blocks=1, max_blocks=None, 
                                    save_path=None, point_size=1, 
                                    use_heatmap=True, show_both=False,
                                    alpha=0.7, existing_points=None,
                                    save_points_path=None, colormap='coolwarm',
                                    use_block_heatmap=False, block_size=None,
                                    min_heatmap_blocks=20, max_heatmap_blocks=100,
                                    bbox=None,
                                    scatter_save_path=None,
                                    heatmap_save_path=None,
                                    overlay_save_path=None):
    """
    Sample and save three figures: scatter, heatmap, overlay
    """
    # Load image
    image = Image.open(image_path).convert('RGB')
    img_np = np.array(image)
    h, w = img_np.shape[:2]

    # Derive three output paths from legacy save_path
    if save_path and (not scatter_save_path and not heatmap_save_path and not overlay_save_path):
        base, _ = os.path.splitext(save_path)
        scatter_save_path = f"{base}_scatter.png"
        heatmap_save_path = f"{base}_heatmap.png"
        overlay_save_path = f"{base}_overlay.png"

    # Grid size
    block_h = h // M
    block_w = w // N
    
    if max_blocks is None:
        max_blocks = (M * N) // 2
    
    # Base prediction (full image)
    print("Getting base prediction on full image...")
    base_response = model.ground_only_positive(instruction=instruction, image=image_path)
    base_point_normalized = base_response.get("point", None)
    base_point = None
    if base_point_normalized is not None:
        base_point = [int(base_point_normalized[0] * w), int(base_point_normalized[1] * h)]
        print(f"Base point (normalized): {base_point_normalized}")
        print(f"Base point (pixels): {base_point}")
    
    # Sampling
    all_points = [] if existing_points is None else existing_points
    valid_predictions = 0
    total_predictions = 0
    pbar = tqdm(total=num_iterations, desc="Collecting model predictions")

    for iter_idx in range(num_iterations):
        masked_img = np.zeros_like(img_np)
        num_blocks = np.random.randint(min_blocks, max_blocks + 1)
        all_blocks = list(range(M * N))
        np.random.shuffle(all_blocks)
        selected_blocks = all_blocks[:num_blocks]
        for block_idx in selected_blocks:
            i, j = divmod(block_idx, N)
            y0, y1 = i * block_h, min((i + 1) * block_h, h)
            x0, x1 = j * block_w, min((j + 1) * block_w, w)
            masked_img[y0:y1, x0:x1] = img_np[y0:y1, x0:x1]
        try:
            masked_img_pil = Image.fromarray(masked_img)
            response = model.ground_only_positive(instruction=instruction, image=masked_img_pil)
            total_predictions += 1
            point_normalized = response.get("point", None)
            if point_normalized is not None:
                point = (int(point_normalized[0] * w), int(point_normalized[1] * h))
                all_points.append(point)
                valid_predictions += 1
                if iter_idx < 5:
                    print(f"Iter {iter_idx+1}: predicted point (pixels) = {point}")
        except Exception as e:
            print(f"Iter {iter_idx+1} error: {e}")
        pbar.update(1)
        if iter_idx % 20 == 19:
            gc.collect()
    pbar.close()

    if save_points_path:
        save_points(all_points, save_points_path)

    print(f"Ran {total_predictions} predictions, {valid_predictions} valid")
    print(f"Collected {len(all_points)} predicted points")

    # Save figures
    if scatter_save_path and len(all_points) > 0:
        save_scatter_figure(img_np, all_points, scatter_save_path, point_size=max(1, point_size))
        print(f"Scatter saved: {scatter_save_path}")

    if heatmap_save_path and len(all_points) > 5:
        save_heatmap_figure(
            img_np, all_points, heatmap_save_path,
            use_block_heatmap=use_block_heatmap,
            block_size=block_size,
            min_heatmap_blocks=min_heatmap_blocks,
            max_heatmap_blocks=max_heatmap_blocks,
            colormap=colormap,
            alpha=alpha
        )
        print(f"Heatmap saved: {heatmap_save_path}")

    if overlay_save_path:
        save_overlay_figure(img_np, bbox, base_point, overlay_save_path)
        print(f"Overlay saved: {overlay_save_path}")

    return all_points, base_point


# Generate accumulation animation frames
def generate_accumulation_animation(model, image_path, instruction, output_dir,
                                   num_iterations=100, frames=10, show_both=False,
                                   colormap='coolwarm', use_block_heatmap=False,
                                   block_size=None, min_heatmap_blocks=20,
                                   max_heatmap_blocks=100, bbox=None, **kwargs):
    """
    Generate animation frames that visualize the accumulation of predicted points
    
    Args:
        model: VLM
        image_path: image path
        instruction: instruction text
        output_dir: output directory
        num_iterations: total iterations
        frames: number of frames
        show_both: show original+scatter+heatmap triptych
        colormap: colormap for heatmap
        use_block_heatmap: use block-based heatmap
        block_size: block size in pixels
        min_heatmap_blocks: min number of blocks across width
        max_heatmap_blocks: max number of blocks across width
        bbox: [x_min, y_min, x_max, y_max]
        **kwargs: forwarded to visualize_model_attention_points
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Iterations per frame
    iters_per_frame = max(1, num_iterations // frames)
    
    all_points = []
    
    # Generate each frame
    for frame in range(1, frames + 1):
        curr_iterations = min(frame * iters_per_frame, num_iterations)
        print(f"\nGenerating frame {frame}/{frames} (iterations: {curr_iterations})")
        
        # Collect points for this frame
        if frame == 1:
            frame_points, _ = visualize_model_attention_points(
                model, image_path, instruction, 
                num_iterations=curr_iterations,
                save_path=None,
                colormap=colormap,
                **kwargs
            )
            all_points = frame_points
        else:
            additional_iters = iters_per_frame
            new_points, _ = visualize_model_attention_points(
                model, image_path, instruction, 
                num_iterations=additional_iters,
                save_path=None,
                colormap=colormap,
                **kwargs
            )
            all_points.extend(new_points)
        
        # Visualize all points so far
        image = Image.open(image_path).convert('RGB')
        img_np = np.array(image)
        h, w = img_np.shape[:2]
        
        # Base point on full image
        base_response = model.ground_only_positive(instruction=instruction, image=image_path)
        base_point_normalized = base_response.get("point", None)
        base_point = None
        if base_point_normalized is not None:
            base_point = [
                int(base_point_normalized[0] * w),
                int(base_point_normalized[1] * h)
            ]
        
        if show_both and len(all_points) > 5:
            # Triptych: original, scatter, heatmap
            fig = plt.figure(figsize=(18, 10))
            
            # 1. Original
            ax1 = plt.subplot(1, 3, 1)
            ax1.imshow(img_np)
            if base_point:
                ax1.plot(base_point[0], base_point[1], 'ro', markersize=1, label='Base prediction')
            ax1.set_title("Original image")
            ax1.axis('off')
            if bbox:
                draw_bbox(ax1, bbox)
            
            # 2. Scatter
            ax2 = plt.subplot(1, 3, 2)
            ax2.imshow(img_np)
            x = [p[0] for p in all_points]
            y = [p[1] for p in all_points]
            ax2.scatter(x, y, c='yellow', s=kwargs.get('point_size', 5), alpha=0.5, label='Predicted points')
            if base_point:
                ax2.plot(base_point[0], base_point[1], 'ro', markersize=1, label='Base prediction')
            ax2.set_title(f"Predicted points - frame {frame}/{frames} ({len(all_points)} points total)")
            ax2.axis('off')
            if bbox:
                draw_bbox(ax2, bbox)
            
            # 3. Heatmap
            ax3 = plt.subplot(1, 3, 3)
            if use_block_heatmap:
                heatmap_img = generate_block_heatmap(
                    img_np, all_points, 
                    block_size=block_size,
                    min_blocks=min_heatmap_blocks, 
                    max_blocks=max_heatmap_blocks,
                    colormap=colormap
                )
                ax3.imshow(heatmap_img)
                ax3.set_title(f"Block heatmap - frame {frame}/{frames}")
            else:
                ax3.imshow(img_np)
                sns.kdeplot(x=x, y=y, cmap=colormap, 
                           fill=True, alpha=kwargs.get('alpha', 0.7), 
                           thresh=0.05, 
                           levels=20, ax=ax3)
                ax3.set_title(f"KDE heatmap - frame {frame}/{frames}")
            
            if base_point:
                ax3.plot(base_point[0], base_point[1], 'ro', markersize=1, label='Base prediction')
            ax3.axis('off')
            if bbox:
                draw_bbox(ax3, bbox)

        else:
            # Single view
            plt.figure(figsize=(12, 10))
            ax = plt.gca()
            
            # Original
            plt.imshow(img_np)
            
            # Base point
            if base_point:
                ax.plot(base_point[0], base_point[1], 'ro', markersize=1, label='Base prediction')
            
            if use_block_heatmap and len(all_points) > 5:
                # Block heatmap
                heatmap_img = generate_block_heatmap(
                    img_np, all_points, 
                    block_size=block_size,
                    min_blocks=min_heatmap_blocks, 
                    max_blocks=max_heatmap_blocks,
                    colormap=colormap
                )
                plt.imshow(heatmap_img)
                plt.title(f"Model attention block heatmap - frame {frame}/{frames} (based on {len(all_points)} points)")
            elif kwargs.get('use_heatmap', True) and len(all_points) > 5:
                # KDE heatmap
                x = [p[0] for p in all_points]
                y = [p[1] for p in all_points]
                sns.kdeplot(x=x, y=y, cmap=colormap, 
                           fill=True, alpha=kwargs.get('alpha', 0.7), 
                           thresh=0.05, 
                           levels=20)
                plt.title(f"Model attention heatmap - frame {frame}/{frames} (based on {len(all_points)} points)")
            else:
                x = [p[0] for p in all_points]
                y = [p[1] for p in all_points]
                plt.scatter(x, y, c='yellow', s=kwargs.get('point_size', 5), 
                           alpha=0.5, label='Predicted points')
                plt.title(f"Predicted points - frame {frame}/{frames} ({len(all_points)} points total)")
            
            if base_point:
                plt.plot(base_point[0], base_point[1], 'ro', markersize=1, label='Base prediction')
                
            plt.axis('off')
            
        
        plt.tight_layout()
        frame_path = os.path.join(output_dir, f"frame_{frame:03d}.png")
        plt.savefig(frame_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"Frame {frame}/{frames} saved: {frame_path}")
    
    print(f"\nGenerated {frames} frames at: {output_dir}")
    print("You can compose them into a GIF or video via:")
    print(f"ffmpeg -framerate 2 -i {output_dir}/frame_%03d.png -vf \"fps=2,scale=800:-1\" {output_dir}/attention_animation.gif")

# Process a single sample
def process_sample(model, sample, args):
    """Process one sample"""
    image_path = sample['img_path']
    instruction = sample['prompt_to_evaluate']
    bbox = get_bbox_from_sample(sample)

    # Unique, readable file prefix
    safe_id = re.sub(r'[^a-zA-Z0-9_-]+', '_', sample.get('id', os.path.splitext(os.path.basename(image_path))[0]))
    prefix = "correct" if sample.get('correctness') == 'correct' else "wrong"
    base_prefix = f"{args.filename+'_' if args.filename else ''}{safe_id}_{prefix}"

    scatter_path = os.path.join(args.output_dir, f"{base_prefix}_scatter.png")
    heatmap_path = os.path.join(args.output_dir, f"{base_prefix}_heatmap.png")
    overlay_path = os.path.join(args.output_dir, f"{base_prefix}_overlay.png")

    # Optional: path to save points
    save_points_path = None
    if args.save_points:
        points_dir = os.path.join(args.output_dir, "points")
        os.makedirs(points_dir, exist_ok=True)
        save_points_path = os.path.join(points_dir, f"{base_prefix}_points.json")

    # Load existing points
    existing_points = None
    if args.continue_from:
        try:
            existing_points = load_points(args.continue_from)
        except Exception as e:
            print(f"Failed to load points file: {e}")

    print(f"Processing sample: {instruction}")
    print(f"Image path: {image_path}")
    if bbox:
        print(f"gt_bbox: {bbox}")

    # Run and save three figures
    all_points, base_point = visualize_model_attention_points(
        model, image_path, instruction,
        num_iterations=args.num_iterations,
        M=args.grid_m, N=args.grid_n,
        min_blocks=args.min_blocks,
        max_blocks=args.max_blocks,
        save_path=None,                      # no composite figure
        point_size=args.point_size,
        use_heatmap=not args.use_scatter,    # affects density type
        show_both=False,                     # no composite
        alpha=args.alpha,
        existing_points=existing_points,
        save_points_path=save_points_path,
        colormap=args.colormap,
        use_block_heatmap=args.use_block_heatmap,
        block_size=args.block_size,
        min_heatmap_blocks=args.min_heatmap_blocks,
        max_heatmap_blocks=args.max_heatmap_blocks,
        bbox=bbox,
        scatter_save_path=scatter_path,
        heatmap_save_path=heatmap_path,
        overlay_save_path=overlay_path
    )

    # Optional animation
    if args.generate_animation:
        anim_dir = os.path.join(args.animation_dir, f"{base_prefix}")
        os.makedirs(anim_dir, exist_ok=True)
        generate_accumulation_animation(
            model, image_path, instruction,
            anim_dir,
            num_iterations=args.num_iterations,
            frames=args.animation_frames,
            show_both=args.show_both,
            colormap=args.colormap,
            use_block_heatmap=args.use_block_heatmap,
            block_size=args.block_size,
            min_heatmap_blocks=args.min_heatmap_blocks,
            max_heatmap_blocks=args.max_heatmap_blocks,
            point_size=args.point_size,
            use_heatmap=not args.use_scatter,
            alpha=args.alpha,
            bbox=bbox
        )

    return all_points


# Randomly select samples
def random_select_samples(results, count, scope='all', seed=42):
    """
    Randomly pick 'count' samples filtered by scope
    scope: 'correct' | 'wrong' | 'all'
    """
    import random
    # 1) filter valid samples
    candidates = []
    for s in results:
        if not isinstance(s, dict):
            continue
        if 'img_path' not in s or 'prompt_to_evaluate' not in s:
            continue
        c = s.get('correctness', None)
        # Only include 'correct' and 'wrong'; exclude 'wrong_format'
        if scope == 'correct' and c != 'correct':
            continue
        if scope == 'wrong' and c != 'wrong':
            continue
        if scope == 'all' and c not in ('correct', 'wrong'):
            continue
        candidates.append(s)

    if not candidates:
        return []

    # 2) random sample
    rng = random.Random(seed)
    if count >= len(candidates):
        return rng.sample(candidates, len(candidates))
    return rng.sample(candidates, count)


def main():
    """Main"""
    import argparse
    parser = argparse.ArgumentParser(description='Model attention visualization based on point distributions')
    parser.add_argument('--model_type', type=str, default='uitars', help='model type')
    parser.add_argument('--log_path', type=str, required=False, help='evaluation log path (JSON)')
    parser.add_argument('--output_dir', type=str, default='attention_points_results', help='output directory')
    parser.add_argument('--grid_m', type=int, default=3, help='grid rows (vertical split count)')
    parser.add_argument('--grid_n', type=int, default=3, help='grid cols (horizontal split count)')
    parser.add_argument('--image_path', type=str, help='image path to analyze')
    parser.add_argument('--instruction', type=str, help='instruction text to run')
    parser.add_argument('--filename', type=str, default='', help='filename prefix for outputs')
    parser.add_argument('--sample_id', type=str, default=None, help='specific sample ID to analyze')
    parser.add_argument('--img_path', type=str, default=None, help='filter samples by image path substring')
    parser.add_argument('--num_samples', type=int, default=1, help='number of samples to select')
    parser.add_argument('--num_iterations', type=int, default=100, help='iterations per sample')
    parser.add_argument('--min_blocks', type=int, default=1, help='min blocks per iteration')
    parser.add_argument('--max_blocks', type=int, default=None, help='max blocks per iteration')
    parser.add_argument('--use_scatter', action='store_true', help='use scatter instead of heatmap')
    parser.add_argument('--show_both', action='store_true', help='show original + scatter + heatmap')
    parser.add_argument('--point_size', type=int, default=5, help='scatter point size')
    parser.add_argument('--alpha', type=float, default=0.7, help='heatmap alpha')
    parser.add_argument('--save_iterations', action='store_true', help='save heatmap of each iteration')
    parser.add_argument('--iter_output_dir', type=str, default='iteration_heatmaps', help='directory for per-iteration heatmaps')
    # multi-iteration
    parser.add_argument('--continue_from', type=str, default=None, help='continue from a previously saved points JSON')
    parser.add_argument('--save_points', action='store_true', help='save collected points to JSON')
    # colormap
    parser.add_argument('--colormap', type=str, default='coolwarm', 
                      help='colormap for heatmap, e.g., coolwarm, RdBu, bwr, hot, viridis')
    # block heatmap
    parser.add_argument('--use_block_heatmap', action='store_true', help='use block-based heatmap instead of KDE')
    parser.add_argument('--block_size', type=int, default=None, help='block size in pixels (auto if not set)')
    parser.add_argument('--min_heatmap_blocks', type=int, default=20, help='min number of blocks across width')
    parser.add_argument('--max_heatmap_blocks', type=int, default=100, help='max number of blocks across width')

    # random sampling control
    parser.add_argument('--random_pick', action='store_true', help='randomly select samples from evaluation log')
    parser.add_argument('--sample_scope', type=str, default='all', choices=['correct', 'wrong', 'all'],
                        help='scope for random sampling: correct/wrong/all')
    parser.add_argument('--random_seed', type=int, default=42, help='random seed')

    args = parser.parse_args()

    # Load model
    model_args = type('', (), {})()
    model_args.model_type = args.model_type
    model_args.model_name_or_path = None
    model = build_model(model_args)
    print(f"Model {args.model_type} loaded")
    
    # Prepare output dir
    os.makedirs(args.output_dir, exist_ok=True)
    if hasattr(args, 'generate_animation') and args.generate_animation:
        os.makedirs(args.animation_dir, exist_ok=True)
    
    # Mode 1: direct image + instruction
    if args.image_path and args.instruction:
        # Load existing points if any
        existing_points = None
        if args.continue_from:
            try:
                existing_points = load_points(args.continue_from)
            except Exception as e:
                print(f"Failed to load points file: {e}")
                
        # Path to save points if requested
        save_points_path = None
        if args.save_points:
            points_dir = os.path.join(args.output_dir, "points")
            os.makedirs(points_dir, exist_ok=True)
            save_points_path = os.path.join(points_dir, "direct_points.json")
                
        visualize_model_attention_points(
            model, args.image_path, args.instruction,
            num_iterations=args.num_iterations,
            M=args.grid_m, N=args.grid_n,
            min_blocks=args.min_blocks,
            max_blocks=args.max_blocks,
            save_path=os.path.join(args.output_dir, f"{args.filename}_attention_points.png" if args.filename else "attention_points.png"),
            point_size=args.point_size,
            use_heatmap=not args.use_scatter,
            show_both=args.show_both,
            alpha=args.alpha,
            existing_points=existing_points,
            save_points_path=save_points_path,
            colormap=args.colormap,
            use_block_heatmap=args.use_block_heatmap,
            block_size=args.block_size,
            min_heatmap_blocks=args.min_heatmap_blocks,
            max_heatmap_blocks=args.max_heatmap_blocks
        )
        
        if hasattr(args, 'generate_animation') and args.generate_animation:
            generate_accumulation_animation(
                model, args.image_path, args.instruction, args.animation_dir,
                num_iterations=args.num_iterations,
                frames=args.animation_frames,
                show_both=args.show_both,
                colormap=args.colormap,
                M=args.grid_m, 
                N=args.grid_n,
                min_blocks=args.min_blocks,
                max_blocks=args.max_blocks,
                point_size=args.point_size,
                use_heatmap=not args.use_scatter,
                alpha=args.alpha
            )
    
    # Mode 2: from evaluation log
    elif args.log_path:
        import json
        # Load from JSON (supports dict with 'details' or a list)
        with open(args.log_path, 'r') as f:
            raw = json.load(f)
        if isinstance(raw, dict) and 'details' in raw:
            results = raw['details']
        elif isinstance(raw, list):
            results = raw
        else:
            results = []
        print(f"Loaded {len(results)} evaluation samples from {args.log_path}")
 
        # A) random sampling mode
        if args.random_pick:
            selected = random_select_samples(
                results,
                count=args.num_samples,
                scope=args.sample_scope,
                seed=args.random_seed
            )
            print(f"Randomly selected {len(selected)} samples (scope: {args.sample_scope}, requested: {args.num_samples}, seed: {args.random_seed})")
 
            if not selected:
                print("No eligible samples for random sampling (check the log or adjust sample_scope).")
                return

            for i, sample in enumerate(selected):
                sid = sample.get('id', f'idx_{i}')
                print(f"\nProcessing random sample {i+1}/{len(selected)}: {sid} | correctness={sample.get('correctness')}")
                process_sample(model, sample, args)

        # B) legacy selection by ID/path or preset selector
        else:
            correct_samples, wrong_samples = select_samples(
                results, 
                num_samples=args.num_samples, 
                sample_id=args.sample_id, 
                img_path=args.img_path
            )
            
            print(f"Selected {len(correct_samples)} correct and {len(wrong_samples)} wrong samples")
            
            # Correct samples
            print("\nProcessing correct samples...")
            for i, sample in enumerate(correct_samples):
                print(f"\nProcessing correct sample {i+1}/{len(correct_samples)}")
                process_sample(model, sample, args)
            
            # Wrong samples
            print("\nProcessing wrong samples...")
            for i, sample in enumerate(wrong_samples):
                print(f"\nProcessing wrong sample {i+1}/{len(wrong_samples)}")
                process_sample(model, sample, args)

    else:
        print("Error: either provide image_path and instruction, or provide log_path")
        return

if __name__ == "__main__":
    main()

