import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from scipy.interpolate import make_interp_spline
from scipy.ndimage import gaussian_filter1d

# CIFAR-10: FedRPF vs DP
cifar10_fedrpf = {
    'psnr_avg': [6.9901, 7.5701, 7.6279, 7.7406, 7.8707, 7.8934, 7.9857, 8.1317, 8.1496, 8.2546,
                 8.3027, 8.3885, 8.4483, 8.4737, 8.5132, 8.5313, 8.7237, 8.8306, 9.0182, 9.0274],
    'acc_last': [0.3122, 0.3500, 0.3230, 0.3098, 0.3570, 0.3528, 0.3310, 0.3384, 0.4078, 0.4240,
                 0.4106, 0.3828, 0.3658, 0.3612, 0.3880, 0.4244, 0.4292, 0.4378, 0.4142, 0.4226]
}

cifar10_dp = {
    'psnr_avg': [8.5106, 8.5213, 8.5431, 8.5467, 8.5469, 8.5474, 8.5586, 8.5637, 8.5696, 8.5781,
                 8.5807, 8.5831, 8.5867, 8.5946, 8.5989, 8.5998, 8.6094, 8.6503, 8.6760, 8.7110],
    'acc_last': [0.1326, 0.2572, 0.2598, 0.2486, 0.2544, 0.1526, 0.2522, 0.2572, 0.2072, 0.2476,
                 0.1704, 0.2490, 0.2362, 0.2426, 0.2630, 0.1878, 0.2452, 0.2170, 0.2588, 0.1650]
}

# 转为 DataFrame 并排序
def create_sorted_df(data):
    df = pd.DataFrame(data)
    return df.sort_values('psnr_avg')

# 平滑处理函数
def smooth_curve(x, y, smooth_factor=0.3, points=300):
    """
    使用高斯滤波和样条插值创建平滑曲线
    smooth_factor: 平滑程度 (0-1, 越大越平滑)
    points: 插值点数量
    """
    # 先用高斯滤波平滑原始数据
    y_smooth = gaussian_filter1d(y, sigma=smooth_factor*len(y)/10)
    
    # 使用样条插值增加点密度
    if len(x) >= 4:  # 样条插值至少需要4个点
        spl = make_interp_spline(x, y_smooth, k=3)  # k=3表示三次样条
        x_new = np.linspace(x.min(), x.max(), points)
        y_new = spl(x_new)
        return x_new, y_new
    else:
        return x, y_smooth

# 设置更美观的样式
plt.style.use('seaborn-v0_8-whitegrid')  # 使用seaborn样式
plt.rcParams.update({
    'font.size': 13,
    'font.family': 'Times New Roman',
    'axes.linewidth': 1.5,
    'axes.spines.top': False,
    'axes.spines.right': False,
    'axes.spines.left': True,
    'axes.spines.bottom': True,
    'grid.alpha': 0.2,
    'grid.linewidth': 0.8,
    'figure.dpi': 300,
    'lines.antialiased': True,
    'patch.antialiased': True
})

# 创建图表
fig, ax = plt.subplots(figsize=(8, 6))
df_fedrpf = create_sorted_df(cifar10_fedrpf)
df_dp = create_sorted_df(cifar10_dp)

# 平滑处理数据
x_fedrpf_smooth, y_fedrpf_smooth = smooth_curve(
    df_fedrpf['psnr_avg'].values, 
    df_fedrpf['acc_last'].values, 
    smooth_factor=0.4
)

x_dp_smooth, y_dp_smooth = smooth_curve(
    df_dp['psnr_avg'].values, 
    df_dp['acc_last'].values, 
    smooth_factor=0.4
)

# 绘制原始数据点（半透明）
ax.scatter(df_fedrpf['psnr_avg'], df_fedrpf['acc_last'], 
          color='#2E86AB', s=30, alpha=0.6, zorder=3)
ax.scatter(df_dp['psnr_avg'], df_dp['acc_last'], 
          color='#A23B72', s=30, alpha=0.6, zorder=3)

# 绘制平滑曲线
ax.plot(x_fedrpf_smooth, y_fedrpf_smooth, 
        color='#2E86AB', linewidth=3, label='FedRPF', alpha=0.9, zorder=2)
ax.plot(x_dp_smooth, y_dp_smooth, 
        color='#A23B72', linewidth=3, label='DP', alpha=0.9, zorder=2)

# 添加渐变填充效果（可选）
ax.fill_between(x_fedrpf_smooth, y_fedrpf_smooth, alpha=0.1, color='#2E86AB')
ax.fill_between(x_dp_smooth, y_dp_smooth, alpha=0.1, color='#A23B72')

# 坐标轴设置
ax.set_xlabel('PSNR (dB)', fontsize=14, fontweight='bold', labelpad=10)
ax.set_ylabel('Accuracy', fontsize=14, fontweight='bold', labelpad=10)
ax.set_title('CIFAR-10: Privacy-Performance Trade-off', 
             fontsize=16, fontweight='bold', pad=20)

# 美化网格
ax.grid(True, linestyle='-', linewidth=0.5, alpha=0.3, color='gray')
ax.set_axisbelow(True)

# 设置坐标轴范围，留出适当边距
x_margin = (ax.get_xlim()[1] - ax.get_xlim()[0]) * 0.02
y_margin = (ax.get_ylim()[1] - ax.get_ylim()[0]) * 0.05
ax.set_xlim(ax.get_xlim()[0] - x_margin, ax.get_xlim()[1] + x_margin)
ax.set_ylim(ax.get_ylim()[0] - y_margin, ax.get_ylim()[1] + y_margin)

# 美化图例
legend = ax.legend(frameon=True, fontsize=13, loc='best', 
                  fancybox=True, shadow=True, framealpha=0.9)
legend.get_frame().set_facecolor('white')
legend.get_frame().set_edgecolor('lightgray')

# 调整刻度
ax.tick_params(axis='both', which='major', labelsize=12, length=6, width=1.5)
ax.tick_params(axis='both', which='minor', length=3, width=1)

# 自动布局并保存
plt.tight_layout()
plt.savefig('./Plot/cifar10_psnr_accuracy_smooth.png', 
           dpi=300, bbox_inches='tight', facecolor='white', edgecolor='none')
plt.show()