import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

# 设置字体大小：坐标上的数字为20，其余文字为24
plt.rcParams.update({
    'font.size': 24,  # 默认字体大小
    'axes.titlesize': 24,  # 标题字体大小
    'axes.labelsize': 24,  # 轴标签字体大小
    'xtick.labelsize': 20,  # x轴刻度标签字体大小
    'ytick.labelsize': 20,  # y轴刻度标签字体大小
    'legend.fontsize': 24  # 图例字体大小
})


def compute_snr_values(N_values):
    """
    Compute SNR_K for all K with different values of N.
    """
    # Define parameters
    beta_min = 0.1
    beta_max = 1.0
    var_x0 = 1.0

    # Store results
    results = {}

    # For each value of N
    for N in N_values:
        delta_t = 1.0 / N
        snr_values = []

        # For each K
        for K in range(1, N + 1):
            # Compute beta at each time step
            betas = [beta_min + (k / N) * (beta_max - beta_min) for k in range(K)]

            # Compute a_K
            a_K = 1.0
            for j in range(K):
                a_K *= (1 - (delta_t / 2) * betas[j])

            # Compute sum of g_j^2
            sum_g_j_squared = 0
            for j in range(K):
                g_j = np.sqrt(betas[j] * delta_t)
                for k in range(j + 1, K):
                    g_j *= (1 - (delta_t / 2) * betas[k])
                sum_g_j_squared += g_j ** 2

            # Compute SNR_K
            snr_K = (a_K ** 2 * var_x0) / sum_g_j_squared
            snr_values.append(snr_K)

        results[N] = snr_values

    return results


def plot_snr(results, N_values):
    """
    Plot the SNR values and their ratios, and save as PDF.
    """
    # Plot 1: SNR values
    fig1, ax1 = plt.subplots(figsize=(12, 9))
    for N in N_values:
        # Plot up to K=50 for clarity
        K_values = range(1, min(51, len(results[N]) + 1))
        ax1.plot(K_values, results[N][:50], label=f'N={N}', linewidth=2.5)

    ax1.set_title('SNR_K for Different Values of N', pad=20)  # 增加标题与图的距离
    ax1.set_xlabel('K')
    ax1.set_ylabel('SNR_K')
    ax1.grid(True)
    ax1.legend(frameon=True, fancybox=True, shadow=True, loc='best')

    # 调整边距，减少左侧和底部留白
    plt.subplots_adjust(left=0.12, bottom=0.12, right=0.95, top=0.92)

    # 保存为PDF
    plt.savefig('SNR_values.pdf', format='pdf')

    # Plot 2: SNR ratios at same K
    fig2, ax2 = plt.subplots(figsize=(12, 9))
    base_N = 50
    for N in N_values:
        if N != base_N:
            # Compute ratio for common K values
            common_K = min(len(results[N]), len(results[base_N]))
            ratios = [results[N][k] / results[base_N][k] for k in range(min(common_K, 50))]
            ax2.plot(range(1, len(ratios) + 1), ratios, label=f'N={N}/N={base_N}', linewidth=2.5)

    ax2.set_title('SNR_K Ratios at Same K (Relative to N=50)', pad=20)  # 增加标题与图的距离
    ax2.set_xlabel('K')
    ax2.set_ylabel('Ratio')
    ax2.grid(True)
    ax2.legend(frameon=True, fancybox=True, shadow=True, loc='best')

    # 调整边距，减少左侧和底部留白
    plt.subplots_adjust(left=0.12, bottom=0.12, right=0.95, top=0.92)

    # 保存为PDF
    plt.savefig('SNR_ratios.pdf', format='pdf')

    # 最后显示图像
    plt.show()


# Main execution
N_values = [50, 100, 200, 400, 1000]
results = compute_snr_values(N_values)
plot_snr(results, N_values)

# Print sample values
print("Sample SNR values:")
for N in N_values:
    print(f"N={N}, SNR_K for K=1,2,3: {results[N][:3]}")

print("\nSample SNR ratios (relative to N=50):")
base_N = 50
for N in N_values:
    if N != base_N:
        ratios = [results[N][k] / results[base_N][k] for k in range(3)]
        print(f"N={N}/N={base_N} for K=1,2,3: {ratios}")