"""
Beta-Binomial regression model for ASR prediction.

This module implements a Bayesian Beta-Binomial regression model for predicting
Attack Success Rate (ASR) as a count of successful attacks out of N trials.

The model is specified as:
    s ~ BetaBinomial(N, α, β)

where:
    α = μ * ν
    β = (1-μ) * ν
    μ = sigmoid(w * x + b)

with priors:
    w ~ HalfNormal(σ_w)      # non-negative slope
    b ~ N(0, σ_b²)           # symmetric uncertainty for intercept

The concentration parameter ν controls the dispersion:
    - Large ν: Less dispersion (closer to binomial)
    - Small ν: More dispersion (overdispersion)
"""

import arviz as az  # Import arviz
import numpy as np
import pymc as pm
import pytensor.tensor as pt
from scipy.special import expit, logit
from scipy.stats import beta


def fit_beta_binomial_model(
    x_data,
    successes,
    trials,
    prior_params=None,
    n_samples=2000,
    tune=1000,
    target_accept=0.9,
    random_seed=42,
    progressbar=True,
):
    """
    Fit Bayesian Beta-Binomial regression model using PyMC.

    Args:
        x_data: Input features (capability differences)
        successes: Number of successful attacks (count data)
        trials: Number of trials for each data point
        prior_params: Dictionary with prior parameters:
                     - 'sigma_w': Scale for w prior
                     - 'sigma_b': Scale for b prior
                     - 'sigma_nu': Scale for ν prior (log-normal scale)
        n_samples: Number of posterior samples
        tune: Number of tuning steps
        target_accept: Target acceptance rate for NUTS
        random_seed: Random seed for reproducibility
        progressbar: Whether to display a progress bar during sampling

    Returns:
        model: PyMC model
        idata: ArviZ InferenceData object containing posterior samples and log likelihood
    """
    # Set default prior parameters if none provided
    if prior_params is None:
        prior_params = {"sigma_w": 1.0, "sigma_b": 1.0, "sigma_nu": 1.0}

    # Ensure inputs are numpy arrays
    x_data = np.asarray(x_data)
    successes = np.asarray(successes, dtype=int)
    trials = np.asarray(trials, dtype=int)

    with pm.Model() as model:
        # Prior for slope (non-negative)
        w = pm.HalfNormal("w", sigma=prior_params["sigma_w"])

        # Prior for intercept
        b = pm.Normal("b", mu=0, sigma=prior_params["sigma_b"])

        # Prior for concentration parameter (log-normal)
        nu = pm.LogNormal("nu", mu=0, sigma=prior_params["sigma_nu"])

        # Mean prediction (μ)
        mu = pm.Deterministic("mu", pm.math.sigmoid(w * x_data + b))

        # Alpha and Beta parameters for the Beta distribution
        alpha = pm.Deterministic("alpha", mu * nu)
        beta_param = pm.Deterministic("beta", (1 - mu) * nu)

        # BetaBinomial likelihood
        y_obs = pm.BetaBinomial(
            "y_obs", n=trials, alpha=alpha, beta=beta_param, observed=successes
        )

        # Sample from posterior
        idata = pm.sample(
            n_samples,
            tune=tune,
            target_accept=target_accept,
            random_seed=random_seed,
            return_inferencedata=True,  # Return InferenceData
            progressbar=progressbar,
        )
        # Compute log likelihood
        pm.compute_log_likelihood(idata, model=model, progressbar=progressbar)

    return model, idata


