from __future__ import annotations

import json
import sys
from pathlib import Path
from typing import Any, Optional

import numpy as np
from PIL import Image

_SCRIPT_DIR = Path(__file__).resolve().parent
_ROOT_DIR = _SCRIPT_DIR.parent
if str(_ROOT_DIR) not in sys.path:
    sys.path.insert(0, str(_ROOT_DIR))

from common.io import maybe_mkdir


def _normalize_heatmap(x: np.ndarray) -> np.ndarray:
    x = x.astype(np.float32)
    if x.size == 0:
        return x
    lo = float(np.percentile(x, 5))
    hi = float(np.percentile(x, 95))
    if hi <= lo:
        hi = lo + 1e-6
    x = (x - lo) / (hi - lo)
    return np.clip(x, 0.0, 1.0)


def overlay_heatmap(
    image: Image.Image, heatmap_hw: np.ndarray, alpha: float = 0.45
) -> Image.Image:
    """Overlay a (H,W) heatmap onto an RGB image.

    We use a simple red overlay to avoid extra plotting dependencies.
    """
    if heatmap_hw.ndim != 2:
        raise ValueError("heatmap_hw must be (H,W)")
    h, w = heatmap_hw.shape
    heat = _normalize_heatmap(heatmap_hw)

    heat_img = Image.fromarray((heat * 255).astype(np.uint8), mode="L").resize(
        image.size, resample=Image.Resampling.NEAREST
    )
    heat_arr = (np.asarray(heat_img).astype(np.float32) / 255.0)[:, :, None]

    base = np.asarray(image.convert("RGB")).astype(np.float32) / 255.0
    red = np.zeros_like(base)
    red[:, :, 0] = 1.0

    out = (1.0 - float(alpha) * heat_arr) * base + (float(alpha) * heat_arr) * red
    out = np.clip(out * 255.0, 0, 255).astype(np.uint8)
    return Image.fromarray(out, mode="RGB")


def export_run(
    *,
    out_dir: str | Path,
    image: Image.Image,
    grid_h: int,
    grid_w: int,
    token_scores: Optional[list[float]],
    source_scores: Optional[list[float]],
    data: dict[str, Any],
) -> None:
    out = maybe_mkdir(out_dir)

    image_path = out / "image.png"
    image.save(image_path)

    if token_scores is not None:
        arr = np.asarray(token_scores, dtype=np.float32)
        if arr.size != int(grid_h * grid_w):
            raise ValueError("token_scores length does not match grid")
        heat_hw = arr.reshape(int(grid_h), int(grid_w))
        overlay = overlay_heatmap(image, heat_hw)
        overlay.save(out / "attribution.png")

    (out / "data.json").write_text(json.dumps(data, indent=2))
