import numpy as np

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_probability as tfp
tfd = tfp.distributions

from tqdm.notebook import tqdm


def log_normal_pdf(x, mu, sigma, eps=1e-8):
    """
    Computes the log probability density of a Normal distribution for each sample,
    summing over the last dimension.
    """
    log2pi = tf.math.log(2.0 * np.pi)
    return -0.5 * tf.reduce_sum(
        log2pi + 2 * tf.math.log(sigma + eps) + tf.square(x - mu) / (tf.square(sigma) + eps),
        axis=-1
    )

def log_likelihood_binary(logits, y):
    """
    Computes the log likelihood for binary classification.
    logits: shape [batch, 1]
    y: tensor of shape [batch] with values 0 or 1.

    Returns a tensor of shape [batch].
    """
    p = tf.sigmoid(logits)
    return tf.squeeze(y * tf.math.log(p + 1e-8) + (1 - y) * tf.math.log(1 - p + 1e-8), axis=-1)

def log_likelihood_multiclass(logits, y, num_classes):
    """
    Computes the log likelihood for multi-class classification.
    logits: shape [batch, num_classes]
    y: tensor of shape [batch] with integer labels or [batch, 1] (sparse labels).
    num_classes: total number of classes.
    
    Returns a tensor of shape [batch].
    """
    # Convert sparse labels to one-hot.
    y_one_hot = tf.one_hot(tf.reshape(y, [-1]), depth=num_classes)
    return -tf.nn.softmax_cross_entropy_with_logits(labels=y_one_hot, logits=logits)

def log_likelihood_regression(preds, y, sigma=1.0):
    """
    Computes the log likelihood for regression assuming a Gaussian likelihood.
    preds: shape [batch, 1]
    y: shape [batch, 1]
    sigma: fixed noise level (default 1.0)

    Returns a tensor of shape [batch].
    """
    return -0.5 * tf.math.log(2 * np.pi * sigma**2) - 0.5 * tf.square((y - preds)/sigma)

def integrated_log_likelihood_classification(logits, binary=True):
    """
    Computes a per-sample integrated log likelihood for classification.
    For binary classification it returns:
         p*log(p) + (1-p)*log(1-p)
    For multi-class, it returns sum_c [p(c) * log(p(c))].
    (This is essentially the negative entropy.)

    logits: if binary, shape [batch, 1]; if multiclass, shape [batch, num_classes]
    Returns a tensor of shape [batch].
    """
    if binary:
        p = tf.sigmoid(logits)
        p0 = 1 - p
        return tf.squeeze(p * tf.math.log(p + 1e-8) + p0 * tf.math.log(p0 + 1e-8), axis=-1)
    else:
        p = tf.nn.softmax(logits)
        return tf.reduce_sum(p * tf.math.log(p + 1e-8), axis=-1)

def integrated_log_likelihood_regression(preds, y_min, y_max, num_mc_samples=10):
    """
    Approximates the integrated (marginalized) log likelihood for regression by Monte Carlo.

    preds: tensor of predictions, shape [batch, 1]
    y_min, y_max: scalars defining the integration range.

    Returns a tensor of shape [batch].
    """
    R = num_mc_samples
    y_samples = tf.random.uniform(shape=[R], minval=y_min, maxval=y_max, dtype=preds.dtype)
    pred_expanded = tf.expand_dims(preds, axis=0)         # shape: [1, batch, 1]
    y_samples_expanded = tf.reshape(y_samples, [R, 1, 1])    # shape: [R, 1, 1]
    ll = -0.5 * tf.math.log(2 * np.pi) - 0.5 * tf.square(y_samples_expanded - pred_expanded)
    ll = tf.squeeze(ll, axis=-1)  # shape: [R, batch]
    integrated_ll = tf.reduce_mean(ll, axis=0)  # shape: [batch]
    return integrated_ll


def normalize_integrated_ll(v):
    """
    Normalizes a vector of integrated log likelihood values.

    Given v = [L(θ₁), L(θ₂), ..., L(θ_N)],
    we approximate the partition function Z as:
         Z ≈ mean(exp(v))
    and return v_norm = v - log(Z).
    """
    Z = tf.reduce_mean(tf.exp(v))
    return v - tf.math.log(Z + 1e-8)


def splits_fn(Xs, Ys, K, m, n):
  splits = []
  for k in range(K):
    N = len(Xs)
    fake_test_idx = np.random.choice(N, m, replace=True)
    fake_train_ids = np.random.choice(N, n, replace=True)
    splits.append([Xs[fake_train_ids], Ys[fake_train_ids], Xs[fake_test_idx], Ys[fake_test_idx]])
  return splits

