import numpy as np

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_probability as tfp
tfd = tfp.distributions

from tqdm.notebook import tqdm

# ============================
# LOG-LIKELIHOODS
# ============================

def log_normal_pdf(x, mu, sigma, eps=1e-8):
    """
    Computes the log density of a Normal(x | mu, sigma), summed over last axis.

    Args:
        x: Tensor of shape [..., d], samples.
        mu: Tensor of same shape, mean of Gaussian.
        sigma: Tensor of same shape, stddev of Gaussian.
        eps: Small constant to avoid log(0).

    Returns:
        Tensor of log-density values.
    """
    log2pi = tf.math.log(2.0 * np.pi)
    return -0.5 * tf.reduce_sum(
        log2pi + 2 * tf.math.log(sigma + eps) + tf.square(x - mu) / (tf.square(sigma) + eps),
        axis=-1
    )

def log_likelihood_binary(logits, y):
    """
    Computes Bernoulli log-likelihood: log Bernoulli(sigmoid(logits)).

    Args:
        logits: Tensor of shape [batch, 1].
        y: Binary labels of shape [batch].

    Returns:
        Tensor of log-density values.
    """
    p = tf.sigmoid(logits)
    return tf.squeeze(y * tf.math.log(p + 1e-8) + (1 - y) * tf.math.log(1 - p + 1e-8), axis=-1)

def log_likelihood_multiclass(logits, y, num_classes):
    """
    Computes log-likelihood for multi-class classification.

    Args:
        logits: Tensor of shape [batch, num_classes].
        y: Sparse labels (ints) of shape [batch] or [batch, 1].
        num_classes: Number of output classes.

    Returns:
        Tensor of log-density values.
    """
    y_one_hot = tf.one_hot(tf.reshape(y, [-1]), depth=num_classes)
    return -tf.nn.softmax_cross_entropy_with_logits(labels=y_one_hot, logits=logits)

def log_likelihood_regression(preds, y, sigma=1.0):
    """
    Log-likelihood for Gaussian regression.

    Args:
        preds: Model predictions, shape [batch, 1].
        y: True values, shape [batch, 1].
        sigma: Standard deviation.

    Returns:
        Tensor of shape [batch].
    """
    return -0.5 * tf.math.log(2 * np.pi * sigma**2) - 0.5 * tf.square((y - preds)/sigma)

# ============================
# INTEGRATED LIKELIHOODS
# ============================

def integrated_log_likelihood_classification(logits, binary=True):
    """
    Args:
        logits: [batch, 1] if binary, or [batch, num_classes] if multiclass.
        binary: True if binary classification.

    Returns:
        Tensor of shape [batch].
    """
    if binary:
        p = tf.sigmoid(logits)
        p0 = 1 - p
        return tf.squeeze(p * tf.math.log(p + 1e-8) + p0 * tf.math.log(p0 + 1e-8), axis=-1)
    else:
        p = tf.nn.softmax(logits)
        return tf.reduce_sum(p * tf.math.log(p + 1e-8), axis=-1)

def integrated_log_likelihood_regression(preds, y_min, y_max, num_mc_samples=10):
    """
    Monte Carlo estimate of marginal log-likelihood for regression.

    Args:
        preds: Model predictions, shape [batch, 1].
        y_min, y_max: Range for uniform MC integration.
        num_mc_samples: Number of MC samples.

    Returns:
        Tensor of shape [batch].
    """
    R = num_mc_samples
    y_samples = tf.random.uniform(shape=[R], minval=y_min, maxval=y_max, dtype=preds.dtype)
    pred_expanded = tf.expand_dims(preds, axis=0)
    y_samples_expanded = tf.reshape(y_samples, [R, 1, 1])
    ll = -0.5 * tf.math.log(2 * np.pi) - 0.5 * tf.square(y_samples_expanded - pred_expanded)
    ll = tf.squeeze(ll, axis=-1)
    return tf.reduce_mean(ll, axis=0)

def normalize_integrated_ll(v):
    """
    Normalize a log-density vector using softmax normalization (log-mean-exp).

    Args:
        v: Tensor of unnormalized log-likelihoods.

    Returns:
        Normalized vector.
    """
    Z = tf.reduce_mean(tf.exp(v))
    return v - tf.math.log(Z + 1e-8)

# ============================
# TRAINING
# ============================

