import numpy as np
import tensorflow as tf
from scipy import sparse
from tqdm.auto import tqdm

from previous_methods import compute_ll

def safe_log(p, eps=1e-6):
    return tf.math.log(tf.clip_by_value(p, eps, 1.0))

def mlp_predict(fitted_models, X):
    return fitted_models["mlp"].predict(X)

def integrated_ll_regression_uniform(preds, y_min, y_max, num_mc_samples=100, sigma=1.0):
    """
    E_y[ log N(y | preds, sigma^2) ] with y ~ Uniform[y_min, y_max] via uniform MC.
    """
    preds = tf.convert_to_tensor(preds, tf.float32)
    two_pi = tf.constant(2.0*np.pi, dtype=tf.float32)

    y_samples = tf.random.uniform(
        shape=[num_mc_samples],
        minval=tf.cast(y_min, tf.float32),
        maxval=tf.cast(y_max, tf.float32),
        dtype=tf.float32
    )
    y_exp   = tf.reshape(y_samples, [-1] + [1]*len(preds.shape))  
    preds_e = tf.expand_dims(preds, axis=0)                       
    log_norm = -0.5 * tf.math.log(two_pi * (sigma**2))
    ll = log_norm - 0.5 * tf.square(y_exp - preds_e) / (sigma**2) 
    return tf.reduce_mean(ll, axis=0)

def elbo_loss(
    model_logits,
    base_ll,
    true_pred,
    pred_train,
    y_oh,
    posterior_type="categorical",
    eps=1e-8,
    kl_f_weight=1.0,
    tau=2.0,
    *,
    task="binary",
    return_parts=False
):
    # --- posterior q(f | x) ---
    q_f = tf.nn.softmax(model_logits, axis=-1)
    if task == "binary":
      q_f += q_f + eps
      q_f = tf.clip_by_value(q_f, eps, 1.0)

    # --- prior p(f) ---
    if task == "binary":
        # your original prior: base_ll + log p0 + log p1 from pred_train (class probs)
        log_p0 = safe_log(pred_train[..., 0])
        log_p1 = safe_log(pred_train[..., 1])
        log_p  = log_p0 + log_p1
        p_f    = tf.nn.softmax(base_ll + log_p, axis=-1)
    elif task == "regression":
        # pred_train carries the integrated log-likelihood for the current datapoints
        p_f = tf.nn.softmax((base_ll[None, :] + pred_train)/tau , axis=-1)

    # --- likelihood term E_q[ log p(y | f, x) ] ---
    if posterior_type == "categorical":
        weighted_ll = tf.reduce_sum(true_pred * q_f, axis=1)
    else:
        raise ValueError(f"Unsupported posterior type: {posterior_type!r}")

    if task == "binary":
        # KL(q || p)
        kl_f = tf.reduce_sum(q_f * (safe_log(q_f) - safe_log(p_f)), axis=1)
        elbo = weighted_ll - kl_f_weight * kl_f
    elif task == "regression":
        weighted_ll = tf.reduce_sum(q_f * true_pred, axis=1)
        kl_f = tf.reduce_sum(q_f * (safe_log(q_f) - safe_log(p_f)), axis=1)
        elbo = weighted_ll - kl_f_weight * kl_f
    if return_parts:
      return -tf.reduce_mean(elbo), tf.reduce_mean(weighted_ll), tf.reduce_mean(kl_f)
    else:
      return -tf.reduce_mean(elbo)



