import typing as tp

import numpy as np
from scipy import stats


def beta_binomial_p_gen(
  size: int, alpha_bias: float = 0, beta_bias: float = 0
) -> tuple[np.ndarray, np.ndarray]:
  xs = tp.cast(np.ndarray, np.random.randint(1, 100, (size, 2)))
  alphas = xs[:, 0] + alpha_bias
  betas = xs[:, 1] + beta_bias
  # Transform xs to shape parameters for the Beta distribution
  # alpha = alpha_base * xs[:, 0]  # Example transformation
  # beta = beta_base * xs[:, 1]  # Example transformation
  ps = stats.beta.rvs(alphas, betas, size=xs.shape[0])
  xs = xs.reshape((-1, 2))
  return xs, tp.cast(np.ndarray, ps)


def generate_binomial_data(
  size: int,
  n: int,
  p_gen: tp.Callable[[int], tuple[np.ndarray, np.ndarray]] = beta_binomial_p_gen,
):
  xs, ps = p_gen(size)

  bernoulli_trials = stats.bernoulli.rvs(ps[:, np.newaxis], size=(len(ps), n))

  return xs, bernoulli_trials, ps


def generate_geometric_data(
  size: int,
  k_max: int,
  *,
  p_gen: tp.Callable[[int], tuple[np.ndarray, np.ndarray]] = beta_binomial_p_gen,
):
  xs, bernoulli_trials, ps = generate_binomial_data(size, k_max, p_gen=p_gen)
  ys = bernoulli_trials.argmax(axis=1)

  # Set ys to k_max where there are no successes
  no_success = bernoulli_trials.sum(axis=1) == 0
  ys[no_success] = k_max

  return xs, ys, ps


def get_k_min(y, *, no_success_value: float = -1):
  k_mins = (y.argmax(axis=1) + 1).astype(float)

  no_success = np.sum(y, axis=1) == 0
  k_mins[no_success] = no_success_value

  return k_mins


def geometric_quantile(alpha: float, p: np.ndarray, *, eps=1e-6):
  p = np.clip(p, eps, 1 - eps)
  return np.ceil(np.log(1 - alpha) / np.log(1 - p))


def geometric_oracle_cis(ps: np.ndarray, *, alpha: float):
  import scipy.stats

  lbs = np.zeros_like(ps)
  ubs = scipy.stats.geom.ppf(1 - alpha, ps)

  return np.column_stack((lbs, ubs))
