from utils.constants import Constants
from utils.print_args import print_args
from exp.exp_mekong_initial import Exp_MeKong as Exp_MeKong_initial
from exp.exp_mekong_for_gnn import Exp_MeKong as Exp_MeKong_for_gnn
from exp.exp_mekong_local_global import Exp_MeKong as Exp_MeKong_local_global
from exp.exp_mekong_flow_list import Exp_MeKong as Exp_MeKong_flow_list
from latex.pred_table import MakeTable
from utils.collect_mts_results import collect_mts_results
import time
import torch
import os
import argparse
import yaml
import random
import numpy as np
import json
import ray
from station_spec_search_hyperparam import StationSpecSearchHyperparam
current_dir = os.path.dirname(os.path.abspath(__file__))


def parse_args():
    # full configs（combine YAML and CMD configs
    parser = argparse.ArgumentParser(description='Water Flow Prediction')
    # task setting
    parser.add_argument('--base_configs', type=str, default='configs/Water.Level/base_configs.yaml',
                        help="['configs/base_configs.yaml', 'configs/base_configs_mts.yaml'")
    parser.add_argument('--random_seed', type=int, default=42, help='random seed')
    parser.add_argument('--model', type=str, default='FlowNet')
    parser.add_argument('--station', type=str, default='all')
    parser.add_argument('--other_station', type=str, default='child_parent',
                        help="['st_flow', 'flow', 'child', 'parent', 'child_parent']")
    parser.add_argument('--is_training', type=int, default=1)
    parser.add_argument('--load_model', type=int, default=0, help='load model if it exits')
    parser.add_argument('--run_initial', action='store_true', help='initial', default=False)
    parser.add_argument('--run_for_gnn', action='store_true', help='for gnn', default=False)
    parser.add_argument('--run_local_global', action='store_true', help='local_global', default=False)
    parser.add_argument('--get_flow_list', action='store_true', help='get_flow_list', default=False)
    parser.add_argument('--base_flow', type=str, default='child_parent', help='vali-flow base flow')
    parser.add_argument('--flow_list_vali_factor', type=float, default=1.0)
    parser.add_argument('--features', type=str, default='MS')
    parser.add_argument('--verbose', type=int, default=1)
    parser.add_argument('--des', type=str, default='exp')
    parser.add_argument('--ablation', type=str, default='None')
    parser.add_argument('--out_station_preds', action='store_true', help='out_station_preds', default=False)
    # model params
    parser.add_argument('--train_epochs', type=int, default=100)
    parser.add_argument('--use_common_searched_param', action='store_true', default=False,
                        help='use_common_searched_param')
    parser.add_argument('--use_station_searched_param', action='store_true', default=False,
                        help='use_station_searched_param')
    parser.add_argument('--learning_rate', type=float, default=0.001)
    parser.add_argument('--alpha', type=float, default=0.5, help='regulation parameter for loss')
    parser.add_argument('--global_loop', type=int, default=3)
    parser.add_argument('--global_lr', type=float, default=0.001)
    parser.add_argument('--global_lr_factor', type=float, default=0.2,
                        help='lr=learning_rate * (global_lr_factor ** loop)')
    parser.add_argument('--e_layers', type=int, default=1)
    parser.add_argument('--d_layers', type=int, default=1)
    parser.add_argument('--d_model', type=int, default=512)
    parser.add_argument('--d_ff', type=int, default=2048)
    parser.add_argument('--n_heads', type=int, default=8)
    parser.add_argument('--factor', type=int, default=1)
    parser.add_argument('--multiscale_levels', type=int, default=1)
    parser.add_argument('--patch_len', type=int, default=16)
    parser.add_argument('--stride', type=int, default=8)
    parser.add_argument('--QAM_start', type=float, default=0.1)
    parser.add_argument('--QAM_end', type=float, default=0.2)
    parser.add_argument('--cycle', type=int, default=24)
    # device
    parser.add_argument('--use_ray', type=int, default=1)
    parser.add_argument('--use_gpu', type=int, default=1)
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--use_multi_gpu', type=int, default=0)
    parser.add_argument('--devices', type=str, default='0,1,2,3')

    # load YAML configs
    with open(parser.parse_args().base_configs, 'r') as f:
        yaml_config = yaml.safe_load(f)

    parser.set_defaults(**yaml_config)  # load YAML as default

    # load num_test, num_vali
    data_time_dir = (parser.parse_args().dataset_path + parser.parse_args().target + '/'
                     + parser.parse_args().data_time_path)
    # data-to-time dir
    parser.set_defaults(data_time_dir=data_time_dir)
    other_data_file = os.path.join(data_time_dir, 'other_data.yaml')
    with open(other_data_file, 'r') as f:
        other_data_config = yaml.safe_load(f)
    parser.set_defaults(num_vali=other_data_config['num_vali'])
    parser.set_defaults(num_test=other_data_config['num_test'])

    args = parser.parse_args()
    # cover by searched hyperparams
    if parser.parse_args().use_common_searched_param:
        param_path = os.path.join('configs', 'common_searched_hyperparam.yaml')
        with open(param_path, 'r') as f:
            param_yaml = yaml.safe_load(f)
        for key, value in param_yaml[args.model].items():
            try:
                setattr(args, key, value)
            except AttributeError:
                pass
    return args