def get_phi(h, gX, gXstar):
    """
    Given summary and input features, compute posterior parameters (mu, sigma) from h.

    Args:
        h: Posterior network.
        gX: Training embeddings.
        gXstar: Test embeddings.

    Returns:
        Tuple (mu, sigma) of posterior parameters.
    """
    gX_mean = tf.reduce_mean(gX, axis=0, keepdims=True)
    gX_mean_tiled = tf.tile(gX_mean, [tf.shape(gXstar)[0], 1])
    phi = h([gXstar, gX_mean_tiled])
    mu, sigma = tf.split(phi, num_or_size_splits=2, axis=-1)
    sigma = tf.nn.softplus(sigma) + 1e-6
    return mu, sigma


def train_posterior(Xs, Ys, g, h, optimizer, J, m, n, K=10, tau=1.0, lambda_val=1.0,
                    classification=True, binary=True, num_classes=None,
                    num_mc_samples=10, dtype='float32'):
    """
    Args:
        Xs, Ys: Full training data.
        g: Embedding network.
        h: Posterior network (outputs mu and sigma).
        optimizer: optimizer.
        J: Number of splits per iteration.
        m, n: Number of samples per test and train environment.
        K: Number of training steps.
        tau: Variance penalty weight.
        lambda_val: KL weight.
        classification: Boolean flag.
        binary: If True, use binary classification.
        num_classes: Number of classes (for multiclass).
        num_mc_samples: Number of samples for MC integration.
        dtype: Floating point precision.

    Returns:
        Trained model `h`, and list of losses per iteration.
    """
    history = []

    if not classification:
        Ys_tensor = tf.cast(tf.convert_to_tensor(Ys), dtype)
        y_min = tf.reduce_min(Ys_tensor)
        y_max = tf.reduce_max(Ys_tensor)

    if classification and not binary and num_classes is None:
        raise ValueError("For multi-class classification, num_classes must be provided.")

    for k in tqdm(range(K)):
        splits = splits_fn(Xs, Ys, J, m, n)
        split_losses = []

        # Buffers to log diagnostics
        observed_train_ll_last = None
        observed_test_ll_last = None
        train_preds_last = None
        test_preds_last = None
        tmp_Y_last = None
        tmp_Ystar_last = None

        with tf.GradientTape() as tape:
            for j in range(J):
                tmp_X, tmp_Y, tmp_Xstar, tmp_Ystar = splits[j]
                gX = g(tmp_X)
                gXstar = g(tmp_Xstar)
                mu, sigma = get_phi(h, gX, gXstar)

                # Sample θ from q
                eps = tf.random.normal(shape=tf.shape(mu), dtype=dtype)
                theta = mu + sigma * eps
                posterior_log_probs = log_normal_pdf(theta, mu, sigma)

                # -------------------------
                # LIKELIHOOD TERMS
                # -------------------------
                if classification:
                    if binary:
                        # Binary classification
                        theta_avg = tf.reduce_mean(theta, axis=0, keepdims=True)
                        train_logits = tf.matmul(gX, theta_avg, transpose_b=True)
                        obs_train_log_probs_vec = log_likelihood_binary(train_logits, tf.cast(tmp_Y, dtype))
                        observed_train_ll = tf.reduce_mean(obs_train_log_probs_vec)

                        test_logits_full = tf.matmul(gXstar, theta, transpose_b=True)
                        test_logits = tf.linalg.diag_part(test_logits_full)
                        obs_test_log_probs_vec = log_likelihood_binary(tf.expand_dims(test_logits, axis=-1), tf.cast(tmp_Ystar, dtype))
                        observed_test_ll = tf.reduce_mean(obs_test_log_probs_vec)

                        # Integrated log likelihoods
                        int_train_ll = normalize_integrated_ll(integrated_log_likelihood_classification(train_logits, binary=True))
                        int_test_ll = normalize_integrated_ll(integrated_log_likelihood_classification(tf.expand_dims(test_logits, axis=-1), binary=True))

                    else:
                        # Multi-class classification
                        d_rep = tf.shape(gXstar)[1]
                        theta_multi = tf.reshape(theta, [tf.shape(theta)[0], d_rep, num_classes])
                        theta_avg_multi = tf.reshape(tf.reduce_mean(theta, axis=0, keepdims=True), [d_rep, num_classes])
                        train_logits = tf.matmul(gX, theta_avg_multi)
                        obs_train_log_probs_vec = log_likelihood_multiclass(train_logits, tmp_Y, num_classes)
                        observed_train_ll = tf.reduce_mean(obs_train_log_probs_vec)

                        test_logits = tf.einsum('nd, ndc -> nc', gXstar, theta_multi)
                        obs_test_log_probs_vec = log_likelihood_multiclass(test_logits, tmp_Ystar, num_classes)
                        observed_test_ll = tf.reduce_mean(obs_test_log_probs_vec)

                        int_train_ll = normalize_integrated_ll(integrated_log_likelihood_classification(train_logits, binary=False))
                        int_test_ll = normalize_integrated_ll(integrated_log_likelihood_classification(test_logits, binary=False))

                else:
                    # Regression
                    theta_avg = tf.reduce_mean(theta, axis=0, keepdims=True)
                    train_preds = tf.matmul(gX, theta_avg, transpose_b=True)
                    obs_train_log_probs_vec = log_likelihood_regression(train_preds, tmp_Y, sigma=1.0)
                    observed_train_ll = tf.reduce_mean(obs_train_log_probs_vec)

                    test_preds = tf.reduce_sum(gXstar * theta, axis=-1, keepdims=True)
                    obs_test_log_probs_vec = log_likelihood_regression(test_preds, tmp_Ystar, sigma=1.0)
                    observed_test_ll = tf.reduce_mean(obs_test_log_probs_vec)

                    int_train_ll = normalize_integrated_ll(integrated_log_likelihood_regression(train_preds, y_min, y_max, num_mc_samples))
                    int_test_ll = normalize_integrated_ll(integrated_log_likelihood_regression(test_preds, y_min, y_max, num_mc_samples))

                    train_preds_last, test_preds_last = train_preds, test_preds
                    tmp_Y_last, tmp_Ystar_last = tmp_Y, tmp_Ystar

                # KL between posterior and integrated prior
                prior_lp = tf.reduce_mean(int_train_ll) + tf.reduce_mean(int_test_ll)
                posterior_lp = tf.reduce_mean(posterior_log_probs)
                kl = posterior_lp - prior_lp

                observed_ll = observed_train_ll + observed_test_ll
                loss_j = -(observed_ll - lambda_val * kl)
                split_losses.append(loss_j)

                # Save diagnostics
                observed_train_ll_last = observed_train_ll
                observed_test_ll_last = observed_test_ll

            # Add variance penalty
            split_loss = tf.reduce_mean(tf.stack(split_losses))
            L_var = tf.math.reduce_variance(tf.stack(split_losses))
            final_loss = split_loss + tau * L_var

        grads = tape.gradient(final_loss, h.trainable_variables)
        optimizer.apply_gradients(zip(grads, h.trainable_variables))
        history.append(final_loss.numpy())

        # Diagnostic print
        if not classification:
            mse_train = tf.reduce_mean(tf.square(train_preds_last - tmp_Y_last))
            mse_test = tf.reduce_mean(tf.square(test_preds_last - tmp_Ystar_last))
            print("--- iteration %d, loss %.4f, split var: %.4f, train MSE: %.4f, test MSE: %.4f, KL: %.4f ---" %
                  (k, float(final_loss), float(L_var), float(mse_train), float(mse_test), float(kl)))
        else:
            train_loss_val = -observed_train_ll_last
            test_loss_val = -observed_test_ll_last
            print("--- iteration %d, loss %.4f, split var: %.4f, train loss: %.4f, test loss: %.4f, KL: %.4f ---" %
                  (k, float(final_loss), float(L_var), float(train_loss_val), float(test_loss_val), float(kl)))

    return h, history

# ============================
# SYNTHETIC ENVIROMENTS
# ============================

def splits_fn(Xs, Ys, K, m, n):
    """
    Creates K bootstrap splits with m pseudo-test and n pseudo-train points each.

    Returns:
        List of tuples (X_train, Y_train, X_test, Y_test).
    """
    splits = []
    for k in range(K):
        N = len(Xs)
        fake_test_idx = np.random.choice(N, m, replace=True)
        fake_train_ids = np.random.choice(N, n, replace=True)
        splits.append([Xs[fake_train_ids], Ys[fake_train_ids], Xs[fake_test_idx], Ys[fake_test_idx]])
    return splits
