from __future__ import annotations

"""Optional DINOv3 region clustering.

This module is isolated on purpose: reviewers can ignore it.
The default demo uses simple grid-based regions.
"""

import os
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import numpy as np
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModel

_SCRIPT_DIR = Path(__file__).resolve().parent
_ROOT_DIR = _SCRIPT_DIR.parent
if str(_ROOT_DIR) not in sys.path:
    sys.path.insert(0, str(_ROOT_DIR))

from inference.sources import Source, VisionTokenMap


@dataclass(frozen=True)
class DINOAttentionResult:
    attention_weights: np.ndarray  # (grid_h * grid_w,) - per-token attention prior
    grid_h: int
    grid_w: int


@dataclass(frozen=True)
class DINOConfig:
    model_id: str
    cache_dir: Optional[str]
    allow_download: bool
    distance_threshold: float


def dinov3_cluster_sources(
    *,
    token_map: VisionTokenMap,
    image_path: str,
    dinov3_model: str,
    cache_dir: Optional[str],
    allow_download: bool,
    distance_threshold: float,
) -> list[Source]:
    try:
        from sklearn.cluster import AgglomerativeClustering
    except Exception as e:
        raise RuntimeError(
            "DINOv3 clustering requires scikit-learn. Install with: pip install scikit-learn"
        ) from e

    cfg = DINOConfig(
        model_id=str(dinov3_model),
        cache_dir=None if cache_dir is None else str(cache_dir),
        allow_download=bool(allow_download),
        distance_threshold=float(distance_threshold),
    )

    if not cfg.allow_download:
        os.environ.setdefault("HF_HUB_OFFLINE", "1")
        os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")

    kwargs = {"local_files_only": not cfg.allow_download}
    if cfg.cache_dir:
        kwargs["cache_dir"] = cfg.cache_dir

    processor = AutoImageProcessor.from_pretrained(cfg.model_id, **kwargs)
    model = AutoModel.from_pretrained(cfg.model_id, **kwargs)

    with Image.open(image_path) as img:
        image = img.convert("RGB")

    processed = processor(images=image, return_tensors="pt")
    pixel_values = processed["pixel_values"].to(model.device)

    with torch.no_grad():
        outputs = model(pixel_values=pixel_values)

    # Use patch tokens (ignore CLS and register tokens).
    last_hidden_state = outputs.last_hidden_state
    num_register = int(getattr(model.config, "num_register_tokens", 0))
    start_idx = 1 + num_register
    patch_embeddings = last_hidden_state[:, start_idx:, :]

    num_patches = int(patch_embeddings.shape[1])
    patch_side = int(num_patches**0.5)
    if patch_side * patch_side != num_patches:
        raise ValueError(f"DINO patch count {num_patches} is not a perfect square")

    # Pool patch embeddings onto the Qwen visual token grid.
    hidden_dim = int(patch_embeddings.shape[2])
    patches_2d = patch_embeddings[0].reshape(patch_side, patch_side, hidden_dim)
    patches_2d = patches_2d.permute(2, 0, 1).unsqueeze(0)
    pooled = torch.nn.functional.adaptive_avg_pool2d(
        patches_2d, (token_map.grid_h, token_map.grid_w)
    )
    token_embeds = (
        pooled.squeeze(0)
        .permute(1, 2, 0)
        .reshape(token_map.grid_h * token_map.grid_w, hidden_dim)
    )

    x = token_embeds.detach().cpu().numpy().astype(np.float32)
    norms = np.linalg.norm(x, axis=1, keepdims=True)
    norms[norms == 0] = 1.0
    x = x / norms

    clustering = AgglomerativeClustering(
        n_clusters=None,
        metric="euclidean",
        linkage="average",
        distance_threshold=float(cfg.distance_threshold),
    )
    labels = clustering.fit_predict(x)

    label_to_tokens: dict[int, list[int]] = {}
    for idx, lab in enumerate(labels.tolist()):
        label_to_tokens.setdefault(int(lab), []).append(int(idx))

    sources: list[Source] = []
    for lab, toks in sorted(label_to_tokens.items()):
        sources.append(Source(name=f"dinov3_{lab}", token_indices=toks))
    return sources


def extract_dinov3_attention(
    *,
    image_path: str,
    grid_h: int,
    grid_w: int,
    dinov3_model: str,
    cache_dir: Optional[str],
    allow_download: bool,
) -> DINOAttentionResult:
    """Extract DINOv3 last-layer CLS attention as spatial prior for patch scores.

    Uses CLS token's attention to all patch tokens from the last transformer layer,
    averaged across heads. The attention is pooled to match the target grid size.
    """
    cfg = DINOConfig(
        model_id=str(dinov3_model),
        cache_dir=None if cache_dir is None else str(cache_dir),
        allow_download=bool(allow_download),
        distance_threshold=0.0,
    )

    if not cfg.allow_download:
        os.environ.setdefault("HF_HUB_OFFLINE", "1")
        os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")

    kwargs = {"local_files_only": not cfg.allow_download}
    if cfg.cache_dir:
        kwargs["cache_dir"] = cfg.cache_dir

    processor = AutoImageProcessor.from_pretrained(cfg.model_id, **kwargs)
    model = AutoModel.from_pretrained(cfg.model_id, **kwargs)

    with Image.open(image_path) as img:
        image = img.convert("RGB")

    processed = processor(images=image, return_tensors="pt")
    pixel_values = processed["pixel_values"].to(model.device)

    with torch.no_grad():
        outputs = model(pixel_values=pixel_values, output_attentions=True)

    attentions = outputs.attentions
    if attentions is None:
        raise RuntimeError("DINOv3 model did not return attentions")

    last_layer_attn = attentions[-1]  # (batch, num_heads, seq_len, seq_len)

    num_register = int(getattr(model.config, "num_register_tokens", 0))
    cls_idx = 0
    patch_start = 1 + num_register

    cls_to_patches = last_layer_attn[
        0, :, cls_idx, patch_start:
    ]  # (num_heads, num_patches)
    attn_mean = cls_to_patches.mean(dim=0)  # (num_patches,)

    num_patches = int(attn_mean.shape[0])
    patch_side = int(num_patches**0.5)
    if patch_side * patch_side != num_patches:
        raise ValueError(f"DINO patch count {num_patches} is not a perfect square")

    attn_2d = attn_mean.reshape(patch_side, patch_side)

    attn_2d_4d = attn_2d.unsqueeze(0).unsqueeze(0)
    pooled = torch.nn.functional.adaptive_avg_pool2d(attn_2d_4d, (grid_h, grid_w))
    attn_pooled = pooled.squeeze(0).squeeze(0).reshape(grid_h * grid_w)

    attn_np = attn_pooled.detach().cpu().numpy().astype(np.float32)
    attn_np = np.maximum(attn_np, 0.0)

    return DINOAttentionResult(
        attention_weights=attn_np,
        grid_h=grid_h,
        grid_w=grid_w,
    )
