from __future__ import annotations

import argparse
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import numpy as np
import torch
from PIL import Image

_SCRIPT_DIR = Path(__file__).resolve().parent
_ROOT_DIR = _SCRIPT_DIR.parent
if str(_ROOT_DIR) not in sys.path:
    sys.path.insert(0, str(_ROOT_DIR))

from common.checks import ensure_dir, ensure_exists
from common.logging import setup_logging
from inference.attention_features import features_from_attentions
from inference.model import MODEL_ID, Qwen8B
from inference.sources import (
    SourceSpec,
    VisionTokenMap,
    build_source_membership,
    build_sources,
)
from inference.text_targets import targets_in_span
from inference.theta import load_theta, uniform_theta
from inference.viz import export_run


@dataclass(frozen=True)
class InferConfig:
    weights_dir: Optional[Path]
    image: Path
    question: str
    theta: Optional[Path]
    out_dir: Path
    max_new_tokens: int = 256
    source_mode: str = "block"
    block_h: int = 2
    block_w: int = 2
    num_regions: int = 48
    source_seed: int = 0
    target_mode: str = "whole"
    use_uniform_theta: bool = False
    allow_download: bool = False
    use_dino_attention: bool = False
    dinov3_model: str = "facebook/dinov3-vitl16-pretrain-lvd1689m"
    dinov3_cache_dir: Optional[Path] = None


def _parse_args() -> InferConfig:
    p = argparse.ArgumentParser()
    p.add_argument("--weights_dir", type=Path, default=None)
    p.add_argument("--image", type=Path, required=True)
    p.add_argument("--question", type=str, required=True)
    p.add_argument("--theta", type=Path, default=None)
    p.add_argument("--out_dir", type=Path, required=True)
    p.add_argument("--max_new_tokens", type=int, default=256)

    p.add_argument("--source_mode", type=str, default="block")
    p.add_argument("--block_h", type=int, default=2)
    p.add_argument("--block_w", type=int, default=2)
    p.add_argument("--num_regions", type=int, default=48)
    p.add_argument("--source_seed", type=int, default=0)

    p.add_argument("--target_mode", type=str, default="whole")
    p.add_argument("--use_uniform_theta", action="store_true")
    p.add_argument("--allow_download", action="store_true")

    p.add_argument("--use_dino_attention", action="store_true")
    p.add_argument(
        "--dinov3_model", type=str, default="facebook/dinov3-vitl16-pretrain-lvd1689m"
    )
    p.add_argument("--dinov3_cache_dir", type=Path, default=None)

    a = p.parse_args()
    return InferConfig(
        weights_dir=None if a.weights_dir is None else Path(a.weights_dir),
        image=Path(a.image),
        question=str(a.question),
        theta=None if a.theta is None else Path(a.theta),
        out_dir=Path(a.out_dir),
        max_new_tokens=int(a.max_new_tokens),
        source_mode=str(a.source_mode),
        block_h=int(a.block_h),
        block_w=int(a.block_w),
        num_regions=int(a.num_regions),
        source_seed=int(a.source_seed),
        target_mode=str(a.target_mode),
        use_uniform_theta=bool(a.use_uniform_theta),
        allow_download=bool(a.allow_download),
        use_dino_attention=bool(a.use_dino_attention),
        dinov3_model=str(a.dinov3_model),
        dinov3_cache_dir=None
        if a.dinov3_cache_dir is None
        else Path(a.dinov3_cache_dir),
    )


