import os
import pandas as pd

def write_csv(dir_path, metrics, config):
    """
    将评估结果写入两个CSV文件：success_rate.csv和avg_rmse.csv
    
    参数:
    dir_path: CSV文件目录路径
    metrics: 评估指标字典
    config: 配置参数
    """
    os.makedirs(dir_path, exist_ok=True)
    success_rate_path = os.path.join(dir_path, 'success_rate.csv')
    avg_rmse_path = os.path.join(dir_path, 'avg_rmse.csv')
    max_tau = max(metrics.keys())
    seed = config['exp']['seed']
    success_rate_row = {'seed': seed}
    avg_rmse_row = {'seed': seed}
    
    for tau, metric in metrics.items():
        success_rate_row[f'tau_{tau}'] = f"{metric['success_rate']:.3f}"
        avg_rmse_row[f'tau_{tau}'] = f"{metric['avg_rmse']:.3f}"
    if os.path.exists(success_rate_path):
        try:
            df_success = pd.read_csv(success_rate_path)
            df_success = pd.concat([df_success, pd.DataFrame([success_rate_row])], ignore_index=True)
        except Exception as e:
            print(f"读取success_rate.csv失败: {e}")
            df_success = pd.DataFrame([success_rate_row])
    else:
        df_success = pd.DataFrame([success_rate_row])
    if os.path.exists(avg_rmse_path):
        try:
            df_rmse = pd.read_csv(avg_rmse_path)
            df_rmse = pd.concat([df_rmse, pd.DataFrame([avg_rmse_row])], ignore_index=True)
        except Exception as e:
            print(f"读取avg_rmse.csv失败: {e}")
            df_rmse = pd.DataFrame([avg_rmse_row])
    else:
        df_rmse = pd.DataFrame([avg_rmse_row])
    df_success.to_csv(success_rate_path, index=False)
    df_rmse.to_csv(avg_rmse_path, index=False)
    print(f"成功率结果已保存到 {success_rate_path}")
    print(f"RMSE结果已保存到 {avg_rmse_path}")