"""Bias and privacy-level generators shared by multiple scenarios."""

from __future__ import annotations

from itertools import product
from typing import List, Sequence

import numpy as np


def monotone_biases(step_small: float, step_big: float, n_sites: int) -> List[List[float]]:
    """Reproduce the discrete monotone bias grid used in the original scripts."""
    if n_sites < 2:
        raise ValueError("n_sites must be at least 2 (target + source sites).")
    levels = [0.0, float(step_small), float(step_big)]
    n_source = n_sites - 1
    combos = [
        combo
        for combo in product(levels, repeat=n_source)
        if all(combo[i] <= combo[i + 1] for i in range(n_source - 1))
    ]
    order = {value: idx for idx, value in enumerate(levels)}
    combos.sort(key=lambda combo: tuple(order[val] for val in reversed(combo)))
    all_biases = [[0.0] + list(combo) for combo in combos]
    # Keep only first 9 bias levels (remove the last one)
    return all_biases[:9]


def continuous_biases(
    end_val: float, num_points: int, n_sites: int, start_exp: float = -5.0
) -> List[List[float]]:
    """Log-spaced continuous bias grid (matches ``generate_biases_cont``)."""
    if end_val <= 0:
        raise ValueError("end_val must be positive for log-space sampling.")
    exponents = np.logspace(start_exp, np.log(end_val), num_points, base=np.e)
    return [[0.0] + [float(v)] * (n_sites - 1) for v in exponents]


def zero_biases(n_sites: int) -> List[List[float]]:
    """Return a single all-zero bias configuration."""
    return [[0.0] * n_sites]


def constant_biases(value: float, n_sites: int) -> List[List[float]]:
    """Use the same offset for every source site."""
    return [[0.0] + [float(value)] * (n_sites - 1)]


def privacy_vectors(
    target_r: float, r_start: float, r_end: float, n: int, n_sites: int
) -> List[List[float]]:
    """Recover the sequence of ``rs`` vectors used in the heterogeneous-r experiments."""
    source_values = np.linspace(r_start, r_end, n)
    return [[target_r] + [float(val)] * (n_sites - 1) for val in source_values]

