import numpy as np
import matplotlib.pyplot as plt
import matplotlib
SEED_NR = 44
def kalman_update(n, M, x, alpha, beta, mu, hatlambda_t, barN_t, P_t, simulate = False):
    # Compute p based on the current estimated lambda_t
    c = (alpha*M-barN_t)/(alpha*M)
    #print(c)

    p_t = ((hatlambda_t / (2 + x + hatlambda_t * (1 + 0.1 * x))) * (1 + 0.1 * x))*c
    
    # Mean and variance of N_t
    mu_N_t = (mu * (1 - beta) +alpha *  M * p_t)
    sigma_N_t_sq = (alpha**2) * M * p_t * (1 - p_t)  # Corrected variance

    # Observation noise variance
    sigma_eps_t_sq = (alpha**2) * p_t * (1 - p_t) * (M - n + ((M - n)**2) / n)
    
    # Prediction step
    hatlambda_t_plus1_pred = beta * hatlambda_t + mu_N_t  # Predicted lambda_{t+1}
    P_t_plus1_pred = beta**2 * P_t + sigma_N_t_sq  # Predicted variance
    
    # Kalman Gain
    K_t = np.where(P_t_plus1_pred + sigma_eps_t_sq > 0,
                       P_t_plus1_pred/ (P_t_plus1_pred + sigma_eps_t_sq),
                       0)  # Avoid division by zero
    
    # Update step
    hatlambda_t_plus1 = hatlambda_t_plus1_pred + K_t * (barN_t - mu_N_t)  # Updated lambda_{t+1}
    
    # Clip lambda to be no smaller than mu
    hatlambda_t_plus1 = np.maximum(hatlambda_t_plus1, mu)
    
    P_t_plus1 = (1 - K_t) * P_t_plus1_pred  # Updated variance
    
    # Simulate next observation barN_{t+1}
    if simulate:
        #barN_t_plus1 = mu_N_t + np.random.normal(0, np.sqrt(sigma_eps_t_sq), size=mu_N_t.shape)
        #return hatlambda_t_plus1, barN_t_plus1, P_t_plus1, K_t

        # Sample the true number of defaults from the process noise distribution
        true_N_t = mu_N_t + np.random.normal(0, np.sqrt(sigma_N_t_sq), size=mu_N_t.shape)
        
        # Now add the observation noise to get the observed default count
        barN_t_plus1 = true_N_t + np.random.normal(0, np.sqrt(sigma_eps_t_sq), size=mu_N_t.shape)
        return hatlambda_t_plus1, barN_t_plus1, P_t_plus1, K_t
    else:
        return hatlambda_t_plus1, None, P_t_plus1, K_t

matplotlib.rcParams["font.family"] = "serif"
matplotlib.rcParams["mathtext.fontset"] = "cm"

def plot_kalman_filter(
        timesteps=100,  
        S=1,  
        n_list=[450],  # List of n values
        M=500,  
        x=0,  
        alpha=0.002,  
        beta=0.5,  
        mu=0.001,  
        seed=42  # Fixed seed for reproducibility
    ):
    np.random.seed(seed)  # Ensure the same random seed for all simulations
    plt.figure(figsize=(24, 7))

    colors = ['blue', 'orange', 'green', 'red', 'purple', 'brown']  # Extend if needed

    lambda_data = {}
    P_data = {}
    K_data = {}

    for i, n in enumerate(n_list):
        hatlambda_t_vec = np.full(S, mu)
        P_t_vec = np.full(S, 0)
        barN_t_vec = np.zeros(S)

        lambda_means = np.zeros(timesteps)
        P_means = np.zeros(timesteps)
        K_means = np.zeros(timesteps)
        np.random.seed(SEED_NR)
        for t in range(timesteps):
            hatlambda_t_vec, barN_t_vec, P_t_vec, K_t_vec = kalman_update(n, M, x, alpha, beta, mu, hatlambda_t_vec, barN_t_vec, P_t_vec, simulate=True)
            lambda_means[t] = np.mean(hatlambda_t_vec)
            P_means[t] = np.mean(P_t_vec)
            K_means[t] = np.mean(K_t_vec)

        lambda_data[n] = lambda_means
        P_data[n] = np.sqrt(P_means)  # Convert variance to standard deviation
        K_data[n] = K_means

    axes = [plt.subplot(1, 3, i+1) for i in range(3)]
    titles = [
        r"Estimated State $\hat{\lambda}(t)$",
        "Standard Deviation of State Estimate",
        "Kalman Gain $K(t)$"
    ]
    ylabels = [
        r"$\hat{\lambda}(t)$",
        r"$\sqrt{P(t)}$",
        r"$K(t)$"
    ]

    for i, n in enumerate(n_list):
        color = colors[i % len(colors)]
        
        axes[0].plot(range(timesteps), lambda_data[n], label=f'n = {n}', color=color, linewidth=2.3)
        axes[1].plot(range(timesteps), P_data[n], label=f'n = {n}', color=color, linewidth=2.3)
        axes[2].plot(range(timesteps), K_data[n], label=f'n = {n}', color=color, linewidth=2.3)

    for i, ax in enumerate(axes):
        ax.set_title(titles[i], fontsize=24)
        ax.set_ylabel(ylabels[i], fontsize=22)
        ax.grid(True, linestyle='-')
        ax.tick_params(axis='both', labelsize=18)
        ax.set_xlabel("Time", fontsize=20)

    axes[2].set_ylim(0, 1.05)  # Ensure K(t) is in range [0, 1.05]

    # Single joint legend below the figure
    handles, labels = axes[0].get_legend_handles_labels()
    plt.figlegend(handles, labels, loc="lower center", ncol=len(n_list), fontsize=18, frameon=False)

    plt.tight_layout(rect=[0, 0.1, 1, 1])  # Adjust layout for legend space
    import os
    BASE_PATH = os.environ.get("BASE_PATH", "")
    if BASE_PATH and BASE_PATH.endswith('/'):
        BASE_PATH = BASE_PATH[:-1]
    plt.savefig(f'{BASE_PATH}/scripts/notebooks/data/kalman_filter_simulation_S_{S}_mu_{mu}_alpha_{alpha*2}_beta_{beta}_M_{M*2}_x_{x}.pdf', dpi=300, bbox_inches='tight')
    plt.show()

if __name__ == '__main__':
    
    # Simulation parameters
    timesteps = 100  # Number of timesteps
    S = 1  # Number of simulations (batch size)
    n = [5,50, 300, 500] # Given parameter n
    M = 500  # Given parameter M
    x = 1  # Given parameter x
    alpha = 0.002  # Given parameter alpha
    beta = 0.5  # Given parameter beta
    mu = 0.001  # Initial mean value of lambda
    plot_kalman_filter(
        timesteps,  # Number of timesteps
        S,  # Number of simulations (batch size)
        n, # Given parameter n
        M, # Given parameter M
        x,  # Given parameter x
        alpha,  # Given parameter alpha
        beta,  # Given parameter beta
        mu,  # Initial mean value of lambda
    )

    