class MeKongWaterLevelPrediction:
    def __init__(self, args):
        self.args = args
        self._constants = Constants(args)
        self.args.scaler_dict = self._constants.get_unified_scaler_dict()
        # all stations list
        self._all_stations = self._constants.all_stations
        self._child_stations_dict = self._constants.child_stations_dict
        self._parent_stations_dict = self._constants.parent_stations_dict
        self._verbose = args.verbose
        self._device = self._acquire_device()
        self.args.device = self._device

    def _acquire_device(self):
        if self.args.use_gpu:
            if torch.cuda.is_available():
                os.environ["CUDA_VISIBLE_DEVICES"] = str(
                    self.args.gpu) if not self.args.use_multi_gpu else self.args.devices
                device = torch.device('cuda:{}'.format(self.args.gpu))
                print('Use GPU: cuda:{}'.format(self.args.gpu))
            elif torch.backends.mps.is_available():
                device = torch.device("mps")
                print('Use MPS')
            else:
                device = torch.device("cpu")
                print('Use CPU')
        else:
            device = torch.device('cpu')
            print('Use CPU')
        return device

    def _load_best_station_params(self, station):
        if self.args.other_station == 'flow':
            param_path = os.path.join('configs', self.args.target, f'sl{self.args.seq_len}_pl{self.args.pred_len}',
                                      f'{self.args.model}_{self.args.other_station}_{self.args.base_flow}_best_station_param.yaml')
        else:
            param_path = os.path.join('configs', self.args.target, f'sl{self.args.seq_len}_pl{self.args.pred_len}',
                                      f'{self.args.model}_{self.args.other_station}_best_station_param.yaml')
        with open(param_path, 'r') as f:
            param_yaml = yaml.safe_load(f)
        try:
            best_station_param = param_yaml[station]
            for key, value in best_station_param.items():
                setattr(self.args, key, value)
        except KeyError:
            self.args.alpha = 0.9
            self.args.global_loop = 5
            self.args.global_lr = 0.001
            self.args.global_lr_factor = 0.1

    def get_flow_station_list(self, flow_type='st_flow', base_flow='child_parent'):
        if flow_type == 'st_flow':
            print("Running to get seasonal-trend best other station list")
            phase = "st_flow_list"
        else:
            print("Running to get best other station list")
            phase = "flow_list"
        if self.args.station == 'all':
            if flow_type == 'st_flow':
                useful_other_dict = {'seasonal': {}, 'trend': {}}
            else:
                useful_other_dict = {}
            for station in self._all_stations:
                self.args.data_path = station + '.csv'
                if self._verbose:
                    print(f"Running {station} station")
                # exp setting
                setting = self._setting.replace('TARGETSTATION', station.replace(' ', ''))
                self.args.data = phase
                child_list = self._child_stations_dict[station]
                parent_list = self._parent_stations_dict[station]
                if base_flow == 'child':
                    other_list = child_list
                elif base_flow == 'parent':
                    other_list = parent_list
                elif base_flow == 'lag_correlation':
                    other_list = self._constants.get_other_list(station, 'lag_correlation')
                else:
                    other_list = child_list + parent_list

                useful_other_list = []
                for other_station in other_list:
                    exp = Exp_MeKong_flow_list(self.args, phase, other_station, self._verbose, self._device)
                    link_exists = exp.train(setting, other_station)
                    if link_exists:
                        useful_other_list.append(other_station)
                print(station, useful_other_list)
                useful_other_dict[station] = useful_other_list

            json_str = json.dumps(useful_other_dict)
            if flow_type == 'st_flow':
                save_name = f'{self.args.model}_st_flow_dict_{base_flow}.json'
            else:
                save_name = f'{self.args.model}_flow_dict_{base_flow}.json'
            json_path = os.path.join(self.args.data_time_dir, save_name)
            with open(json_path, 'w') as json_file:
                json_file.write(json_str)
        else:
            self.args.data_path = self.args.station + '.csv'
            # exp setting
            setting = self._setting.replace('TARGETSTATION', self.args.station.replace(' ', ''))
            self.args.data = phase
            child_list = self._child_stations_dict[self.args.station]
            parent_list = self._parent_stations_dict[self.args.station]
            if base_flow == 'child':
                other_list = child_list
            elif base_flow == 'parent':
                other_list = parent_list
            elif base_flow == 'lag_correlation':
                other_list = self._constants.get_other_list(self.args.station, 'lag_correlation')
            else:
                other_list = child_list + parent_list

            useful_other_list = []
            for other_station in other_list:
                exp = Exp_MeKong_flow_list(self.args, phase, other_station, self._verbose, self._device)
                link_exists = exp.train(setting, other_station)
                if link_exists:
                    useful_other_list.append(other_station)
                print(self.args.station, other_station, link_exists)
            print(self.args.station, useful_other_list)

    def _run_single_station(self, station, phase, device):
        if phase != 'for_gnn':
            if station not in self._all_stations:
                raise ValueError(f"Station {station} not found.")

        if self.args.use_station_searched_param:
            self._load_best_station_params(station)

        self.args.data_path = station + '.csv'
        if self._verbose:
            print(f"Running {station} station")
        # exp setting
        setting = self._setting.replace('TARGETSTATION', station.replace(' ', ''))
        self.args.data = phase
        if phase == 'initial':
            exp = Exp_MeKong_initial(self.args, phase, self._verbose, device)
        elif phase == 'for_gnn':
            exp = Exp_MeKong_for_gnn(self.args, phase, self._verbose, device)
        elif phase == 'local_global':
            other_list = self._constants.get_other_list(station, other_station_type=self.args.other_station)
            # if other_list:
            #     exp = Exp_MeKong_local_global(self.args, phase, self._verbose, device)
            # else:
            #     exp = Exp_MeKong_initial(self.args, 'initial', self._verbose, device)
            exp = Exp_MeKong_local_global(self.args, phase, self._verbose, device)
        else:
            raise ValueError(f"Phase {phase} not found.")
        if self.args.is_training:
            if self._verbose:
                print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>\n'.format(setting))
            exp.train(setting)
            if self._verbose:
                print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n'.format(setting))
            try:
                avg_mae, avg_mse, avg_rmse, avg_mape, avg_mspe, avg_r2, avg_nse, avg_pbias, avg_kge, avg_flv, avg_fhv, \
                    median_mae, median_mse, median_rmse, median_mape, median_mspe, median_r2, median_nse, median_pbias, median_kge, median_flv, median_fhv = exp.test(setting)
            except Exception as e:
                avg_mae, avg_mse, avg_rmse, avg_mape, avg_mspe, avg_r2, avg_nse, avg_pbias, avg_kge, avg_flv, avg_fhv, \
                    median_mae, median_mse, median_rmse, median_mape, median_mspe, median_r2, median_nse, median_pbias, median_kge, median_flv, median_fhv = exp.test(
                    setting, train_phase='local')
            torch.cuda.empty_cache()
        else:
            if self._verbose:
                print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n'.format(setting))
            try:
                avg_mae, avg_mse, avg_rmse, avg_mape, avg_mspe, avg_r2, avg_nse, avg_pbias, avg_kge, avg_flv, avg_fhv, \
                    median_mae, median_mse, median_rmse, median_mape, median_mspe, median_r2, median_nse, median_pbias, median_kge, median_flv, median_fhv = exp.test(setting, test=1)
            except Exception as e:
                avg_mae, avg_mse, avg_rmse, avg_mape, avg_mspe, avg_r2, avg_nse, avg_pbias, avg_kge, avg_flv, avg_fhv, \
                    median_mae, median_mse, median_rmse, median_mape, median_mspe, median_r2, median_nse, median_pbias, median_kge, median_flv, median_fhv = exp.test(
                    setting, test=1, train_phase='local')
            torch.cuda.empty_cache()
        return avg_mae, avg_mse, avg_rmse, avg_mape, avg_mspe, avg_r2, avg_nse, avg_pbias, avg_kge, avg_flv, avg_fhv

    @ray.remote(num_gpus=0.5)
    def _run_station_task(self, station, run_type, device):
        station_str = str(station)
        return self._run_single_station(station_str, run_type, device)

    def _run_tasks(self, phase):
        if self.args.station == 'all' and phase != 'for_gnn':
            tasks_list = self._all_stations

            start_time = time.time()
            results = []

            def fmt_time(seconds):
                m, s = divmod(int(seconds), 60)
                h, m = divmod(m, 60)
                if h > 0:
                    return f"{h}h {m}m {s}s"
                elif m > 0:
                    return f"{m}m {s}s"
                else:
                    return f"{s}s"

            completed = 0

            # 初始化 Ray：分配所有 GPU
            if self.args.use_ray:
                num_gpus = torch.cuda.device_count()
                ray.init(num_gpus=num_gpus, ignore_reinit_error=True)
                print(f"🚀 Ray start, detect {num_gpus} GPUs")
                # Submit all tasks
                futures = [self._run_station_task.remote(self, s, phase, self._device) for s in tasks_list]
                num_tasks = len(futures)

                while futures:
                    # Wait for any task to complete
                    done, futures = ray.wait(futures, num_returns=1)
                    res = ray.get(done[0])
                    results.append(res)

                    completed += 1
                    elapsed = time.time() - start_time
                    avg_time = elapsed / completed
                    remaining = (num_tasks - completed) * avg_time

                    print(
                        f"✅ Progress: {completed}/{num_tasks} | Elapsed: {fmt_time(elapsed)} | ETA: {fmt_time(remaining)}")
                ray.shutdown()  # 任务完成后关闭 Ray
            else:
                for i, station in enumerate(tasks_list):
                    completed += 1
                    station_str = str(station)
                    res = self._run_single_station(station_str, phase, self._device)
                    results.append(res)
                    num_tasks = len(tasks_list)
                    elapsed = time.time() - start_time
                    avg_time = elapsed / completed
                    remaining = (num_tasks - completed) * avg_time
                    print(
                        f"✅ Progress: {completed}/{num_tasks} | Elapsed: {fmt_time(elapsed)} | ETA: {fmt_time(remaining)}")
            if phase == 'local_global':
                for i, station in enumerate(tasks_list):
                    station_str = str(station)
                    station_search_param = StationSpecSearchHyperparam(self.args, station_str, metric='nse')
                    station_search_param.save_best_station_param()
                    station_search_param.save_best_results_to_csv()

            all_mae, all_mse, all_rmse, all_mape, all_mspe, \
                all_r2, all_nse, all_pbias, all_kge, all_flv, all_fhv = zip(*results)

            avg_mae = np.mean(all_mae)
            median_mae = np.median(all_mae)
            avg_mse = np.mean(all_mse)
            median_mse = np.median(all_mse)
            avg_rmse = np.mean(all_rmse)
            median_rmse = np.median(all_rmse)
            avg_mape = np.mean(all_mape)
            median_mape = np.median(all_mape)
            avg_mspe = np.mean(all_mspe)
            median_mspe = np.median(all_mspe)
            avg_r2 = np.mean(all_r2)
            median_r2 = np.median(all_r2)
            avg_nse = np.mean(all_nse)
            median_nse = np.median(all_nse)
            avg_pbias = np.mean(all_pbias)
            median_pbias = np.median(all_pbias)
            avg_kge = np.mean(all_kge)
            median_kge = np.median(all_kge)
            avg_flv = np.mean(all_flv)
            median_flv = np.median(all_flv)
            avg_fhv = np.mean(all_fhv)
            median_fhv = np.median(all_fhv)

            print(f"Run all stations in {phase} done, "
                  f"overall duration {time.time() - start_time:.2f}s")
            print(f"avg_nse: {avg_nse}, avg_pbias: {avg_pbias}, avg_kge: {avg_kge}, "
                  f"avg_flv: {avg_flv}, avg_fhv: {avg_fhv}, avg_mae: {avg_mae}, "
                  f"avg_mse: {avg_mse}, avg_rmse: {avg_rmse}, avg_mape: {avg_mape}, "
                  f"avg_mspe: {avg_mspe}, avg_r2: {avg_r2}")
            print(f"median_nse: {median_nse}, median_pbias: {median_pbias}, median_kge: {median_kge}, "
                  f"median_flv: {median_flv}, median_fhv: {median_fhv}, "
                  f"median_mae: {median_mae}, median_mse: {median_mse}, median_rmse: {median_rmse}, "
                  f"median_mape: {median_mape}, median_mspe: {median_mspe}, median_r2: {median_r2}")

            setting = self._setting.replace('TARGETSTATION', '')
            f = open(f"result_{phase}_all.txt", 'a')
            f.write(setting + "  \n")
            f.write('avg rmse:{}, avg mae:{}, avg mae:{}, avg avg_mape:{}, avg avg_nse:{}'.format(
                avg_rmse, avg_mae, avg_mape, avg_nse, avg_r2))
            f.write('\n')
            f.write('\n')
            f.close()
            if phase == 'local_global':
                phase_str = 'FlowNet'
            else:
                phase_str = phase
            if self.args.ablation != 'None':
                csv_save_path = os.path.join(
                    './ablation/', self.args.target, self.args.data_time_path, self.args.ablation,
                    f'sl{self.args.seq_len}_pl{self.args.pred_len}', phase_str
                )
            else:
                csv_save_path = os.path.join(
                    './baselines_results/', self.args.target, self.args.data_time_path,
                    f'sl{self.args.seq_len}_pl{self.args.pred_len}', phase_str
                )
            if not os.path.exists(csv_save_path):
                os.makedirs(csv_save_path)
            make_table = MakeTable(self.args, self.args.model, phase, self.args.results, setting, csv_save_path=csv_save_path)
            make_table.save_to_csv()
            # make_table.collect_stations_rmse_results(phase=phase)
            print("Make results table done")
            if phase in ["local_global"]:
                collect_mts_results(self.args, phase, self._constants,
                                    read_path=csv_save_path,
                                    save_path=csv_save_path, phase='local')
                for loop in range(self.args.global_loop):
                    collect_mts_results(self.args, phase, self._constants,
                                        read_path=csv_save_path,
                                        save_path=csv_save_path, phase=f'global_{loop}')
            else:
                # collect_mts_results(self.args, phase, self._constants,
                #                     read_path=csv_save_path,
                #                     save_path=csv_save_path, phase='initial')
                pass
        else:
            print(f"Running station {self.args.station} in {phase}")
            task_start_time = time.time()
            mae, mse, rmse, mape, mspe, r2, nse, pbias, kge, flv, fhv = self._run_single_station(self.args.station,
                                                                                                 phase, self._device)
            duration = time.time() - task_start_time
            print(f"duration: {duration:.2f}s, "
                  f'nse:{nse}, pbias:{pbias}, kge:{kge}, flv:{flv}, '
                  f'fhv:{fhv}, mae: {mae}, mse: {mse}, rmse: {rmse}, '
                  f'mape: {mape}, mspe: {mspe}, r2: {r2}')
            print(f"Run station {self.args.station} in {phase} done")
            if phase == 'for_gnn':
                if self.args.ablation != 'None':
                    csv_save_path = os.path.join(
                        './ablation/', self.args.target, self.args.data_time_path, self.args.ablation,
                        f'sl{self.args.seq_len}_pl{self.args.pred_len}'
                    )
                else:
                    csv_save_path = os.path.join(
                        './baselines_results/', self.args.target, self.args.data_time_path,
                        f'sl{self.args.seq_len}_pl{self.args.pred_len}'
                    )
                if not os.path.exists(csv_save_path):
                    os.makedirs(csv_save_path)
                setting = self._setting.replace('TARGETSTATION', '')
                make_table = MakeTable(self.args, self.args.model, phase, self.args.results, setting,
                                       csv_save_path=csv_save_path)
                make_table.save_to_csv()
                # make_table.collect_stations_rmse_results(phase=phase)
                print("Make results table done")

    def run(self):
        # exp setting
        self._setting = '{}_{}_{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_fc{}_eb{}_dt{}_{}_alpha{}_gloop{}_glr{}_glrf{}_{}_seed{}'.format(
            self.args.model,
            self.args.target,
            'TARGETSTATION',
            self.args.seq_len,
            self.args.label_len,
            self.args.pred_len,
            self.args.d_model,
            self.args.n_heads,
            self.args.e_layers,
            self.args.d_layers,
            self.args.d_ff,
            self.args.factor,
            self.args.embed,
            self.args.distil,
            self.args.other_station,
            self.args.alpha,
            self.args.global_loop,
            self.args.global_lr,
            self.args.global_lr_factor,
            self.args.des,
            self.args.random_seed
        )
        if self.args.run_initial:
            self.args.phase = 'initial'
            if args.model in ['CrossGNN', 'FourierGNN', 'GWNet', 'TGCN', 'GCN', 'GCNII', 'ResGCN', 'ResGAT', 'AGCLSTM', 'AGCLSTM_revin']:
                self._run_tasks('for_gnn')
            else:
                self._run_tasks('initial')
        if self.args.run_local_global:
            self.args.phase = 'local_global'
            self._run_tasks('local_global')
        if self.args.get_flow_list:
            self.get_flow_station_list(flow_type='flow', base_flow=self.args.base_flow)


if __name__ == '__main__':
    args = parse_args()
    random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)
    print_args(args)

    forcast = MeKongWaterLevelPrediction(args)
    forcast.run()
