import jax
import jax.numpy as jnp
import equinox as eqx
import jax.scipy.stats as stats
import numpy as np
import sys

@eqx.filter_jit
def generate_pi(selected_errors: jnp.array):
    alphas = jnp.array([0.05,0.10,0.15,0.20,0.25,0.30,0.35,0.40,0.45])

    num_samples = selected_errors.shape[0]
    error_dim = selected_errors.shape[-1]
    bs = selected_errors.shape[1]

    selected_errors = jnp.array(selected_errors) # [num_samples,bs,seq_len,error_dim]
    selected_errors = selected_errors.reshape(num_samples,-1,error_dim) # [num_samples,bs*seq_len,error_dim]

    vmap_a = jax.vmap(jnp.quantile, in_axes=(1,None,None)) # third loop over bs*seq_len
    vmap_b = jax.vmap(vmap_a, in_axes=(2,None,None)) # second loop over error_dim
    vmap_c = jax.vmap(vmap_b, in_axes=(None,0,None)) # first loop over alphas
    lower_quant = vmap_c(selected_errors, alphas, 0) # [len(alphas),error_dim,bs*seq_len]
    upper_quant = vmap_c(selected_errors, 1-alphas, 0) 

    lower_quant= jnp.array(lower_quant).reshape(alphas.shape[0],error_dim,bs,-1) # [len(alphas),error_dim,bs,seq_len]
    upper_quant= jnp.array(upper_quant).reshape(alphas.shape[0],error_dim,bs,-1)
    
    lower_quant = jnp.transpose(lower_quant, (2,0,3,1))
    upper_quant = jnp.transpose(upper_quant, (2,0,3,1))
    
    quant = jnp.concatenate((upper_quant, lower_quant), axis=1)
    mean_quant = jnp.expand_dims(jnp.mean(quant, axis=1), axis=1)
    quant = jnp.concatenate((mean_quant,quant), axis=1)

    return quant

def pi_width_func(lower_mu_quant: jax.Array, upper_mu_quant: jax.Array, batch_sz: int):
    """
    Args:
        lower_mu_quant: [batch_size*seq_len]
        upper_mu_quant:[batch_size*seq_len]
    """
    width = jnp.abs(upper_mu_quant-lower_mu_quant)
    avg_width = width.sum() / batch_sz

    return avg_width
    

def ensemble_pred(pred_mu: jax.Array, pred_logvar: jax.Array):

    """
    Args:
        unnorm_mu: JAX array of shape [models, batch_size, seq_len, outs]
        unnorm_var: JAX array of shape [models, batch_size, seq_len, outs]
    Returns:
    tuple: (pred_mu, pred_var)
            Shapes:
            - pred_mu: [batch_size, seq_len, outs]
            - pred_var: [batch_size, seq_len, outs]
    """
    ens_mu = jnp.mean(pred_mu, axis=0)
    pred_var = jnp.exp(pred_logvar)

    ens_var = jnp.mean(pred_var, axis=0) + jnp.var(pred_mu, axis=0)

    return ens_mu, ens_var

def z_scores():
    """Calculates the z-scores for given set of quantiles in a standard normal."""
    lower_alphas = jnp.array([0.05,0.10,0.15,0.20,0.25,0.30,0.35,0.40,0.45])
    upper_alphas = 1 - lower_alphas

    # Use jax.scipy.stats.norm.ppf for percentile point function
    lower_z_score = stats.norm.ppf(lower_alphas)
    upper_z_score = stats.norm.ppf(upper_alphas)

    # Unsqueeze dimensions using jnp.expand_dims or slicing with None
    # (num_alphas,) -> (num_alphas, 1, 1) as in your original PyTorch code
    lower_z_score = jnp.expand_dims(lower_z_score, axis=(-1, -2))
    upper_z_score = jnp.expand_dims(upper_z_score, axis=(-1, -2))

    return lower_z_score, upper_z_score

def save_scores(lower_mu_quant: jax.Array, upper_mu_quant: jax.Array,\
                        true_error: jax.Array):
    
    """    
    Args: 
        lower_mu_quant: [alphas,batch_size,seq_len,outs]
        upper_mu_quant:[alphas,batch_size,seq_len,outs]
        true_error:[batch_size,seq_len,outs]
    """  
    batch_sz = true_error.shape[0]
    seq_len = true_error.shape[1]
    outs = true_error.shape[-1]

    lower_mu_quant = lower_mu_quant.reshape(-1,batch_sz*seq_len,outs)
    upper_mu_quant = upper_mu_quant.reshape(-1,batch_sz*seq_len,outs)
    true_error = true_error.reshape(batch_sz*seq_len,outs)

    ############ Calibration Scores ############
    vmap1 = jax.vmap(calibration_score, in_axes=(0,0,None)) # first loop over outs
    vmap2 = jax.vmap(vmap1, in_axes=(2,2,1)) # second loop alphas
    scores = vmap2(lower_mu_quant, upper_mu_quant, true_error)

    ############ Prediction Interval Width ############
    vmap1 = jax.vmap(pi_width_func, in_axes=(0,0,None)) # first loop over outs
    vmap2 = jax.vmap(vmap1, in_axes=(2,2,None)) # second loop alphas
    pi_width = vmap2(lower_mu_quant, upper_mu_quant, batch_sz)
    pi_width = jnp.mean(pi_width, axis=1)

    return scores, pi_width 

