from __future__ import annotations

import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import numpy as np

_SCRIPT_DIR = Path(__file__).resolve().parent
_ROOT_DIR = _SCRIPT_DIR.parent
if str(_ROOT_DIR) not in sys.path:
    sys.path.insert(0, str(_ROOT_DIR))


@dataclass(frozen=True)
class VisionTokenMap:
    """Minimal vision token map used in the submission package."""

    token_positions: list[int]
    grid_h: int
    grid_w: int


@dataclass(frozen=True)
class SourceSpec:
    """How we group visual tokens into sources (regions)."""

    mode: str = "block"  # block | voronoi | token | dinov3
    block_h: int = 2
    block_w: int = 2
    num_regions: int = 12
    seed: int = 0

    # Optional DINOv3 settings
    image_path: Optional[str] = None
    dinov3_model: str = "facebook/dinov3-vitl16-pretrain-lvd1689m"
    dinov3_cache_dir: Optional[str] = None
    dinov3_allow_download: bool = False
    dinov3_distance_threshold: float = 0.5


@dataclass(frozen=True)
class Source:
    name: str
    token_indices: list[int]  # indices into the grid (0..grid_h*grid_w-1)


def build_sources(token_map: VisionTokenMap, spec: SourceSpec) -> list[Source]:
    if spec.mode == "token":
        return [
            Source(name=f"tok_{i}", token_indices=[i])
            for i in range(token_map.grid_h * token_map.grid_w)
        ]
    if spec.mode == "block":
        return _build_block_sources(
            token_map, block_h=spec.block_h, block_w=spec.block_w
        )
    if spec.mode == "voronoi":
        return _build_voronoi_sources(
            token_map, num_regions=spec.num_regions, seed=spec.seed
        )
    if spec.mode == "dinov3":
        if spec.image_path is None:
            raise ValueError("dinov3 mode requires --image (used as image_path)")
        from inference.dinov3_clustering import dinov3_cluster_sources

        return dinov3_cluster_sources(
            token_map=token_map,
            image_path=spec.image_path,
            dinov3_model=spec.dinov3_model,
            cache_dir=spec.dinov3_cache_dir,
            allow_download=spec.dinov3_allow_download,
            distance_threshold=float(spec.dinov3_distance_threshold),
        )
    raise ValueError(f"Unknown source mode: {spec.mode}")


def build_source_membership(
    sources: list[Source], *, num_visual_tokens: int
) -> np.ndarray:
    m = np.zeros((len(sources), int(num_visual_tokens)), dtype=np.float32)
    for s_idx, s in enumerate(sources):
        for v in s.token_indices:
            m[s_idx, int(v)] = 1.0
    return m


def _build_block_sources(
    token_map: VisionTokenMap, *, block_h: int, block_w: int
) -> list[Source]:
    h, w = int(token_map.grid_h), int(token_map.grid_w)
    out: list[Source] = []
    for r0 in range(0, h, int(block_h)):
        for c0 in range(0, w, int(block_w)):
            toks = []
            for r in range(r0, min(r0 + int(block_h), h)):
                for c in range(c0, min(c0 + int(block_w), w)):
                    toks.append(r * w + c)
            out.append(Source(name=f"block_{r0}_{c0}", token_indices=toks))
    return out


def _build_voronoi_sources(
    token_map: VisionTokenMap, *, num_regions: int, seed: int
) -> list[Source]:
    h, w = int(token_map.grid_h), int(token_map.grid_w)
    num_tokens = h * w
    if int(num_regions) < 1:
        raise ValueError("num_regions must be >= 1")
    if int(num_regions) > num_tokens:
        raise ValueError("num_regions cannot exceed number of visual tokens")

    rng = np.random.RandomState(int(seed))
    seed_tokens = rng.choice(num_tokens, size=int(num_regions), replace=False).tolist()
    seed_coords = [(t // w, t % w) for t in seed_tokens]

    assignments: list[int] = []
    for token in range(num_tokens):
        r, c = token // w, token % w
        best_seed = 0
        best_dist = 10**9
        for s_idx, (sr, sc) in enumerate(seed_coords):
            dist = abs(r - sr) + abs(c - sc)
            if dist < best_dist:
                best_dist = dist
                best_seed = s_idx
        assignments.append(best_seed)

    groups: list[list[int]] = [[] for _ in range(int(num_regions))]
    for token, region_id in enumerate(assignments):
        groups[int(region_id)].append(int(token))

    return [Source(name=f"rrs_{i}", token_indices=g) for i, g in enumerate(groups)]
