import math
from typing import List


def cycle_lengths_geom(
    K_init: float, n: int, gamma: float, *, as_int: bool = True
) -> List[int] | List[float]:
    """
    Return n cycle lengths with the same geometric growth rate as theory:
      K_{k+1} / K_k = gamma^(-2/3)
    and K_0 = K_init.

    Args:
        K_init: initial (first) cycle length K_0.
        n: number of cycles (n >= 0).
        gamma: discount factor in (0, 1).
        as_int: ceil-round to integers if True.

    Returns:
        List of n lengths.
    """
    if n <= 0:
        return []
    if not (0 < gamma < 1):
        raise ValueError("gamma must be in (0, 1)")
    if K_init <= 0:
        raise ValueError("K_init must be > 0")

    # Geometric ratio that matches ((n-1-k)/1.5) scaling
    g = gamma ** (-2 / 3)

    seq = [K_init * (g**k) for k in range(n)]
    if as_int:
        return [max(1, int(math.ceil(x))) for x in seq]
    return seq


def ada_eps_descending(n: int) -> List[float]:
    """
    Return n adaptive epsilon values following the formula: 1 / k^2, where k starts from 1.
    Args:
        ada_eps_init: initial (first) adaptive epsilon value (unused in this version).
        n: number of cycles (n >= 0).
        gamma: discount factor (unused in this version).
    Returns:
        List of n adaptive epsilon values.
    """
    if n <= 0:
        return []
    return [1 / (k**2) for k in range(1, n + 1)]


def create_lr(s=1.0, c=1.0, n_warmup=0):
    """
    Learning rate ~ 1 / (c + s * n) after warmup.
    Args:
        c: constant offset in the denominator.
        s: slope for the linear term in denominator.
        n_warmup: number of iterations to keep the rate == initial_rate.
    """
    return {
        "initial_rate": 1 / (s + 1),
        "mode": "rate",
        "mode_kwargs": {
            "rate_fct": lambda n: (
                1 / s if n < n_warmup else 1 / (s + c * (n - n_warmup))
            ),
            "iteration_num": 1,
            "final_rate": 0,
        },
        "current_rate": 1 / (s + 1),
    }


def create_const_lr(rate=0.1):
    return {
        "initial_rate": rate,
        "mode": "rate",
        "mode_kwargs": {
            "rate_fct": lambda n: rate,
            "iteration_num": 1,
            "final_rate": rate,
        },
        "current_rate": rate,
    }


# --- Helper: map (state, action) -> special-log index (eval) ---
def build_sa_indexer_with_var(which_state_actions_focus):
    """
    Returns small utilities to convert between (s,a) and special-log indices when
    each (s,a) contributes two series: value then variance (_var).

    The sequential order is:
        for s in states:
            for a in actions_per_state[state_index]:
                yield (s,a) value, then (s,a) var
    """
    states, actions_per_state = which_state_actions_focus
    seq_index = {}  # (s,a) -> position among pairs (ignoring var)
    rev_pairs = {}  # position -> (s,a)

    k = 0
    for s_idx, s in enumerate(states):
        for a in actions_per_state[s_idx]:
            seq_index[(s, a)] = k
            rev_pairs[k] = (s, a)
            k += 1

    def sa_to_index_value(s, a):
        # value index only
        return 2 * seq_index[(s, a)]

    def sa_to_index_var(s, a):
        # variance index only
        return 2 * seq_index[(s, a)] + 1

    def index_to_sa(i):
        # reverse lookup; i can be value or var index
        base = i // 2
        return rev_pairs[base], ("value" if i % 2 == 0 else "var")

    return sa_to_index_value, sa_to_index_var, index_to_sa


def indices_for_pairs_values_only(which_state_actions_focus, pairs):
    """
    Convenience: list[(s,a)] -> list[value_indices], skipping the _var series.
    """
    sa_to_val, _, _ = build_sa_indexer_with_var(which_state_actions_focus)
    return [sa_to_val(s, a) for (s, a) in pairs]
