from typing import Union
import torch
import pyro
from pyro.infer import Predictive
from sbsep.cbnn import CBNN
from sbsep.sb import SBMixture


def compute_pred(model: Union[CBNN, SBMixture], nsamples=100, guide_based=False):
    pyro.clear_param_store()
    predictive_before = Predictive(
        model, guide=model.guide if guide_based else None, num_samples=nsamples
    )
    x_test = torch.linspace(model.xa - 0.1, model.xb + 0.1, 50)

    data_test = x_test.unsqueeze(-1), None

    preds = predictive_before(*data_test)

    obs_pred = preds[f"{model.name}#obs"].detach().squeeze(-1)

    preds_mean = obs_pred.mean(axis=0).numpy()

    preds_5pct = torch.quantile(obs_pred, 0.05, axis=0).numpy()

    preds_95pct = torch.quantile(obs_pred, 0.95, axis=0).numpy()

    prediction = x_test, preds_mean, preds_5pct, preds_95pct

    return prediction