def get_phi(h, gX, gXstar):

    # Compute training summary.
    gX_mean = tf.reduce_mean(gX, axis=0, keepdims=True)  # shape: [1, d]
    gX_mean_tiled = tf.tile(gX_mean, [tf.shape(gXstar)[0], 1])  # shape: [n, d]

    # h takes two inputs.
    phi = h([gXstar, gX_mean_tiled])
    # Split into mu and sigma.
    mu, sigma = tf.split(phi, num_or_size_splits=2, axis=-1)
    sigma = tf.nn.softplus(sigma) + 1e-6

    return mu, sigma

def train_posterior(Xs, Ys, g, h, optimizer, J, m, n, K=10, tau=1.0, lambda_val=1.0,
                    classification=True, binary=True, num_classes=None,
                    num_mc_samples=10, dtype='float32'):
    """
    Trains the posterior network h. In this version, we compute proper log probabilities
    (rather than losses) for the observed likelihood, and we combine them with a normalized
    integrated likelihood (acting as a prior).

    For each split:
       L_j = (observed_train_log_prob + observed_test_log_prob)
             - lambda_val * (posterior_lp - prior_lp)

    where:
      - posterior_lp is computed from the posterior over θ,
      - prior_lp is computed from normalized integrated likelihoods.

    The loss we optimize is the negative of this quantity.

    Returns the updated model h and a history list.
    """
    history = []

    if not classification:
        Ys_tensor = tf.cast(tf.convert_to_tensor(Ys), dtype)
        y_min = tf.reduce_min(Ys_tensor)
        y_max = tf.reduce_max(Ys_tensor)

    if classification and (not binary):
        if num_classes is None:
            raise ValueError("For multi-class classification, num_classes must be provided.")

    # Variables to hold MSE values (for regression) for printing.
    mse_train_print = None
    mse_test_print = None

    for k in tqdm(range(K)):
        splits = splits_fn(Xs, Ys, J, m, n)
        split_losses = []

        # We'll save the observed likelihood values from the last split for printing.
        observed_train_ll_last = None
        observed_test_ll_last = None
        train_preds_last = None
        test_preds_last = None
        tmp_Y_last = None
        tmp_Ystar_last = None

        with tf.GradientTape() as tape:
            for j in range(J):
                tmp_X, tmp_Y, tmp_Xstar, tmp_Ystar = splits[j]

                # Compute embeddings.
                gX = g(tmp_X)         # shape: [m, d]
                gXstar = g(tmp_Xstar) # shape: [n, d]

                mu, sigma = get_phi(h, gX, gXstar)

                # Reparameterization trick.
                eps = tf.random.normal(shape=tf.shape(mu), dtype=dtype)
                theta = mu + sigma * eps

                # Posterior log probability for θ (per sample from test split).
                posterior_log_probs = log_normal_pdf(theta, mu, sigma)  # shape: [n]

                # Compute observed log likelihoods.
                if classification:
                    if binary:
                        # For training, use averaged θ.
                        theta_avg = tf.reduce_mean(theta, axis=0, keepdims=True)  # shape: [1, d_theta]
                        train_logits = tf.matmul(gX, theta_avg, transpose_b=True)  # shape: [m, 1]
                        obs_train_log_probs_vec = log_likelihood_binary(train_logits, tf.cast(tmp_Y, dtype))
                        observed_train_ll = tf.reduce_mean(obs_train_log_probs_vec)

                        # For test, each sample uses its own θ.
                        test_logits_full = tf.matmul(gXstar, theta, transpose_b=True)  # shape: [n, n]
                        test_logits = tf.linalg.diag_part(test_logits_full)  # shape: [n]
                        obs_test_log_probs_vec = log_likelihood_binary(tf.expand_dims(test_logits, axis=-1),
                                                                       tf.cast(tmp_Ystar, dtype))
                        observed_test_ll = tf.reduce_mean(obs_test_log_probs_vec)

                        int_train_ll_vec = integrated_log_likelihood_classification(train_logits, binary=True)
                        int_train_ll_norm = normalize_integrated_ll(int_train_ll_vec)
                        integrated_train_ll = tf.reduce_mean(int_train_ll_norm)

                        test_logits_exp = tf.expand_dims(test_logits, axis=-1)
                        int_test_ll_vec = integrated_log_likelihood_classification(test_logits_exp, binary=True)
                        int_test_ll_norm = normalize_integrated_ll(int_test_ll_vec)
                        integrated_test_ll = tf.reduce_mean(int_test_ll_norm)

                    else:
                        d_rep = tf.shape(gXstar)[1]
                        theta_multi = tf.reshape(theta, [tf.shape(theta)[0], d_rep, num_classes])

                        theta_avg = tf.reduce_mean(theta, axis=0, keepdims=True)  # shape: [1, d*num_classes]
                        theta_avg_multi = tf.reshape(theta_avg, [d_rep, num_classes])
                        train_logits = tf.matmul(gX, theta_avg_multi)  # shape: [m, num_classes]
                        obs_train_log_probs_vec = log_likelihood_multiclass(train_logits, tmp_Y, num_classes)
                        observed_train_ll = tf.reduce_mean(obs_train_log_probs_vec)

                        test_logits = tf.einsum('nd, ndc -> nc', gXstar, theta_multi)  # shape: [n, num_classes]
                        obs_test_log_probs_vec = log_likelihood_multiclass(test_logits, tmp_Ystar, num_classes)
                        observed_test_ll = tf.reduce_mean(obs_test_log_probs_vec)

                        int_train_ll_vec = integrated_log_likelihood_classification(train_logits, binary=False)
                        int_train_ll_norm = normalize_integrated_ll(int_train_ll_vec)
                        integrated_train_ll = tf.reduce_mean(int_train_ll_norm)

                        int_test_ll_vec = integrated_log_likelihood_classification(test_logits, binary=False)
                        int_test_ll_norm = normalize_integrated_ll(int_test_ll_vec)
                        integrated_test_ll = tf.reduce_mean(int_test_ll_norm)

                else:
                    # Regression.
                    theta_avg = tf.reduce_mean(theta, axis=0, keepdims=True)  # shape: [1, d_theta]
                    train_preds = tf.matmul(gX, theta_avg, transpose_b=True)  # shape: [m, 1]
                    obs_train_log_probs_vec = log_likelihood_regression(train_preds, tmp_Y, sigma=1.0)
                    observed_train_ll = tf.reduce_mean(obs_train_log_probs_vec)

                    # For test predictions, compute an element-wise dot product.
                    test_preds = tf.reduce_sum(gXstar * theta, axis=-1, keepdims=True)  # shape: [n, 1]
                    obs_test_log_probs_vec = log_likelihood_regression(test_preds, tmp_Ystar, sigma=1.0)
                    observed_test_ll = tf.reduce_mean(obs_test_log_probs_vec)

                    int_train_ll_vec = integrated_log_likelihood_regression(train_preds, y_min, y_max,
                                                                           num_mc_samples=num_mc_samples)
                    int_train_ll_norm = normalize_integrated_ll(int_train_ll_vec)
                    integrated_train_ll = tf.reduce_mean(int_train_ll_norm)

                    int_test_ll_vec = integrated_log_likelihood_regression(test_preds, y_min, y_max,
                                                                          num_mc_samples=num_mc_samples)
                    int_test_ll_norm = normalize_integrated_ll(int_test_ll_vec)
                    integrated_test_ll = tf.reduce_mean(int_test_ll_norm)

                    # Save these for printing MSE.
                    train_preds_last = train_preds
                    test_preds_last = test_preds
                    tmp_Y_last = tmp_Y
                    tmp_Ystar_last = tmp_Ystar

                # Save for printing (last split in the iteration).
                observed_train_ll_last = observed_train_ll
                observed_test_ll_last = observed_test_ll

                # Compute prior log probability from normalized integrated likelihoods.
                prior_lp = integrated_train_ll + integrated_test_ll
                posterior_lp = tf.reduce_mean(posterior_log_probs)
                kl = posterior_lp - prior_lp

                observed_ll = observed_train_ll + observed_test_ll
                # Our objective: maximize observed_ll - lambda * kl.
                # Loss is the negative of this quantity.
                loss_j = -(observed_ll - lambda_val * kl)
                split_losses.append(loss_j)

            split_loss = tf.reduce_mean(tf.stack(split_losses))
            L_var = tf.math.reduce_variance(tf.stack(split_losses))
            final_loss = split_loss + tau * L_var

        grads = tape.gradient(final_loss, h.trainable_variables)
        optimizer.apply_gradients(zip(grads, h.trainable_variables))
        history.append(final_loss.numpy())

        # For printing, if regression, compute MSE.
        if not classification:
            mse_train = tf.reduce_mean(tf.square(train_preds_last - tmp_Y_last))
            mse_test = tf.reduce_mean(tf.square(test_preds_last - tmp_Ystar_last))
            print("--- iteration %d, loss %.4f, split var: %.4f, train MSE: %.4f, test MSE: %.4f, KL: %.4f ---" %
                  (k, float(final_loss), float(L_var), float(mse_train), float(mse_test), float(kl)))
        else:
            # For classification, print the negative log probabilities (losses).
            train_loss_val = -observed_train_ll_last
            test_loss_val = -observed_test_ll_last
            print("--- iteration %d, loss %.4f, split var: %.4f, train loss: %.4f, test loss: %.4f, KL: %.4f ---" %
                  (k, float(final_loss), float(L_var), float(train_loss_val), float(test_loss_val), float(kl)))

    return h, history