from utils.constants_LamaH import Constants
from utils.collect_mts_results import collect_mts_results
from utils.print_args import print_args
from exp.exp_mekong_initial import Exp_MeKong as Exp_MeKong_initial
from exp.exp_mekong_flownet import Exp_MeKong as Exp_MeKong_FlowNet
from exp.exp_mekong_flownet2 import Exp_MeKong as Exp_MeKong_FlowNet2
from exp.exp_mekong_flow_list import Exp_MeKong as Exp_MeKong_flow_list
from latex.pred_table import MakeTable
import time
import torch
import os
import argparse
import yaml
import random
import numpy as np
import pandas as pd
import json
from data_provider.LamaH_preprocess import LamaHDataset
from station_spec_search_hyperparam import StationSpecSearchHyperparam
import ray
current_dir = os.path.dirname(os.path.abspath(__file__))


def parse_args():
    # full configs（combine YAML and CMD configs
    parser = argparse.ArgumentParser(description='Water Flow Prediction')
    # task setting
    parser.add_argument('--base_configs', type=str, default='configs/LamaH_daily/pl1/base_configs.yaml',
                        help="['configs/base_configs.yaml', 'configs/base_configs_mts.yaml'")
    parser.add_argument('--random_seed', type=int, default=42, help='random seed')
    parser.add_argument('--model', type=str, default='DMT')
    parser.add_argument('--station', type=str, default='all')
    parser.add_argument('--other_station', type=str, default='child_parent',
                        help="['st_flow', 'flow', 'child', 'parent', 'child_parent']")
    parser.add_argument('--is_training', type=int, default=1)
    parser.add_argument('--load_model', type=int, default=0, help='load model if it exits')
    parser.add_argument('--run_type', default='initial', help="run type: ['initial', 'FlowNet']")
    parser.add_argument('--run_wise', default='all', help="run type: ['station', 'basin', 'all']")
    parser.add_argument('--get_flow_list', action='store_true', help='get_flow_list', default=False)
    parser.add_argument('--base_flow', type=str, default='child_parent', help='vali-flow base flow')
    parser.add_argument('--flow_list_vali_factor', type=float, default=1.0)
    parser.add_argument('--features', type=str, default='MS')
    parser.add_argument('--verbose', type=int, default=1)
    parser.add_argument('--des', type=str, default='exp')
    parser.add_argument('--ablation', type=str, default='None')
    parser.add_argument('--out_station_preds', action='store_true', help='out_station_preds', default=False)
    parser.add_argument('--missing_graph_ratio', type=float, default=0)

    # model params
    parser.add_argument('--train_epochs', type=int, default=100)
    parser.add_argument('--use_common_searched_param', action='store_true', default=False,
                        help='use_common_searched_param')
    parser.add_argument('--use_station_searched_param', action='store_true', default=False,
                        help='use_station_searched_param')
    parser.add_argument('--learning_rate', type=float, default=0.001)
    parser.add_argument('--alpha', type=float, default=0.9, help='regulation parameter for loss')
    parser.add_argument('--load_model_from_global', action='store_true', default=False)
    parser.add_argument('--start_loop', type=int, default=0)
    parser.add_argument('--global_loop', type=int, default=3)
    parser.add_argument('--global_lr', type=float, default=0.0005)
    parser.add_argument('--global_lr_factor', type=float, default=0.1,
                        help='lr=learning_rate * (global_lr_factor ** loop)')
    parser.add_argument('--e_layers', type=int, default=1)
    parser.add_argument('--d_layers', type=int, default=1)
    parser.add_argument('--d_model', type=int, default=128)
    parser.add_argument('--d_ff', type=int, default=512)
    parser.add_argument('--n_heads', type=int, default=2)
    parser.add_argument('--factor', type=int, default=1)
    parser.add_argument('--multiscale_levels', type=int, default=1)
    parser.add_argument('--patch_len', type=int, default=16)
    parser.add_argument('--stride', type=int, default=8)
    parser.add_argument('--QAM_start', type=float, default=0.1)
    parser.add_argument('--QAM_end', type=float, default=0.2)
    parser.add_argument('--cycle', type=int, default=24)
    parser.add_argument('--use_di', type=int, default=0, help='data integration')
    parser.add_argument('--di_window', type=int, default=1, help='data integration window')
    # device
    parser.add_argument('--use_ray', type=int, default=1)
    parser.add_argument('--use_gpu', type=int, default=1)
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--use_multi_gpu', type=int, default=0)
    parser.add_argument('--devices', type=str, default='0,1,2,3')

    # load YAML configs
    with open(parser.parse_args().base_configs, 'r') as f:
        yaml_config = yaml.safe_load(f)

    parser.set_defaults(**yaml_config)  # load YAML as default

    args = parser.parse_args()
    # cover by searched hyperparams
    if args.use_common_searched_param:
        if args.run_type == 'FlowNet':
            param_path = os.path.join('configs', args.data, f"pl{args.pred_len}", 'FlowNet_common_searched_hyperparam.yaml')
        else:
            param_path = os.path.join('configs', args.data, f"pl{args.pred_len}", 'common_searched_hyperparam.yaml')

        with open(param_path, 'r') as f:
            param_yaml = yaml.safe_load(f)
        for key, value in param_yaml[args.model].items():
            try:
                setattr(args, key, value)
            except AttributeError:
                pass
    return args


