import numpy as np
from scipy.stats import truncnorm

# Compute normalized surprise
def calculate_surprise(r_obs, r_pred, sigma, a=-1, b=1):
    # Convert bounds to scaled space
    a_scaled = (a - r_pred) / sigma
    b_scaled = (b - r_pred) / sigma
    
    # Compute densities
    density_obs = truncnorm.pdf(r_obs, a_scaled, b_scaled, loc=r_pred, scale=sigma)
    density_mean = truncnorm.pdf(r_pred, a_scaled, b_scaled, loc=r_pred, scale=sigma)
    # probability = truncnorm.cdf(r_obs+0.05, a_scaled, b_scaled, loc=r_pred, scale=sigma) - truncnorm.cdf(r_obs-0.05, a_scaled, b_scaled, loc=r_pred, scale=sigma)
    # print(density_mean)
    # print(density_obs)
    # print(probability)
    # print(-np.log(probability))
    # Normalize and compute surprise
    surprise = -np.log(density_obs / density_mean)
    return surprise