import jax
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, log_likelihood

class BayesianLogisticRegression:
    def __init__(self, input_dim, mean=None, bias=0.0, scale=1.0):
        self.input_dim = input_dim
        self.mean = jnp.zeros(input_dim) if mean is None else jnp.array(mean)
        self.bias = bias
        self.scale = scale

    def model(self, x, y=None,):
        weight = numpyro.sample("weight", dist.Normal(self.mean, self.scale).to_event(1))
        bias = numpyro.sample("bias", dist.Normal(self.bias, self.scale))
        
        logits = jnp.dot(x, weight) + bias
        numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=y)

class BayesianLogisticRegressionMulti:
    def __init__(self, input_dim, output_dim, mean=None, bias=None, scale=1.0):
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.mean = (
            jnp.zeros((output_dim, input_dim)) if mean is None else jnp.array(mean)
        )
        self.bias = jnp.zeros(output_dim) if bias is None else jnp.array(bias)
        self.scale = scale

        # Check dimensions of the provided priors
        if self.mean.shape != (output_dim, input_dim):
            raise ValueError(f"Expected mean shape ({output_dim}, {input_dim}), got {self.mean.shape}")
        if self.bias.shape != (output_dim,):
            raise ValueError(f"Expected bias shape ({output_dim},), got {self.bias.shape}")

    def model(self, x, y=None):
        weight = numpyro.sample(
            "weight",
            dist.Normal(self.mean, self.scale).to_event(2)
        )
        bias = numpyro.sample(
            "bias",
            dist.Normal(self.bias, self.scale).to_event(1)
        )

        logits = jnp.dot(x, weight.T) + bias
        numpyro.sample("obs", dist.Categorical(logits=logits), obs=y)

class BayesianNeuralNetwork:
    def __init__(
        self, 
        hid_dim, 
        activation, 
        num_hidden_layers, 
        noise_dist
    ):
        self.hid_dim = hid_dim
        self.activation = activation
        self.num_hidden_layers = num_hidden_layers
        self.noise_dist = noise_dist

    def model(self, x, y=None):
        N, D_X = x.shape  # N = number of data points, D_X = input dimensionality

        # Common prior distribution for weights and biases
        prior_dist = dist.Normal(0.0, 1.0)

        # Priors for the first layer weights and biases
        w1 = numpyro.sample("w1", prior_dist.expand([D_X, self.hid_dim]).to_event(2))
        b1 = numpyro.sample("b1", prior_dist.expand([self.hid_dim]).to_event(1))

        # Forward pass through the first hidden layer
        if not eval:
            z1 = self.activation(jnp.dot(x, w1) + b1)
        else:
            # z1 = self.activation(jnp.einsum("bij,jn->bin", w1.transpose(0,2,1), x.T).transpose(0, 2, 1) + b1[:, None, :])
            z1 = self.activation(jnp.matmul(w1.transpose(0,2,1), x.T).transpose(0, 2, 1) +  b1[:, None, :])
        

        # Pass through subsequent hidden layers if any
        z = z1
        for i in range(2, self.num_hidden_layers + 1):
            w = numpyro.sample(f"w{i}", prior_dist.expand([self.hid_dim, self.hid_dim]).to_event(2))
            b = numpyro.sample(f"b{i}", prior_dist.expand([self.hid_dim]).to_event(1))
            z = self.activation(jnp.dot(z, w) + b)

        # Priors for the output layer weights and biases
        w_out = numpyro.sample("w_out", prior_dist.expand([self.hid_dim, 1]).to_event(2))
        b_out = numpyro.sample("b_out", prior_dist.expand([1]).to_event(1))

         # Sample the observation noise scale
        noise_scale = numpyro.sample("noise_scale", self.noise_dist)

        # Compute final output
        y_hat = jnp.dot(z, w_out) + b_out
        if y is not None:
            y_hat = y_hat.reshape(*y.shape)
        dis = dist.Normal(y_hat, noise_scale).to_event(1)
       
        numpyro.sample("obs", dis, obs=y)
        
def fit_model(model, x, y, num_samples=2000, burn_in=1000, num_chains=4, rng_key=0, bar=False):
    model_fn = model.model
    rng_key = jax.random.PRNGKey(rng_key)
    kernel = NUTS(model_fn)
    mcmc = MCMC(kernel, num_samples=num_samples, num_warmup=burn_in, num_chains=num_chains, chain_method='vectorized', progress_bar=bar)
    mcmc.run(rng_key, x=x, y=y)
    return mcmc

def get_log_probs(mcmc, test_data, model=None, intermediate=False):
    """
    Compute joint log likelihood for all data points
    If model is None, mcmc should be a numpyro MCMC object.
    Else, model hould be the numpyro function and mcmc should be the samples dictionary
    """
    # Compute log likelihood for observations
    jax.clear_caches()
    if model is None:
        log_probs = log_likelihood(mcmc.sampler.model, mcmc.get_samples(), **test_data)["obs"]
    else:
        log_probs = log_likelihood(model, mcmc, **test_data)["obs"]
    jax.clear_caches()

    if intermediate:
        return log_probs
    log_sum_exp = jax.scipy.special.logsumexp(log_probs.sum(axis=1), axis=0)
    return log_sum_exp - jnp.log(log_probs.size)

def get_indiv_log_probs(mcmc, test_data, model=None):
    """
    Compute log likelihoods for individual data points
    If model is None, mcmc should be a numpyro MCMC object.
    Else, model hould be the numpyro function and mcmc should be the samples dictionary
    """
    jax.clear_caches()
    if model is None:
        log_probs = log_likelihood(mcmc.sampler.model, mcmc.get_samples(), **test_data)["obs"]
    else:
        log_probs = log_likelihood(model, mcmc, **test_data)["obs"]
    log_sum_exp = jax.scipy.special.logsumexp(log_probs, axis=0)
    return log_sum_exp - jnp.log(log_probs.shape[0])