def calibration_score(lower_mu_quant: jax.Array, upper_mu_quant: jax.Array,\
                        true_error: jax.Array):
    """
    Check if the expected frequency matches the observed frequency
    Args:
        lower_mu_quant: [batch_size*seq_len]
        upper_mu_quant:[batch_size*seq_len]
        true_error:[batch_size*seq_len]
    """ 
    total = true_error.shape[0] 
    
    below_lower = jnp.sum(true_error < lower_mu_quant)
    above_upper = jnp.sum(true_error > upper_mu_quant)

    inside = total - (below_lower + above_upper)
    score = inside.astype(jnp.float32) / total  

    return score 

def select_quantiles(unnorm_mu: jax.Array, unnorm_var: jax.Array,
                         lower_z_scores: jax.Array, upper_z_scores: jax.Array):
    """
    Calculates prediction intervals (quantiles) from unnormalized mean and variance
    using provided z-scores.

    Args:
        unnorm_mu: JAX array of shape [batch_size, seq_len, outs]
                   (unnormalized predictions).
        unnorm_var: JAX array of shape [batch_size, seq_len, outs]
                    (predicted variance, not log-variance if sqrt is directly applied).
        lower_z_scores: JAX array of shape [num_alphas, 1, 1]
        upper_z_scores: JAX array of shape [num_alphas, 1, 1]

    Returns:
        tuple: (unnorm_mu_np, unnorm_var_np, unnorm_upper_mu_np, unnorm_lower_mu_np)
               All converted to NumPy arrays.
               Shapes:
               - unnorm_mu_np: [batch_size, seq_len, outs]
               - unnorm_var_np: [batch_size, seq_len, outs]
               - unnorm_upper_mu_np: [num_alphas, batch_size, seq_len, outs]
               - unnorm_lower_mu_np: [num_alphas, batch_size, seq_len, outs]
    """
    batch_sz, seq_len, outs = unnorm_mu.shape

    unnorm_mu_flat = unnorm_mu.reshape(batch_sz * seq_len, outs)
    unnorm_var_flat = unnorm_var.reshape(batch_sz * seq_len, outs) # [batch_size * seq_len, outs]

    unnorm_sigma_flat = jnp.sqrt(unnorm_var_flat)

    unnorm_mu_expanded = jnp.expand_dims(unnorm_mu_flat, axis=0) # [1, batch_size * seq_len, outs]
    unnorm_sigma_expanded = jnp.expand_dims(unnorm_sigma_flat, axis=0)
 
    unnorm_upper_mu = unnorm_mu_expanded + jnp.multiply(unnorm_sigma_expanded, upper_z_scores) # [num_alphas, batch_size * seq_len, outs]
    unnorm_lower_mu = unnorm_mu_expanded - jnp.multiply(unnorm_sigma_expanded, upper_z_scores)

    unnorm_mu_reshaped = unnorm_mu_flat.reshape(batch_sz, seq_len, outs) # [batch_size, seq_len, outs]
    unnorm_var_reshaped = unnorm_var_flat.reshape(batch_sz, seq_len, outs)

    unnorm_upper_mu_reshaped = unnorm_upper_mu.reshape(-1, batch_sz, seq_len, outs) # [num_alphas,batch_size,seq_len,outs]
    unnorm_lower_mu_reshaped = unnorm_lower_mu.reshape(-1, batch_sz, seq_len, outs)

    unnorm_mu_np = np.array(unnorm_mu_reshaped)
    unnorm_var_np = np.array(unnorm_var_reshaped)
    unnorm_upper_mu_np = np.array(unnorm_upper_mu_reshaped)
    unnorm_lower_mu_np = np.array(unnorm_lower_mu_reshaped)

    mu_pi = np.concatenate((np.expand_dims(unnorm_mu_np, axis=0),unnorm_upper_mu_np,unnorm_lower_mu_np), axis=0)
    mu_pi = np.transpose(mu_pi, axes=(1,0,2,3))

    return unnorm_lower_mu_np, unnorm_upper_mu_np, mu_pi