def train_adaptive_BMA(
    task,
    posterior_net,
    posterior_type,
    optimizer,
    x_train,
    y_train,
    fitted_models,
    batch_size=64,
    epochs=100,
    kl_f_weight=1.0,
    *,
    X_train_feat=None,     # to pass precomputed numeric features
    eps=1e-8,
    y_min=None,
    y_max=None
):
    import numpy as np
    import tensorflow as tf


    if task == "regression" and X_train_feat is not None:
        X_feat = X_train_feat
    else:
        try:
            X_feat = np.asarray(x_train, dtype=np.float32)
        except Exception:
            preds_tmp = [m.predict_proba(x_train) for m in fitted_models.values()]
            base_pred_train_tmp = np.stack(preds_tmp, axis=1).astype(np.float32)  # [N, M, C]
            X_feat = base_pred_train_tmp.reshape(base_pred_train_tmp.shape[0], -1)  # [N, M*C]

    idxs_train = np.arange(X_feat.shape[0], dtype=np.int32)
    y_np = y_train.to_numpy() if hasattr(y_train, "to_numpy") else np.asarray(y_train)

    # ---------- dataset ----------
    ds = (
        tf.data.Dataset
          .from_tensor_slices((idxs_train, X_feat, y_np))
          .shuffle(X_feat.shape[0])
          .batch(batch_size)
    )

    # ---------- Precompute base likelihood terms ----------
    if task == "binary":
        if (X_train_feat is not None) and (np.asarray(X_train_feat).ndim == 3):
            pred_train = np.asarray(X_train_feat, dtype=np.float32)  # [N,M,C?]
        else:
            pred_train = np.stack(
                [m.predict_proba(x_train) for m in fitted_models.values()],
                axis=1
            ).astype(np.float32)  
    
        pt = np.asarray(pred_train)
        if pt.ndim == 2:                      
            p1 = pt
            pred_train = np.stack([1.0 - p1, p1], axis=-1).astype(np.float32)
        elif pt.ndim == 3 and pt.shape[-1] == 1:  
            p1 = pt[..., 0]
            pred_train = np.stack([1.0 - p1, p1], axis=-1).astype(np.float32)
        elif pt.ndim == 3 and pt.shape[-1] == 2:
            pred_train = pt.astype(np.float32)
        else:
            raise ValueError(f"Binary expects probs/logits with C∈{{1,2}}; got shape {pt.shape}")
    
        y_int = (y_np.astype(int) if isinstance(y_np, np.ndarray) else np.asarray(y_np, dtype=int))
        y_oh = tf.one_hot(y_int, depth=2)
        true_pred_train = tf.reduce_sum(pred_train * y_oh[:, None, :], axis=-1)  # [N,M]
        true_pred_train = tf.cast(true_pred_train, tf.float32)
    
        pred_train_all = tf.convert_to_tensor(pred_train, dtype=tf.float32)      # [N,M,2]
        N_tf = tf.cast(tf.shape(pred_train_all)[0], tf.float32)
        base_ll_prior_binary = tf.reduce_sum(safe_log(pred_train_all), axis=[0, 2]) / N_tf    
        
    elif task == "regression":
        N = len(y_np)
        yhat_cols, kept_names = [], []
        for name, mdl in fitted_models.items():
            if name.lower() == "mlp":
                if hasattr(mdl, "yhat_train") and mdl.yhat_train is not None:
                    yhat = np.asarray(mdl.yhat_train).ravel()
                else:
                    yhat = np.asarray(mlp_predict(fitted_models, x_train)).ravel()
            else:
                if not hasattr(mdl, "predict"):
                    continue
                yhat = np.asarray(mdl.predict(x_train)).ravel()
            if len(yhat) != N:
                raise ValueError(f"Train preds length mismatch for model '{name}': got {len(yhat)}, expected {N}.")
            yhat_cols.append(yhat); kept_names.append(name)
        if not yhat_cols:
            raise ValueError("No valid predictors found in fitted_models for regression.")
        yhat_train = np.column_stack(yhat_cols).astype(np.float32)  # [N,M]
        y_vec = y_np.astype(np.float32)
        sigma2 = np.maximum(np.var(y_vec[:, None] - yhat_train, axis=0, ddof=1), eps)  # [M]
        ll_per_sample = -0.5 * np.log(2.0 * np.pi * (sigma2[None, :] + eps)) - 0.5 * ((y_vec[:, None] - yhat_train) ** 2) / (sigma2[None, :] + eps)
        true_pred_train = tf.convert_to_tensor(ll_per_sample, dtype=tf.float32)
        if y_min is None or y_max is None:
            y_min = float(np.min(y_vec)); y_max = float(np.max(y_vec))
            pad = 0.05 * (y_max - y_min + 1e-8)
            y_min -= pad; y_max += pad
        yhat_train_tf = tf.convert_to_tensor(yhat_train, dtype=tf.float32)
        int_ll_train = integrated_ll_regression_uniform(yhat_train_tf, y_min=y_min, y_max=y_max, num_mc_samples=64, sigma=1.0)
        base_ll = tf.reduce_mean(int_ll_train, axis=0)  # [M]
    else:
        raise ValueError("task must be 'binary' or 'regression'")

    sigma2_tf = None
    if task == "regression":
        sigma2_tf = tf.convert_to_tensor(sigma2.astype(np.float32), dtype=tf.float32)

    # ---------- Training loop ----------
    for epoch in range(1, epochs + 1):
        epoch_loss_total = 0.0
        epoch_neg_ell_total = 0.0
        epoch_kl_total = 0.0
        epoch_count = 0.0

        for idx_batch, x_batch, y_batch in ds:
            if task == "binary":
                true_pred_batch = tf.gather(true_pred_train, idx_batch)
                pred_batch = tf.cast(tf.gather(pred_train, idx_batch), tf.float32)
                y_oh_batch = tf.gather(tf.one_hot(y_np.astype(int), depth=2), idx_batch)

                with tf.GradientTape() as tape:
                    logits_f = posterior_net(x_batch)
                    loss, neg_ell_mean, kl_mean = elbo_loss(
                        model_logits=logits_f,
                        base_ll=base_ll_prior_binary,
                        true_pred=true_pred_batch,
                        pred_train=pred_batch,
                        y_oh=y_oh_batch,
                        posterior_type=posterior_type,
                        kl_f_weight=kl_f_weight,
                        task=task,
                        return_parts=True
                    )
            else:
                true_pred_batch = tf.gather(true_pred_train, idx_batch)
                yhat_batch = tf.gather(yhat_train_tf, idx_batch)
                pred_batch = integrated_ll_regression_uniform(yhat_batch, y_min=y_min, y_max=y_max, num_mc_samples=64, sigma=1.0)
                with tf.GradientTape() as tape:
                    logits_f = posterior_net(x_batch)
                    loss, neg_ell_mean, kl_mean = elbo_loss(
                        model_logits=logits_f,
                        base_ll=base_ll,
                        true_pred=true_pred_batch,
                        pred_train=pred_batch,
                        y_oh=None,
                        posterior_type=posterior_type,
                        kl_f_weight=kl_f_weight,
                        task=task,
                        return_parts=True
                    )

            grads = tape.gradient(loss, posterior_net.trainable_variables)
            grads_and_vars = [(g, v) for (g, v) in zip(grads, posterior_net.trainable_variables) if g is not None]
            optimizer.apply_gradients(grads_and_vars)

            bsz = tf.shape(x_batch)[0]
            bsz_f = float(bsz.numpy() if hasattr(bsz, "numpy") else bsz)
            epoch_loss_total += float(loss) * bsz_f
            epoch_neg_ell_total += float(neg_ell_mean) * bsz_f
            epoch_kl_total += float(kl_mean) * bsz_f
            epoch_count += bsz_f

        epoch_loss = epoch_loss_total / max(1.0, epoch_count)
        epoch_neg_ell = epoch_neg_ell_total / max(1.0, epoch_count)
        epoch_kl = epoch_kl_total / max(1.0, epoch_count)
        print(f"Epoch {epoch}/{epochs} — ELBO loss = {epoch_loss:.4f}, -E[ll] = {epoch_neg_ell:.4f}, KL = {epoch_kl:.4f}")

    return posterior_net