def predict(
    idata,  # Changed from trace to idata
    x_new,
    n_trials=50,
    successes_true=None,
    trials_true=None,
    return_params=False,
    n_samples=None,
):
    """
    Generate predictions from the fitted model.

    Args:
        idata: ArviZ InferenceData object from fit_beta_binomial_model
        x_new: New input points for prediction
        n_trials: Number of trials for prediction
        successes_true: True counts for computing metrics (optional)
        trials_true: True trial counts (optional)
        return_params: Whether to return alpha and beta parameters
        n_samples: Number of posterior samples to use (if None, use all available)

    Returns:
        Dictionary containing:
            - mean: Mean prediction (expected success probability)
            - mean_count: Mean prediction (expected count)
            - samples_prob: Posterior samples of probabilities (as discrete proportions)
            - samples_count: Posterior samples of counts
            - intervals: Credible intervals at different sigma levels
            - metrics: Dictionary of metrics (if successes_true provided)
    """
    x_new = np.atleast_1d(x_new)
    posterior = idata.posterior  # Access posterior samples

    # Determine number of samples to use
    n_total_samples = posterior.dims["chain"] * posterior.dims["draw"]
    if n_samples is not None and n_samples < n_total_samples:
        # Arviz can handle subsampling if needed, but let's do it simply for now
        sample_idx = np.random.choice(n_total_samples, n_samples, replace=False)
        w_samples = posterior["w"].values.flatten()[sample_idx]
        b_samples = posterior["b"].values.flatten()[sample_idx]
        nu_samples = posterior["nu"].values.flatten()[sample_idx]
    else:
        # Use all samples, flattened
        w_samples = posterior["w"].values.flatten()
        b_samples = posterior["b"].values.flatten()
        nu_samples = posterior["nu"].values.flatten()

    # Compute mean prediction (μ)
    logits = w_samples[:, None] * x_new[None, :] + b_samples[:, None]
    mu_samples = expit(logits)  # Convert to probability space

    # Compute alpha and beta parameters
    alpha_samples = mu_samples * nu_samples[:, None]
    beta_samples = (1 - mu_samples) * nu_samples[:, None]

    # Generate samples from beta distribution for each posterior sample
    beta_rv_samples = np.zeros_like(mu_samples)
    for i in range(len(nu_samples)):
        for j in range(len(x_new)):
            beta_rv_samples[i, j] = np.random.beta(
                alpha_samples[i, j], beta_samples[i, j]
            )

    # Generate samples from binomial distribution with beta probabilities
    count_samples = np.zeros_like(mu_samples, dtype=int)
    for i in range(len(nu_samples)):
        for j in range(len(x_new)):
            count_samples[i, j] = np.random.binomial(n_trials, beta_rv_samples[i, j])

    # Convert counts to proportions (discrete values)
    samples_prob = count_samples / n_trials

    # Compute summary statistics using discrete proportions
    mean_prob = np.mean(samples_prob, axis=0)
    mean_count = np.mean(count_samples, axis=0)

    # Compute credible intervals using discrete proportions
    sigma_levels = [1, 2, 3]
    sigma_percentiles = {1: (15.87, 84.13), 2: (2.28, 97.72), 3: (0.13, 99.87)}

    intervals_prob = {}
    intervals_count = {}
    for sigma in sigma_levels:
        lower, upper = sigma_percentiles[sigma]
        intervals_prob[f"{sigma}sigma"] = (
            np.percentile(samples_prob, lower, axis=0),
            np.percentile(samples_prob, upper, axis=0),
        )
        intervals_count[f"{sigma}sigma"] = (
            np.percentile(count_samples, lower, axis=0),
            np.percentile(count_samples, upper, axis=0),
        )

    # Compute metrics if true values are provided
    metrics = {}
    if successes_true is not None and trials_true is not None:
        # Convert true values to probabilities
        p_true = successes_true / trials_true
        p_pred = mean_prob

        # Root mean squared error
        rmse_prob = np.sqrt(np.mean((p_true - p_pred) ** 2))

        # R-squared for probabilities
        ss_total_prob = np.sum((p_true - np.mean(p_true)) ** 2)
        ss_residual_prob = np.sum((p_true - p_pred) ** 2)
        r2_prob = 1 - (ss_residual_prob / ss_total_prob) if ss_total_prob != 0 else 0

        # Count-based metrics
        count_pred = mean_count * (
            trials_true / n_trials
        )  # Scale prediction to match true trials
        rmse_count = np.sqrt(np.mean((successes_true - count_pred) ** 2))

        metrics["rmse_prob"] = rmse_prob
        metrics["r2_prob"] = r2_prob
        metrics["rmse_count"] = rmse_count

        # Log-likelihood is no longer computed here, it's in the idata object
        # We can extract it if needed, e.g., for reporting
        if hasattr(idata, "log_likelihood") and "y_obs" in idata.log_likelihood:
            # Calculate mean log likelihood across all samples and data points
            mean_log_likelihood = idata.log_likelihood["y_obs"].mean().item()
            metrics["log_likelihood"] = mean_log_likelihood

    result = {
        "mean_prob": mean_prob,
        "mean_count": mean_count,
        "samples_prob": samples_prob,  # Now using discrete proportions
        "samples_count": count_samples,
        "intervals_prob": intervals_prob,
        "intervals_count": intervals_count,
    }

    if return_params:
        result.update({"alpha_samples": alpha_samples, "beta_samples": beta_samples})

    if metrics:
        result["metrics"] = metrics

    return result
