"""
Core utilities shared across the DQN betting components.
"""

from __future__ import annotations

import math
from typing import Optional, Tuple

import numpy as np
import torch

from .constants import EPSILON_ACTIONS, FEAT_TAU


def _resolve_world(world: str, rng: np.random.Generator) -> str:
    """
    Resolve 'world' to a concrete choice in {'beta', 'beta_mixture'}.
    If world == 'random', choose uniformly at random (50/50).
    """
    if world == "random":
        return "beta" if rng.random() < 0.5 else "beta_mixture"
    if world in ("beta", "beta_mixture"):
        return world
    raise ValueError("world must be 'beta', 'beta_mixture', or 'random'")


def _sample_conc_uniform(rng: np.random.Generator, conc_range, gap_logu_eps_min=1e-6) -> float:
    """
    Sample a concentration parameter with mass near 1 as in the original script.
    """
    cmin, cmax = map(float, conc_range)
    if cmax < cmin:
        raise ValueError(f"conc_range must satisfy max>=min, got {conc_range}")

    if cmin >= 1.0:
        return float(rng.uniform(cmin, cmax))

    if cmax <= 1.0:
        eps_max = 1.0 - cmin
        eps_min = min(float(gap_logu_eps_min), 0.5 * eps_max)
        u = rng.random()
        eps = eps_min * (eps_max / eps_min) ** u
        return float(1.0 - eps)

    if rng.random() < 0.5:
        eps_max = 1.0 - cmin
        if eps_max <= 0:
            return 1.0
        eps_min = float(gap_logu_eps_min)
        if eps_min <= 0:
            raise ValueError("gap_logu_eps_min must be > 0 for log-uniform sampling.")
        if eps_min >= eps_max:
            eps_min = 0.5 * eps_max
        u = rng.random()
        eps = eps_min * (eps_max / eps_min) ** u
        return float(1.0 - eps)
    return float(rng.uniform(1, cmax))


def safe_bounds(m: float, eps_cap: float = 1e-3) -> Tuple[float, float]:
    """Return (lam_max_pos, lam_max_neg) keeping 1+λ(X-m) ≥ 0 on X∈[0,1]."""
    lam_max_pos = (1.0 - eps_cap) / m
    lam_max_neg = -(1.0 - eps_cap) / (1.0 - m)
    return lam_max_pos, lam_max_neg


def _lambda_from_action(a: int, lam_kelly: float, lam_end: float) -> float:
    """
    Map discrete action index -> actual bet λ_t.
    0: 0.5 * Kelly
    1: Kelly
    2: all-in endpoint
    """
    if a == 0:
        return 0.5 * lam_kelly
    if a == 1:
        return lam_kelly
    if a == 2:
        return lam_end
    raise ValueError(f"Invalid action {a}. Expected 0, 1, or 2.")


def _lambda_from_action_batch(a, lam_kelly, lam_end):
    """
    Vectorized version of _lambda_from_action.
    """
    a = np.asarray(a, dtype=np.int64)
    lam = np.where(a == 0, 0.5 * lam_kelly, lam_kelly)
    lam = np.where(a == 2, lam_end, lam)
    return lam.astype(np.float32)


def _default_device():
    if torch.cuda.is_available():
        print("Cuda is available")
        return torch.device("cuda")
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")


__all__ = [
    "EPSILON_ACTIONS",
    "FEAT_TAU",
    "_resolve_world",
    "_sample_conc_uniform",
    "safe_bounds",
    "_lambda_from_action",
    "_lambda_from_action_batch",
    "_default_device",
]