def main() -> None:
    cfg = _parse_args()
    logger = setup_logging("INFO")

    ensure_exists(cfg.image, what="image")
    if cfg.theta is not None:
        ensure_exists(cfg.theta, what="theta checkpoint")
    if cfg.weights_dir is not None:
        ensure_dir(cfg.weights_dir, what="weights_dir")

    if not torch.cuda.is_available():
        logger.warning(
            "CUDA not available. Qwen3-VL-8B inference is likely too slow on CPU."
        )

    image = Image.open(str(cfg.image)).convert("RGB")
    backend = Qwen8B(
        weights_dir=None if cfg.weights_dir is None else str(cfg.weights_dir),
        allow_download=cfg.allow_download,
        attn_implementation="eager",  # readable path: we want attentions returned directly
    )

    inputs, prepared_image = backend.build_inputs_and_image(
        image=image, question=cfg.question, force_think=True
    )
    full_ids, prompt_len = backend.generate(
        inputs, max_new_tokens=int(cfg.max_new_tokens), stop_at_end_think=True
    )

    spans = backend.find_spans(full_ids, prompt_len=int(prompt_len))
    thought_span = spans.thought_span(seq_len=len(full_ids))
    if thought_span is None:
        raise RuntimeError(
            "No <think>...</think> span found. Increase --max_new_tokens."
        )
    thought_start, thought_end = thought_span

    # Truncate to the end of the thought span for attribution computation.
    ids_for_features = full_ids[: int(thought_end)]

    # Build vision token map and region sources.
    token_positions, grid_h, grid_w = backend.build_vision_token_map(
        full_ids, inputs=inputs
    )
    token_map = VisionTokenMap(
        token_positions=token_positions, grid_h=int(grid_h), grid_w=int(grid_w)
    )

    source_spec = SourceSpec(
        mode=cfg.source_mode,
        block_h=cfg.block_h,
        block_w=cfg.block_w,
        num_regions=cfg.num_regions,
        seed=cfg.source_seed,
        image_path=str(cfg.image) if cfg.source_mode == "dinov3" else None,
        dinov3_cache_dir=str(cfg.weights_dir) if cfg.weights_dir is not None else None,
        dinov3_allow_download=cfg.allow_download,
    )
    sources = build_sources(token_map, source_spec)
    membership = build_source_membership(
        sources, num_visual_tokens=token_map.grid_h * token_map.grid_w
    )

    # Prepare tensors for forward pass that returns attentions.
    ids_t = torch.tensor([ids_for_features], device=backend.device, dtype=torch.long)
    attn_mask_t = torch.ones_like(ids_t, dtype=torch.long)
    inputs_for_features = type(inputs)(
        input_ids=ids_t,
        attention_mask=attn_mask_t,
        pixel_values=inputs.pixel_values,
        image_grid_thw=inputs.image_grid_thw,
    )
    position_ids = backend.compute_position_ids(
        input_ids=ids_t, attention_mask=attn_mask_t, inputs=inputs_for_features
    )

    with torch.no_grad():
        out = backend.forward_with_attentions(
            input_ids=ids_t,
            attention_mask=attn_mask_t,
            position_ids=position_ids,
            inputs=inputs_for_features,
        )
    if out.attentions is None:
        raise RuntimeError(
            "Model did not return attentions (expected eager attention path)."
        )

    targets = targets_in_span(
        tokenizer=backend.tokenizer,
        input_ids=ids_for_features,
        span=(int(thought_start), int(thought_end)),
        mode=str(cfg.target_mode),
    )
    if not targets:
        targets = [(int(thought_start), int(thought_end))]

    feats = (
        features_from_attentions(
            attentions=out.attentions,
            target_token_ranges=targets,
            vision_token_positions=token_positions,
            source_membership=membership,
        )
        .cpu()
        .numpy()
    )  # (T, S, F)

    num_features = int(feats.shape[2])
    if cfg.use_uniform_theta or cfg.theta is None:
        if cfg.theta is None and not cfg.use_uniform_theta:
            logger.warning(
                "No --theta provided; using uniform weights (train one with training/train_estimator.py)"
            )
        theta = uniform_theta(num_features)
    else:
        theta = load_theta(cfg.theta, expected_num_features=num_features)

    scores_by_target = np.tensordot(feats, theta.weights, axes=([2], [0]))  # (T, S)

    dino_attn: Optional[np.ndarray] = None
    if cfg.use_dino_attention:
        from inference.dinov3_clustering import extract_dinov3_attention

        logger.info("Extracting DINOv3 attention for patch score redistribution...")
        dino_result = extract_dinov3_attention(
            image_path=str(cfg.image),
            grid_h=int(token_map.grid_h),
            grid_w=int(token_map.grid_w),
            dinov3_model=str(cfg.dinov3_model),
            cache_dir=str(cfg.dinov3_cache_dir) if cfg.dinov3_cache_dir else None,
            allow_download=cfg.allow_download,
        )
        dino_attn = dino_result.attention_weights

    num_tokens = token_map.grid_h * token_map.grid_w
    token_scores_by_target = []
    for t_idx in range(int(scores_by_target.shape[0])):
        heat = np.zeros((num_tokens,), dtype=np.float32)
        for s_idx, src in enumerate(sources):
            region_score = float(scores_by_target[t_idx, s_idx])
            if dino_attn is not None and len(src.token_indices) > 0:
                region_attn_sum = sum(
                    float(dino_attn[tok]) for tok in src.token_indices
                )
                if region_attn_sum > 1e-8:
                    for tok in src.token_indices:
                        heat[int(tok)] += (
                            region_score * float(dino_attn[tok]) / region_attn_sum
                        )
                else:
                    per_tok = region_score / max(1, len(src.token_indices))
                    for tok in src.token_indices:
                        heat[int(tok)] += per_tok
            else:
                per_tok = region_score / max(1, len(src.token_indices))
                for tok in src.token_indices:
                    heat[int(tok)] += per_tok
        token_scores_by_target.append(heat.tolist())

    thought_text = backend.decode(
        ids_for_features[int(thought_start) : int(thought_end)]
    ).strip()

    data = {
        "model": MODEL_ID,
        "question": cfg.question,
        "thought": thought_text,
        "theta": theta.name,
        "grid_h": int(token_map.grid_h),
        "grid_w": int(token_map.grid_w),
        "sources": [
            {"name": s.name, "token_indices": s.token_indices} for s in sources
        ],
        "targets": [{"range": [int(a), int(b)]} for (a, b) in targets],
        "source_scores_by_target": scores_by_target.tolist(),
        "token_scores_by_target": token_scores_by_target,
        "use_dino_attention": cfg.use_dino_attention,
        "dino_attention_weights": dino_attn.tolist() if dino_attn is not None else None,
    }

    export_run(
        out_dir=cfg.out_dir,
        image=prepared_image,
        grid_h=int(token_map.grid_h),
        grid_w=int(token_map.grid_w),
        token_scores=token_scores_by_target[0] if token_scores_by_target else None,
        source_scores=scores_by_target[0].tolist() if scores_by_target.size else None,
        data=data,
    )
    logger.info("Wrote outputs to: %s", str(cfg.out_dir))


if __name__ == "__main__":
    main()
