"""
Group-wise helpers for RL training utilities.

Public API:
    - as_torch_index(index, device=None) -> torch.LongTensor
    - group_mean_std(scores, gidx, eps=1e-6, device=None) -> (mean_g, std_g, count_g)

Default device policy:
    - If `device` is None:
        * In pytest (detected by env "PYTEST_CURRENT_TEST"): use CPU.
        * Else if CUDA is available: use CUDA.
        * Else: use CPU.
    - You can override via env "VERL_FORCE_DEVICE" (e.g., "cuda:0" / "cpu").

Notes:
- as_torch_index: canonicalizes arbitrary group labels to a contiguous 1-D torch.long
  tensor in range [0..G-1]. Robust to torch/numpy/list/tuple, ints/floats/bools,
  numeric strings, UUIDs, mixed object arrays. Near-integer floats (|x-round(x)|<=1e-6)
  are rounded; otherwise factorization is applied.
- group_mean_std: pure-PyTorch per-group mean/std with Bessel correction for variance
  (denominator max(count-1, 1)). Singleton groups fallback to mean=0, std=1 for
  compatibility with common “native” conventions.
"""

from __future__ import annotations

import os
from typing import Any, Optional

import numpy as np
import torch

from verl.utils.device import get_torch_device

__all__ = ["as_torch_index", "group_mean_std"]


def _resolve_device(explicit: Optional[torch.device | str]) -> torch.device:
    """
    Resolve device according to policy described in the module docstring.
    Priority:
      1) explicit argument
      2) VERL_FORCE_DEVICE env
      3) pytest detection -> cpu
      4) cuda if available, else cpu
    """
    if explicit is not None:
        return torch.device(explicit)

    forced = os.getenv("VERL_FORCE_DEVICE")
    if forced:
        return torch.device(forced)

    # Heuristic: pytest sets PYTEST_CURRENT_TEST
    if "PYTEST_CURRENT_TEST" in os.environ:
        return torch.device("cpu")

    return get_torch_device()


def _to_1d_numpy_object_array(x: Any) -> np.ndarray:
    """Best-effort: convert arbitrary input into a 1-D numpy array; fallback to object dtype."""
    try:
        arr = np.asarray(x)
    except Exception:
        try:
            arr = np.array(list(x), dtype=object)
        except Exception:
            arr = np.array([x], dtype=object)
    if arr.ndim != 1:
        arr = arr.reshape(-1)
    return arr


def as_torch_index(index: Any, device: torch.device | str | None = None) -> torch.Tensor:
    """
    Convert arbitrary group labels to a contiguous 1-D torch.long tensor (0..G-1).

    Args:
        index: Any iterable of labels or tensor/ndarray.
        device: Target device; if None, resolved via _resolve_device().

    Returns:
        torch.LongTensor with shape (N,)
    """
    target = _resolve_device(device)

    # ---------- Fast path: torch.Tensor ----------
    if isinstance(index, torch.Tensor):
        t = index.reshape(-1)
        if t.dtype in (
            torch.int64,
            torch.int32,
            torch.int16,
            torch.int8,
            getattr(torch, "uint8", torch.uint8),
            torch.bool,
        ):
            return t.to(device=target, dtype=torch.long)

        if t.dtype in (torch.float16, torch.float32, torch.float64, torch.bfloat16):
            t64 = t.to(dtype=torch.float64)
            rounded = torch.round(t64)
            if torch.allclose(t64, rounded, rtol=0.0, atol=1e-6):
                return rounded.to(device=target, dtype=torch.long)
            arr = np.array([str(x.item()) for x in t], dtype=object)
        else:
            arr = np.array([str(x.item()) if hasattr(x, "item") else str(x) for x in t], dtype=object)

    else:
        # ---------- Non-torch: go through numpy ----------
        arr = _to_1d_numpy_object_array(index)

        # Pure integers (incl. bool)
        if arr.dtype != object and np.issubdtype(arr.dtype, np.integer):
            return torch.from_numpy(arr.astype(np.int64, copy=False)).to(device=target)

        # Floats nearly equal to integers
        if arr.dtype != object and np.issubdtype(arr.dtype, np.floating):
            arr64 = arr.astype(np.float64, copy=False)
            rounded = np.rint(arr64)
            if np.allclose(arr64, rounded, rtol=0.0, atol=1e-6):
                return torch.from_numpy(rounded.astype(np.int64)).to(device=target)
            # fall through

        # Try numeric string coercion
        try:
            coerced = arr.astype(np.int64)
            return torch.from_numpy(coerced).to(device=target)
        except Exception:
            pass

        if arr.dtype != object:
            arr = arr.astype(object)

    # ---------- Factorization (UUIDs / mixed types / arbitrary labels) ----------
    try:
        _, inv = np.unique(arr, return_inverse=True)
    except Exception:
        sarr = np.array([str(x) for x in arr], dtype=object)
        _, inv = np.unique(sarr, return_inverse=True)

    inv = inv.astype(np.int64, copy=False)
    return torch.from_numpy(inv).to(device=target)


@torch.no_grad()
def group_mean_std(
    scores: torch.Tensor,
    gidx: torch.Tensor,
    eps: float = 1e-6,
    device: torch.device | str | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Compute per-group mean/std/count in pure PyTorch.

    mean_g = sum / count
    std_g  = sqrt( max( (sum2 - sum^2/count) / max(count-1, 1), eps ) )

    Singleton groups fallback to mean=0, std=1.

    Args:
        scores: (N,) float tensor.
        gidx  : (N,) long/int tensor with group indices (0..G-1).
        eps   : Numerical floor for variance.
        device: Target device; if None, resolved via _resolve_device().

    Returns:
        mean_g: (G,) float32
        std_g : (G,) float32
        count : (G,) float32
    """
    target = _resolve_device(device)

    scores = scores.reshape(-1).to(device=target, dtype=torch.float32)
    gidx = gidx.reshape(-1).to(device=target, dtype=torch.long)

    if scores.numel() != gidx.numel():
        raise ValueError(f"scores and gidx length mismatch: {scores.numel()} vs {gidx.numel()}")

    G = int(torch.max(gidx).item()) + 1 if gidx.numel() > 0 else 0
    if G == 0:
        # Return empty tensors on the selected device
        empty = torch.empty(0, device=target, dtype=torch.float32)
        return empty, empty, empty

    ones = torch.ones_like(scores, dtype=torch.float32)

    count = torch.zeros(G, device=target, dtype=torch.float32).index_add_(0, gidx, ones)
    s1 = torch.zeros(G, device=target, dtype=torch.float32).index_add_(0, gidx, scores)
    s2 = torch.zeros(G, device=target, dtype=torch.float32).index_add_(0, gidx, scores * scores)

    mean = s1 / count.clamp_min(1.0)
    var_num = s2 - (s1 * s1) / count.clamp_min(1.0)
    denom = (count - 1.0).clamp_min(1.0)
    var = var_num / denom
    std = torch.sqrt(torch.clamp(var, min=eps))

    # Singleton groups: mean=0, std=1
    single = count <= 1.0
    if torch.any(single):
        mean = mean.clone()
        std = std.clone()
        mean[single] = 0.0
        std[single] = 1.0

    return mean, std, count
