"""Data generation utilities."""

from __future__ import annotations

from typing import Iterable, List, Sequence, Tuple

import numpy as np
from scipy.stats import cauchy, norm


def _true_quantile(dist_type: str, tau: float, mu: float) -> float:
    if dist_type == "normal":
        return mu + norm.ppf(tau)
    if dist_type == "uniform":
        low, high = mu - 1, mu + 1
        return low + (high - low) * tau
    if dist_type == "cauchy":
        return mu + cauchy.ppf(tau)
    raise ValueError(f"Unsupported distribution: {dist_type}")


def generate_stream(
    dist_type: str, tau: float, n_samples: int, mu: float = 0.0
) -> Tuple[np.ndarray, float]:
    """Generate i.i.d. samples plus the population quantile."""
    if n_samples <= 0:
        raise ValueError("n_samples must be positive.")
    if dist_type == "normal":
        data = np.random.normal(mu, 1.0, n_samples)
    elif dist_type == "uniform":
        data = np.random.uniform(mu - 1.0, mu + 1.0, n_samples)
    elif dist_type == "cauchy":
        data = np.random.standard_cauchy(n_samples) + mu
    else:
        raise ValueError(f"Unsupported distribution: {dist_type}")
    true_q = _true_quantile(dist_type, tau, mu)
    return data.astype(float), float(true_q)


def generate_federated_streams(
    dist_type: str,
    tau: float,
    sample_sizes: Sequence[int],
    biases: Sequence[float],
) -> Tuple[List[np.ndarray], List[float]]:
    """One stream per site."""
    if len(sample_sizes) != len(biases):
        raise ValueError("sample_sizes and biases must have identical length.")
    datas, true_qs = [], []
    for n_i, bias in zip(sample_sizes, biases):
        data_i, q_i = generate_stream(dist_type, tau, int(n_i), mu=float(bias))
        datas.append(data_i)
        true_qs.append(q_i)
    return datas, true_qs


def sample_allocations(
    n_sites: int, target_samples: int, source_prop: float | Sequence[float]
) -> List[int]:
    """Replicate ``get_n_n_sample`` with support for per-source ratios."""
    target_samples = int(target_samples)
    if target_samples <= 0:
        raise ValueError("target_samples must be positive.")
    if isinstance(source_prop, Sequence) and not isinstance(source_prop, (str, bytes)):
        if len(source_prop) != n_sites - 1:
            raise ValueError("source_prop must have length n_sites - 1.")
        sources = [int(round(target_samples * float(r))) for r in source_prop]
    else:
        src = int(round(target_samples * float(source_prop)))
        sources = [src] * (n_sites - 1)
    return [target_samples] + sources


def proportional_chains(sample_sizes: Sequence[int], base_k: int) -> List[int]:
    """Allocate multi-chain counts proportionally to sample size."""
    arr = np.asarray(sample_sizes, dtype=float)
    if len(arr) == 0:
        raise ValueError("sample_sizes cannot be empty.")
    ref = max(arr[0], 1.0)
    ratios = arr / ref
    chains = np.maximum(1, np.round(base_k * ratios).astype(int))
    chains[0] = int(base_k)
    return chains.tolist()


def split_stream(data: Sequence[float], n_chunks: int) -> List[np.ndarray]:
    """Evenly split a stream across ``n_chunks`` chains."""
    if n_chunks <= 0:
        raise ValueError("n_chunks must be positive.")
    arr = np.asarray(data, dtype=float).ravel()
    remainder = len(arr) % n_chunks
    if remainder:
        arr = arr[:-remainder]
    if len(arr) == 0:
        return [np.array([], dtype=float) for _ in range(n_chunks)]
    return np.array_split(arr, n_chunks)

