import os
import numpy as np
import pandas as pd
from utils.constants_LamaH import Constants
import yaml
current_dir = os.path.dirname(os.path.abspath(__file__))


def collect_mts_results(args, run_type, constants, read_path, save_path=None):
    pred_len = args.pred_len
    model = args.model
    if save_path is None:
        folder = os.path.join(f"{args.data}/{args.target}/{args.data_time_path}/pl{pred_len}/wise_{args.run_wise}/{args.model}")
    else:
        folder = save_path
    if args.run_wise in ['basin', 'basin_station']:
        station_list = constants.basin_dict.keys()
    elif args.run_wise in ['station']:
        station_list = constants.all_stations
    elif args.run_wise in ['all']:
        return None

    df_all = pd.DataFrame()
    for basin_i in station_list:
        if run_type == 'initial':
            suffix = f"_initial_{basin_i}.csv"
        else:
            suffix = f"_{run_type}_alpha{args.alpha}_{args.other_station}_{basin_i}.csv"
        file_path = os.path.join(read_path, f"{model}{suffix}")
        try:
            # 读取CSV文件
            df = pd.read_csv(file_path)

            # 关键步骤：删除第一列包含"Avg"值的行
            if not df.empty and df.columns[0] in df:  # 确保第一列存在[4](@ref)
                # 将第一列转换为字符串
                col0 = df.iloc[:, 0].astype(str)
                # 创建掩码：既不能是'Avg'，又必须在all_stations列表中
                mask = (col0 != 'Mean') & (col0 != 'Median') & (col0.isin(constants.all_stations))
                df = df[mask]
            df_all = pd.concat([df_all, df], axis=0)
        except FileNotFoundError:
            print(f"⛔ 文件不存在: {file_path}")
        except Exception as e:
            print(f"❌ 处理文件时出错 ({file_path}): {str(e)}")
    df_all.set_index('Station', inplace=True)
    if 'Mean' not in df_all.index:
        avg_row = df_all[['NSE', 'RMSE', 'MAE', 'PBIAS', 'KGE', 'FLV', 'FHV', 'MSE', 'MAPE', 'MSPE', 'R2']].mean().to_frame().T
        avg_row.index = ['Mean']
        median_row = df_all[['NSE', 'RMSE', 'MAE', 'PBIAS', 'KGE', 'FLV', 'FHV', 'MSE', 'MAPE', 'MSPE', 'R2']].median().to_frame().T
        median_row.index = ['Median']
        df_all = pd.concat([df_all, avg_row, median_row])
    df_all.reset_index(names='Station', inplace=True)
    df_all.to_csv(os.path.join(save_path), index=False)


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='Water Flow Prediction')
    base_configs = os.path.join(current_dir, '../', 'configs/camels/pl1/base_configs.yaml')
    # load YAML configs
    with open(base_configs, 'r') as f:
        yaml_config = yaml.safe_load(f)

    parser.set_defaults(**yaml_config)  # load YAML as default
    args = parser.parse_args()
    args.verbose = 1
    args.run_type = 'FlowNet'
    args.other_station = 'child_parent'
    args.run_wise = 'station'
    args.model = 'DMT'
    args.seq_len = 32
    args.pred_len = 1
    args.alpha = 0.9
    constants = Constants(args)

    save_path = os.path.join(current_dir, f"../baselines_results/{args.data}/{args.target}/pl{args.pred_len}/wise_{args.run_wise}/{args.model}/global_2/")
    collect_mts_results(
        args,
        constants,
        save_path=save_path
    )