class MeKongWaterLevelPrediction:
    def __init__(self, args, setting):
        self.args = args
        self._setting = setting
        self._verbose = args.verbose
        self.args.device = self._acquire_device()
        self._check_dataset()
        self._constants = Constants(args)
        # self.args.scaler_dict = None
        # all stations list
        self._all_stations = self._constants.all_stations
        self._child_stations_dict = self._constants.child_stations_dict
        self._parent_stations_dict = self._constants.parent_stations_dict
        self._get_data_graph()
        self.args.data_tensor, self.args.df_example = self._read_data()

    def _acquire_device(self):
        if self.args.use_gpu:
            if torch.cuda.is_available():
                os.environ["CUDA_VISIBLE_DEVICES"] = str(
                    self.args.gpu) if not self.args.use_multi_gpu else self.args.devices
                device = torch.device('cuda:{}'.format(self.args.gpu))
                print('Use GPU: cuda:{}'.format(self.args.gpu))
            elif torch.backends.mps.is_available():
                device = torch.device("mps")
                print('Use MPS')
            else:
                device = torch.device("cpu")
                print('Use CPU')
        else:
            device = torch.device('cpu')
            print('Use CPU')
        return device

    def _check_dataset(self):
        if 'LamaH' in self.args.data:
            if 'daily' in self.args.data:
                self.lamah_data = LamaHDataset('./dataset/LamaH_daily/', data_type='daily', missing_graph_ratio=self.args.missing_graph_ratio)
            elif 'hourly' in self.args.data:
                self.lamah_data = LamaHDataset('./dataset/LamaH_hourly/', data_type='hourly')
            else:
                raise NotImplementedError
        elif args.data == 'camels':
            pass

    def _get_data_graph(self):
        if 'LamaH' in self.args.data:
            self.args.edge_index = self.lamah_data.edge_index.to(self.args.device)
            self.args.edge_attr = self.lamah_data.edge_attr.to(self.args.device)
            # self.args.scaler_dict = {'qobs': self._constants.get_unified_scaler()}
            num_nodes = self.args.edge_index.max().item() + 1
            self.args.adj = torch.zeros(num_nodes, num_nodes, dtype=torch.float32)
            # 填充源节点到目标节点
            self.args.adj[self.args.edge_index[0], self.args.edge_index[1]] = 1
            self.args.adj = self.args.adj.numpy()
        elif args.data == 'camels':
            adj_matrix = self._constants.adj_matrix.values
            # 取出所有非零的边 (i,j)
            rows, cols = np.nonzero(adj_matrix)
            # edge_index 是 [2, num_edges] 的 tensor
            self.args.edge_index = torch.tensor([rows, cols], dtype=torch.long).to(self.args.device)
            self.args.edge_attr = None
            self.args.adj = adj_matrix

    def _read_data(self):
        data_tensor = None
        df_example = None
        if self.args.data == 'camels':
            pass
        else:
            # ===== 全局时间范围 =====
            target_start = pd.Timestamp(self.args.train_start_time)
            target_end = pd.Timestamp(self.args.test_end_time)
            # ====== 缓存路径 ======
            cache_path = os.path.join(self.args.dataset_path, "data_tensor.npy")

            if os.path.exists(cache_path):
                if self.args.verbose:
                    print(f"Loading cached data from {cache_path}")
                data_tensor = np.load(cache_path)
                # 仍然需要 df_example
                if self.args.data == "LamaH_daily":
                    t_index = pd.date_range(start=target_start, end=target_end, freq=self.args.freq).values.astype(
                        "datetime64[D]")
                elif self.args.data == "LamaH_hourly":
                    t_index = pd.date_range(start=target_start, end=target_end, freq=self.args.freq).values.astype(
                        "datetime64[h]")
                df_example = pd.DataFrame({"Timestamp": t_index})
                df_example["example_flow"] = data_tensor[:, 0, -1]
        return data_tensor, df_example

    def _load_best_station_params(self, station):
        if self.args.other_station == 'flow':
            param_path = os.path.join('configs', self.args.data, f'pl{self.args.pred_len}',
                                      f'{self.args.model}_{self.args.other_station}_{self.args.base_flow}_best_station_param.yaml')
        else:
            param_path = os.path.join('configs', self.args.data, f'pl{self.args.pred_len}',
                                      f'{self.args.model}_{self.args.other_station}_best_station_param.yaml')
        with open(param_path, 'r') as f:
            param_yaml = yaml.safe_load(f)
        try:
            best_station_param = param_yaml[station]
            for key, value in best_station_param.items():
                setattr(self.args, key, value)
        except KeyError:
            self.args.alpha = 0.9
            self.args.global_loop = 5
            self.args.global_lr = 0.001
            self.args.global_lr_factor = 0.1

    def get_flow_station_list(self, flow_type='flow', base_flow='child_parent'):
        print("Running to get best other station list")
        phase = "flow_list"
        if self.args.station == 'all':
            if flow_type == 'st_flow':
                useful_other_dict = {'seasonal': {}, 'trend': {}}
            else:
                useful_other_dict = {}
            for station in self._all_stations:
                station_str = str(station)
                self.args.data_path = 'ID_' + station_str + '.csv'
                if self._verbose:
                    print(f"Running {station_str} station")
                # exp setting
                setting = self._setting.replace('TARGETSTATION', station_str.replace(' ', ''))
                # self.args.data = phase
                child_list = self._child_stations_dict[station_str]
                parent_list = self._parent_stations_dict[station_str]
                if base_flow == 'child':
                    other_list = child_list
                elif base_flow == 'parent':
                    other_list = parent_list
                elif base_flow == 'lag_correlation':
                    other_list = self._constants.get_other_list(station_str, 'lag_correlation')
                else:
                    other_list = child_list + parent_list

                useful_other_list = []
                for other_station in other_list:
                    exp = Exp_MeKong_flow_list(self.args, phase, other_station, self._verbose, self.args.device)
                    link_exists = exp.train(setting, other_station)
                    if link_exists:
                        useful_other_list.append(other_station)
                print(station, useful_other_list)
                useful_other_dict[station] = useful_other_list

            json_str = json.dumps(useful_other_dict)
            save_name = f'{self.args.model}_flow_dict_{base_flow}.json'
            json_path = os.path.join(self.args.data_time_dir, save_name)
            with open(json_path, 'w') as json_file:
                json_file.write(json_str)
        else:
            self.args.data_path = self.args.station + '.csv'
            # exp setting
            setting = self._setting.replace('TARGETSTATION', self.args.station.replace(' ', ''))
            # self.args.data = phase
            child_list = self._child_stations_dict[self.args.station]
            parent_list = self._parent_stations_dict[self.args.station]
            if base_flow == 'child':
                other_list = child_list
            elif base_flow == 'parent':
                other_list = parent_list
            elif base_flow == 'lag_correlation':
                other_list = self._constants.get_other_list(self.args.station, 'lag_correlation')
            else:
                other_list = child_list + parent_list

            useful_other_list = []
            for other_station in other_list:
                exp = Exp_MeKong_flow_list(self.args, phase, other_station, self._verbose, self.args.device)
                link_exists = exp.train(setting, other_station)
                if link_exists:
                    useful_other_list.append(other_station)
                print(self.args.station, other_station, link_exists)
            print(self.args.station, useful_other_list)

    def _run_assigned_task(self, station, run_type):
        if self.args.run_wise == 'all':
            self.args.num_stations = len(self._constants.all_stations)
        elif self.args.run_wise in ['basin', 'basin_station']:
            self.args.num_stations = len(self._constants.basin_dict[station])
            tasks_list = self._constants.basin_dict.keys()
            if station not in tasks_list:
                raise ValueError(f"Station {station} not found.")
        elif args.run_wise == 'station':
            self.args.num_stations = 1
            tasks_list = self._all_stations
            if station not in tasks_list:
                raise ValueError(f"Station {station} not found.")

        task_start_time = time.time()
        if run_type == 'initial':
            exp = Exp_MeKong_initial(self.args, station, self._setting)
        elif run_type == 'FlowNet':
            exp = Exp_MeKong_FlowNet(self.args, station, self._setting)
        elif run_type == 'FlowNet2':
            exp = Exp_MeKong_FlowNet2(self.args, station, self._setting)
        else:
            raise NotImplementedError
        if self.args.is_training:
            if self._verbose:
                print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>\n'.format(self._setting))
            exp.train()
        test_flag = 0 if self.args.is_training else 1
        if self._verbose:
            print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n'.format(self._setting))
        avg_mae, avg_mse, avg_rmse, avg_mape, avg_mspe, avg_r2, avg_nse, avg_pbias, avg_kge, avg_flv, avg_fhv, \
            median_mae, median_mse, median_rmse, median_mape, median_mspe, median_r2, median_nse, median_pbias, median_kge, median_flv, median_fhv = exp.test(test=test_flag)
        torch.cuda.empty_cache()
        duration = time.time() - task_start_time
        print(f"station {station}, duration: {duration:.2f}s, "
              f"avg_nse:{avg_nse:8f}, avg_pbias:{avg_pbias:8f}, avg_kge:{avg_kge:8f}, avg_flv:{avg_flv:8f}, "
              f"avg_fhv:{avg_fhv:8f}, avg_mae: {avg_mae:8f}, avg_mse: {avg_mse:8f}, avg_rmse: {avg_rmse:8f}, "
              f"avg_mape: {avg_mape:8f}, avg_mspe: {avg_mspe:8f}, avg_r2: {avg_r2:8f}")
        print(f'station {station}, median_nse: {median_nse:8f}, median_pbias: {median_pbias:8f}, median_kge: {median_kge:8f}, '
              f'median_flv: {median_flv:8f}, median_fhv: {median_fhv:8f}, median_mae: {median_mae:8f}, '
              f'median_mse: {median_mse:8f}, median_rmse: {median_rmse:8f}, median_mape: {median_mape:8f}, '
              f'median_mspe: {median_mspe:8f}, median_r2: {median_r2:8f}')
        if self.args.ablation != 'None':
            csv_save_path = os.path.join(
                './ablation/', self.args.data, self.args.target,
                f"train{self.args.train_start_time}_{self.args.train_end_time}_"
                f"vali{self.args.vali_start_time}_{self.args.vali_end_time}_"
                f"test{self.args.test_start_time}_{self.args.test_end_time}",
                f'pl{self.args.pred_len}', run_type, f"wise_{self.args.run_wise}", self.args.model,
                self.args.ablation
            )
        else:
            csv_save_path = os.path.join(
                './baselines_results/', self.args.data, self.args.target,
                f"train{self.args.train_start_time}_{self.args.train_end_time}_"
                f"vali{self.args.vali_start_time}_{self.args.vali_end_time}_"
                f"test{self.args.test_start_time}_{self.args.test_end_time}",
                f'pl{self.args.pred_len}', run_type, f"wise_{self.args.run_wise}", self.args.model
            )
        if not os.path.exists(csv_save_path):
            os.makedirs(csv_save_path)
        if run_type not in ['FlowNet', 'FLowNet2']:
            make_table = MakeTable(self.args, self.args.model, run_type, self.args.results, self._setting,
                                   csv_save_path=csv_save_path, target_station=station)
            make_table.save_to_csv()
            # make_table.collect_stations_rmse_results(phase=phase)
            print("Make results table done")
        # record to search results
        if run_type == 'initial' and self.args.run_wise == 'all':
            save_path = os.path.join("search_results", self.args.data, self.args.target, f"pl_{self.args.pred_len}", run_type)
            if not os.path.exists(save_path):
                os.makedirs(save_path)

            with open(os.path.join(save_path, f"{self.args.model}.csv"), 'a') as f:
                # 先写入固定变量
                f.write(
                    f"avg_nse,{avg_nse:8f},avg_pbias,{avg_pbias:8f},avg_kge,{avg_kge:8f},avg_flv,{avg_flv:8f},"
                    f"avg_fhv,{avg_fhv:8f},avg_mae,{avg_mae:8f},avg_mse,{avg_mse:8f},avg_rmse,{avg_rmse:8f},"
                    f"avg_mape,{avg_mape:8f},avg_mspe,{avg_mspe:8f},avg_r2,{avg_r2:8f},"
                    f"median_nse,{median_nse:8f},median_pbias,{median_pbias:8f},median_kge,{median_kge:8f},"
                    f"median_flv,{median_flv:8f},median_fhv,{median_fhv:8f},median_mae,{median_mae:8f},"
                    f"median_mse,{median_mse:8f},median_rmse,{median_rmse:8f},median_mape,{median_mape:8f},"
                    f"median_mspe,{median_mspe:8f},median_r2,{median_r2:8f},"
                )

                # 遍历 self.args 的所有参数
                args_dict = vars(self.args)  # 转换为字典
                for key, value in args_dict.items():
                    if str(key) in ['edge_index', 'edge_attr', 'adj']:
                        continue
                    f.write(f"{key},{value},")  # 格式化为 key,value
                f.write("\n")  # 换行
        del exp
        return avg_mae, avg_mse, avg_rmse, avg_mape, avg_mspe, avg_r2, avg_nse, avg_pbias, avg_kge, avg_flv, avg_fhv

    def _run_assigned_task_flownet2_local(self, station, run_type):
        task_start_time = time.time()
        self.args.num_stations = len(self._constants.all_stations)
        exp = Exp_MeKong_initial(self.args, station, self._setting)
        if self.args.is_training:
            if self._verbose:
                print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>\n'.format(self._setting))
            exp.train()
        test_flag = 0 if self.args.is_training else 1
        if self._verbose:
            print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n'.format(self._setting))
        avg_mae, avg_mse, avg_rmse, avg_mape, avg_mspe, avg_r2, avg_nse, avg_pbias, avg_kge, avg_flv, avg_fhv, \
            median_mae, median_mse, median_rmse, median_mape, median_mspe, median_r2, median_nse, median_pbias, median_kge, median_flv, median_fhv = exp.test(test=test_flag)
        torch.cuda.empty_cache()
        duration = time.time() - task_start_time
        print(f"station {station}, duration: {duration:.2f}s, "
              f"avg_nse:{avg_nse:8f}, avg_pbias:{avg_pbias:8f}, avg_kge:{avg_kge:8f}, avg_flv:{avg_flv:8f}, "
              f"avg_fhv:{avg_fhv:8f}, avg_mae: {avg_mae:8f}, avg_mse: {avg_mse:8f}, avg_rmse: {avg_rmse:8f}, "
              f"avg_mape: {avg_mape:8f}, avg_mspe: {avg_mspe:8f}, avg_r2: {avg_r2:8f}")
        print(f'station {station}, median_nse: {median_nse:8f}, median_pbias: {median_pbias:8f}, median_kge: {median_kge:8f}, '
              f'median_flv: {median_flv:8f}, median_fhv: {median_fhv:8f}, median_mae: {median_mae:8f}, '
              f'median_mse: {median_mse:8f}, median_rmse: {median_rmse:8f}, median_mape: {median_mape:8f}, '
              f'median_mspe: {median_mspe:8f}, median_r2: {median_r2:8f}')
        if self.args.ablation != 'None':
            csv_save_path = os.path.join(
                './ablation/', self.args.data, self.args.target,
                f"train{self.args.train_start_time}_{self.args.train_end_time}_"
                f"vali{self.args.vali_start_time}_{self.args.vali_end_time}_"
                f"test{self.args.test_start_time}_{self.args.test_end_time}",
                f'pl{self.args.pred_len}', run_type, f"wise_{self.args.run_wise}", self.args.model,
                self.args.ablation
            )
        else:
            csv_save_path = os.path.join(
                './baselines_results/', self.args.data, self.args.target,
                f"train{self.args.train_start_time}_{self.args.train_end_time}_"
                f"vali{self.args.vali_start_time}_{self.args.vali_end_time}_"
                f"test{self.args.test_start_time}_{self.args.test_end_time}",
                f'pl{self.args.pred_len}', run_type, f"wise_{self.args.run_wise}", self.args.model
            )
        if not os.path.exists(csv_save_path):
            os.makedirs(csv_save_path)
        make_table = MakeTable(self.args, self.args.model, 'initial', self.args.results, self._setting,
                               csv_save_path=csv_save_path, target_station=station)
        make_table.save_to_csv()
        # make_table.collect_stations_rmse_results(phase=phase)
        print("Make results table done")
        del exp
        return avg_mae, avg_mse, avg_rmse, avg_mape, avg_mspe, avg_r2, avg_nse, avg_pbias, avg_kge, avg_flv, avg_fhv

    @ray.remote(num_gpus=0.5)
    def _run_station_task(self, station, run_type):
        station_str = str(station)
        return self._run_assigned_task(station_str, run_type)

    def _run_all_tasks_in_loop(self, run_type):
        if self.args.run_wise in ['basin', 'basin_station']:
            tasks_list = list(self._constants.basin_dict.keys())
        else:
            tasks_list = self._all_stations

        start_time = time.time()
        results = []

        def fmt_time(seconds):
            m, s = divmod(int(seconds), 60)
            h, m = divmod(m, 60)
            if h > 0:
                return f"{h}h {m}m {s}s"
            elif m > 0:
                return f"{m}m {s}s"
            else:
                return f"{s}s"

        completed = 0

        # 初始化 Ray：分配所有 GPU
        if self.args.use_ray:
            num_gpus = torch.cuda.device_count()
            ray.init(num_gpus=num_gpus, ignore_reinit_error=True)
            print(f"🚀 Ray start, detect {num_gpus} GPUs")
            # Submit all tasks
            futures = [self._run_station_task.remote(self, s, run_type) for s in tasks_list]
            num_tasks = len(futures)

            while futures:
                # Wait for any task to complete
                done, futures = ray.wait(futures, num_returns=1)
                res = ray.get(done[0])
                results.append(res)

                completed += 1
                elapsed = time.time() - start_time
                avg_time = elapsed / completed
                remaining = (num_tasks - completed) * avg_time

                print(f"✅ Progress: {completed}/{num_tasks} | Elapsed: {fmt_time(elapsed)} | ETA: {fmt_time(remaining)}")
        else:
            for i, station in enumerate(tasks_list):
                completed += 1
                station_str = str(station)
                res = self._run_assigned_task(station_str, run_type)
                results.append(res)
                num_tasks = len(tasks_list)
                elapsed = time.time() - start_time
                avg_time = elapsed / completed
                remaining = (num_tasks - completed) * avg_time
                print(f"✅ Progress: {completed}/{num_tasks} | Elapsed: {fmt_time(elapsed)} | ETA: {fmt_time(remaining)}")
                if run_type == 'FlowNet':
                    station_search_param = StationSpecSearchHyperparam(self.args, station_str, metric='nse')
                    station_search_param.save_best_station_param()
                    station_search_param.save_best_results_to_csv()

        all_mae, all_mse, all_rmse, all_mape, all_mspe, \
            all_r2, all_nse, all_pbias, all_kge, all_flv, all_fhv = zip(*results)

        avg_mae = np.mean(all_mae)
        median_mae = np.median(all_mae)
        avg_mse = np.mean(all_mse)
        median_mse = np.median(all_mse)
        avg_rmse = np.mean(all_rmse)
        median_rmse = np.median(all_rmse)
        avg_mape = np.mean(all_mape)
        median_mape = np.median(all_mape)
        avg_mspe = np.mean(all_mspe)
        median_mspe = np.median(all_mspe)
        avg_r2 = np.mean(all_r2)
        median_r2 = np.median(all_r2)
        avg_nse = np.mean(all_nse)
        median_nse = np.median(all_nse)
        avg_pbias = np.mean(all_pbias)
        median_pbias = np.median(all_pbias)
        avg_kge = np.mean(all_kge)
        median_kge = np.median(all_kge)
        avg_flv = np.mean(all_flv)
        median_flv = np.median(all_flv)
        avg_fhv = np.mean(all_fhv)
        median_fhv = np.median(all_fhv)

        print(f"Run all stations in {run_type} done, "
              f"overall duration {time.time() - start_time:.2f}s")
        print(f"avg_nse: {avg_nse}, avg_pbias: {avg_pbias}, avg_kge: {avg_kge}, "
              f"avg_flv: {avg_flv}, avg_fhv: {avg_fhv}, avg_mae: {avg_mae}, "
              f"avg_mse: {avg_mse}, avg_rmse: {avg_rmse}, avg_mape: {avg_mape}, "
              f"avg_mspe: {avg_mspe}, avg_r2: {avg_r2}")
        print(f"median_nse: {median_nse}, median_pbias: {median_pbias}, median_kge: {median_kge}, "
              f"median_flv: {median_flv}, median_fhv: {median_fhv}, "
              f"median_mae: {median_mae}, median_mse: {median_mse}, median_rmse: {median_rmse}, "
              f"median_mape: {median_mape}, median_mspe: {median_mspe}, median_r2: {median_r2}")
        ray.shutdown()  # 任务完成后关闭 Ray
        if self.args.ablation != 'None':
            csv_save_path = os.path.join(
                './ablation/', self.args.data, self.args.target,
                f"train{self.args.train_start_time}_{self.args.train_end_time}_"
                f"vali{self.args.vali_start_time}_{self.args.vali_end_time}_"
                f"test{self.args.test_start_time}_{self.args.test_end_time}",
                f'pl{self.args.pred_len}', run_type, f"wise_{self.args.run_wise}", self.args.model,
                self.args.ablation
            )
        else:
            csv_save_path = os.path.join(
                './baselines_results/', self.args.data, self.args.target,
                f"train{self.args.train_start_time}_{self.args.train_end_time}_"
                f"vali{self.args.vali_start_time}_{self.args.vali_end_time}_"
                f"test{self.args.test_start_time}_{self.args.test_end_time}",
                f'pl{self.args.pred_len}', run_type, f"wise_{self.args.run_wise}", self.args.model
            )
        if not os.path.exists(csv_save_path):
            os.makedirs(csv_save_path)
        if run_type in ["FlowNet", "FlowNet2"]:
            # collect_mts_results(self.args, run_type, self._constants,
            #                     read_path=os.path.join(csv_save_path, 'local'),
            #                     save_path=os.path.join(csv_save_path, 'local.csv'))
            for loop in range(self.args.global_loop):
                collect_mts_results(self.args, run_type, self._constants,
                                    read_path=os.path.join(csv_save_path, f'global_{loop}'),
                                    save_path=os.path.join(csv_save_path, f'global_{loop}.csv'))
        else:
            collect_mts_results(self.args, run_type, self._constants,
                                read_path=os.path.join(csv_save_path),
                                save_path=os.path.join(csv_save_path, 'initial.csv'))
        return avg_mae, avg_mse, avg_rmse, avg_mape, avg_mspe, \
            avg_r2, avg_nse, avg_pbias, avg_kge, avg_flv, avg_fhv

    def run_tasks(self):
        if self.args.run_type == "FlowNet2":
            self.args.run_wise = 'all'
            self._run_assigned_task_flownet2_local('all', 'initial')
            self.args.run_wise = 'station'
            self._run_all_tasks_in_loop('FlowNet2')
            return None
        if self.args.station == 'all':
            print(f"Running all stations with running type {self.args.run_type}, running wise {self.args.run_wise}")
            if not self.args.is_training:
                print('testing')
            if self.args.run_type == 'initial':
                if 'LamaH' in self.args.data:
                    if self.args.run_wise == 'all':
                        self._run_assigned_task('all', self.args.run_type)
                    elif self.args.run_wise == 'station':
                        self._run_all_tasks_in_loop(self.args.run_type)
                    else:
                        raise ValueError("For LamaH data, run_wise must in ['station']")
                elif 'camels' == self.args.data:
                    if self.args.run_wise == 'all':
                        self._run_assigned_task('all', self.args.run_type)
                    elif self.args.run_wise in ['basin', 'basin_station']:
                        self._run_all_tasks_in_loop(self.args.run_type)
                    elif self.args.run_wise == 'station':
                        self._run_all_tasks_in_loop(self.args.run_type)
                    else:
                        raise ValueError("For camels data, run_wise must in ['all', 'basin', 'basin_station', 'station']")
                else:
                    raise NotImplementedError
            elif self.args.run_type == 'FlowNet':
                self._run_all_tasks_in_loop(self.args.run_type)
            else:
                raise NotImplementedError
        else:
            station_str = str(self.args.station)

            print(f"Running station {self.args.station} with running type {self.args.run_type}, running wise {self.args.run_wise}")
            setting = self._setting
            if not self.args.is_training:
                print('testing')
            self._run_assigned_task(station_str)
            if self.args.ablation != 'None':
                csv_save_path = os.path.join(
                    './ablation/', self.args.data, self.args.target,
                    f"train{self.args.train_start_time}_{self.args.train_end_time}_"
                    f"vali{self.args.vali_start_time}_{self.args.vali_end_time}_"
                    f"test{self.args.test_start_time}_{self.args.test_end_time}",
                    f'pl{self.args.pred_len}', self.args.run_type, f"wise_{self.args.run_wise}", self.args.model,
                    self.args.ablation
                )
            else:
                csv_save_path = os.path.join(
                    './baselines_results/', self.args.data, self.args.target,
                    f"train{self.args.train_start_time}_{self.args.train_end_time}_"
                    f"vali{self.args.vali_start_time}_{self.args.vali_end_time}_"
                    f"test{self.args.test_start_time}_{self.args.test_end_time}",
                    f'pl{self.args.pred_len}', self.args.run_type, f"wise_{self.args.run_wise}", self.args.model
                )
            if not os.path.exists(csv_save_path):
                os.makedirs(csv_save_path)
            make_table = MakeTable(self.args, self.args.model, self.args.run_type, self.args.results, setting,
                                   csv_save_path=csv_save_path)
            make_table.save_to_csv()
            # make_table.collect_stations_rmse_results(phase=phase)
            print("Make results table done")
            if self.args.run_type == 'FlowNet':
                station_search_param = StationSpecSearchHyperparam(self.args, self.args.station, metric='nse')
                station_search_param.save_best_station_param()
                station_search_param.save_best_results_to_csv()

def run(args):
    # exp setting
    setting = '{}_{}_{}_{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_fc{}_eb{}_dt{}_{}_alpha{}_gloop{}_glr{}_glrf{}_{}_seed{}'.format(
        args.model,
        args.data,
        args.target,
        args.run_type,
        args.seq_len,
        args.label_len,
        args.pred_len,
        args.d_model,
        args.n_heads,
        args.e_layers,
        args.d_layers,
        args.d_ff,
        args.factor,
        args.embed,
        args.distil,
        args.other_station,
        args.alpha,
        args.global_loop,
        args.global_lr,
        args.global_lr_factor,
        args.des,
        args.random_seed
    )
    model = MeKongWaterLevelPrediction(args, setting)
    results = model.run_tasks()

    return results


if __name__ == '__main__':
    args = parse_args()
    random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)
    print_args(args)

    run(args)