from data_provider.data_factory import data_provider
from exp.exp_basic import Exp_Basic
from utils.tools import EarlyStopping_mekong_local_global, adjust_learning_rate, visual
from utils.metrics import metric
import torch
import torch.nn as nn
from torch import optim
import os
import time
import warnings
import numpy as np
from copy import deepcopy
import json
from station_spec_search_hyperparam import StationSpecSearchHyperparam
from latex.pred_table import MakeTable

current_dir = os.path.dirname(os.path.abspath(__file__))

warnings.filterwarnings('ignore')


class Exp_MeKong(Exp_Basic):
    """
    Global:
    If target station is B, whose link stations are [A, C], build models:
     B: L_B -> H_B1
     A_2_B: L_A -> H_B2
     C_2_B: L_C -> H_B3
     pred_from_AC = avg(H_B2 + H_B3)
     prediction = alpha * H_B1 + (1 - alpha) * pred_from_AC
    """

    def __init__(self, args, phase, verbose=True, device='cpu'):
        super(Exp_MeKong, self).__init__(args, phase, verbose=verbose, device=device)
        if self.args.other_station == 'child_parent':
            other_list = (self.constants.child_stations_dict[self.target_station] +
                          self.constants.parent_stations_dict[self.target_station])
            self.child_list = self.constants.child_stations_dict[self.target_station]
            self.parent_list = self.constants.parent_stations_dict[self.target_station]
        elif self.args.other_station == 'lag_correlation':
            other_list = self.constants.lag_correlation_dict[self.target_station]
            all_child_list = self.constants.child_stations_dict[self.target_station]
            all_parent_list = self.constants.parent_stations_dict[self.target_station]
            self.child_list = []
            self.parent_list = []
            for station in other_list:
                if station in all_child_list:
                    self.child_list.append(station)
                elif station in all_parent_list:
                    self.parent_list.append(station)
        elif self.args.other_station == 'flow':
            other_list = self.constants.get_other_list(self.target_station, 'flow')
            all_child_list = self.constants.child_stations_dict[self.target_station]
            all_parent_list = self.constants.parent_stations_dict[self.target_station]
            self.child_list = []
            self.parent_list = []
            for station in other_list:
                if station in all_child_list:
                    self.child_list.append(station)
                elif station in all_parent_list:
                    self.parent_list.append(station)
        elif self.args.other_station == 'random':
            other_list = self.constants.randomize_dict[self.target_station]
            all_child_list = self.constants.child_stations_dict[self.target_station]
            all_parent_list = self.constants.parent_stations_dict[self.target_station]
            self.child_list = []
            self.parent_list = []
            for station in other_list:
                if station in all_child_list:
                    self.child_list.append(station)
                elif station in all_parent_list:
                    self.parent_list.append(station)
        else:
            raise NotImplementedError
        self.stations_list = other_list
        self.stations_list.append(self.target_station)
        self.models_dic = self._build_models_dict()
        self.alpha = self.args.alpha
        num_param = 0
        for model_name in self.models_dic.keys():
            num_param += self.numel(self.models_dic[model_name], True)
        if self.verbose:
            print('model number of parameters:', num_param)

    def numel(self, m: torch.nn.Module, only_trainable: bool = False):
        """
        Returns the total number of parameters used by `m` (only counting
        shared parameters once); if `only_trainable` is True, then only
        includes parameters with `requires_grad = True`
        """
        parameters = list(m.parameters())
        if only_trainable:
            parameters = [p for p in parameters if p.requires_grad]
        unique = {p.data_ptr(): p for p in parameters}.values()
        return sum(p.numel() for p in unique)

    def _build_models_dict(self):
        models_dic = {}
        if self.args.features == 'S':
            num_channels = 1
        else:
            num_channels = len(self.station_channels_dict[self.target_station])
        self.args.enc_in = num_channels
        self.args.dec_in = num_channels
        self.args.c_out = num_channels
        models_dic[self.target_station] = self.model_dict[self.args.model].Model(self.args).float().to(self.device)
        for name in self.stations_list:
            if name == self.target_station:
                continue
            if self.args.features == 'S':
                num_channels = 1
            else:
                num_channels = len(self.station_channels_dict[name])
            self.args.enc_in = num_channels
            self.args.dec_in = num_channels
            self.args.c_out = num_channels
            models_dic[name] = self.model_dict[self.args.model].Model(self.args).float().to(self.device)
        return models_dic

    def _get_data(self, flag, stations_list=None):
        if stations_list is None:
            stations_list = self.stations_list
        data_loader_name = 'all_in_list'
        data_set, data_loader = data_provider(self.args, flag, self.verbose,
                                              data_loader_name, stations_list=stations_list)
        return data_set, data_loader

    def _select_optimizer(self, model, lr=None):
        if lr is None:
            lr = self.args.learning_rate
        if isinstance(model, list):
            model_optim = optim.Adam(model, lr=lr)
        else:
            model_optim = optim.Adam(model.parameters(), lr=lr)
        return model_optim

    def _select_criterion(self):
        criterion = nn.MSELoss()
        return criterion

    def vali(self, criterion, dataloader, train_phase='local', station=None):
        total_loss = []
        for name in self.models_dic.keys():
            self.models_dic[name].eval()
        with torch.no_grad():
            for i, data_list in enumerate(dataloader):
                input_list = data_list[0]
                true_list = data_list[1]
                input_list = [x.float().to(self.device) for x in input_list]
                true_list = [x.float().to(self.device) for x in true_list]
                batch_x_mark = data_list[2].float().to(self.device)
                batch_y_mark = data_list[3].float().to(self.device)
                target_x = deepcopy(input_list[-1])
                target_y = deepcopy(true_list[-1])
                # decoder input
                dec_inp = torch.zeros_like(target_y[:, -self.args.pred_len:, -1:]).float()
                dec_inp = torch.cat([target_y[:, :self.args.label_len, -1:], dec_inp], dim=1).float().to(self.device)
                # encoder - decoder
                if train_phase == 'local':
                    if station == [self.target_station]:
                        target_out = self.models_dic[self.target_station](target_x, batch_x_mark, dec_inp, batch_y_mark)
                    else:
                        all_model_out = []
                        for name in station:
                            if name == self.target_station:
                                continue
                            station_i = station.index(name)
                            input_x = deepcopy(input_list[station_i])
                            model_out = self.models_dic[name](input_x, batch_x_mark, dec_inp, batch_y_mark,
                                                              target_x=target_x)
                            all_model_out.append(model_out[:, -self.args.pred_len:, -1:])
                        target_out = torch.stack(all_model_out, dim=0).sum(dim=0)
                else:
                    input_x = target_x
                    target_out = self.models_dic[station](input_x, batch_x_mark, dec_inp, batch_y_mark,
                                                          target_x=target_x)
                    target_out = target_out[:, -self.args.pred_len:, -1:]
                # other_out = self._get_other_out(input_list)
                # real_out is the output for inference: avg from target and others
                real_out = target_out
                real_out = real_out[:, -self.args.pred_len:, -1:]
                target_y = target_y[:, -self.args.pred_len:, -1:]
                del data_list
                del input_list
                del true_list
                real_out = real_out.detach().cpu()
                target_y = target_y.detach().cpu()
                loss = criterion(real_out, target_y)
                total_loss.append(loss)
        total_loss = np.average(total_loss)
        # for name in self.models_dic.keys():
        #     self.models_dic[name].train()
        return total_loss

    def _train_local(self, setting, stations_list=None):
        if stations_list == []:
            return None
        self.train_data, self.train_loader = self._get_data(flag='train', stations_list=stations_list)
        self.vali_data, self.vali_loader = self._get_data(flag='val', stations_list=stations_list)
        self.test_data, self.test_loader = self._get_data(flag='test', stations_list=stations_list)
        path = os.path.join(self.args.checkpoints, setting)
        if not os.path.exists(os.path.join(path, self.phase, 'local')):
            os.makedirs(os.path.join(path, self.phase, 'local'))

        time_now = time.time()

        train_steps = len(self.train_loader)
        early_stopping = EarlyStopping_mekong_local_global(patience=self.args.patience, verbose=self.verbose,
                                                           delta=1e-6)

        opt_params_target = [{'params': self.models_dic[m].parameters()} for m in stations_list]
        model_optim = self._select_optimizer(opt_params_target)
        criterion = self._select_criterion()

        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []
            epoch_time = time.time()
            for station in stations_list:
                self.models_dic[station].train()
            for i, data_list in enumerate(self.train_loader):
                iter_count += 1
                input_list = data_list[0]
                true_list = data_list[1]
                batch_x_mark = data_list[2].float().to(self.device)
                batch_y_mark = data_list[3].float().to(self.device)
                input_list = [x.float().to(self.device) for x in input_list]
                true_list = [x.float().to(self.device) for x in true_list]
                target_x = deepcopy(input_list[-1])
                target_y = deepcopy(true_list[-1])
                # decoder input
                dec_inp = torch.zeros_like(target_y[:, -self.args.pred_len:, -1:]).float()
                dec_inp = torch.cat([target_y[:, :self.args.label_len, -1:], dec_inp], dim=1).float().to(self.device)
                # encoder - decoder
                if stations_list == [self.target_station]:
                    model_out = self.models_dic[self.target_station](target_x, batch_x_mark, dec_inp, batch_y_mark)
                    real_out = model_out
                else:
                    model_out_all = []
                    for station in stations_list:
                        if station == self.target_station:
                            continue
                        station_i = stations_list.index(station)
                        input_x = deepcopy(input_list[station_i])
                        model_out = self.models_dic[station](input_x, batch_x_mark, dec_inp, batch_y_mark,
                                                             target_x=target_x)
                        model_out_all.append(model_out[:, -self.args.pred_len:, -1:])
                    sum_model_out = torch.stack(model_out_all, dim=0).sum(dim=0)
                    real_out = sum_model_out
                real_out = real_out[:, -self.args.pred_len:, -1:]
                target_y = target_y[:, -self.args.pred_len:, -1:]
                model_optim.zero_grad()
                loss = criterion(real_out, target_y)
                loss.backward()
                model_optim.step()
                del data_list
                del input_list
                del true_list
                # torch.cuda.empty_cache()
                train_loss.append(loss.item())

                if self.verbose:
                    if (i + 1) % 100 == 0:
                        print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
                        speed = (time.time() - time_now) / iter_count
                        left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
                        print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                        iter_count = 0
                        time_now = time.time()

            if self.verbose:
                print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
            train_loss = np.average(train_loss)
            vali_loss = self.vali(criterion, self.vali_loader, train_phase='local', station=stations_list)
            test_loss = self.vali(criterion, self.test_loader, train_phase='local', station=stations_list)
            if self.verbose:
                print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
                    epoch + 1, train_steps, train_loss, vali_loss, test_loss))
            early_stopping(vali_loss, self.models_dic, path + f'/{self.phase}/local/')
            if early_stopping.early_stop:
                if self.verbose:
                    print("Early stopping")
                break
            adjust_learning_rate(model_optim, epoch + 1, self.args, self.verbose)
        for station in stations_list:
            ckpt_name = f'checkpoint_{station}.pth'
            best_model_path = path + f'/{self.phase}/local/' + ckpt_name
            self.models_dic[station].load_state_dict(torch.load(best_model_path))

    def _train_global(self, setting, loop=0):
        # other_station_list = self.other_station_list
        path = os.path.join(self.args.checkpoints, setting)
        if not os.path.exists(os.path.join(path, self.phase, 'global')):
            os.makedirs(os.path.join(path, self.phase, 'global'))
        for name in self.models_dic.keys():
            ckpt_name = f'checkpoint_{name}.pth'
            torch.save(self.models_dic[name].state_dict(), os.path.join(path, self.phase, 'global', ckpt_name))
        # if not other_station_list:
        #     return None

        time_now = time.time()

        train_steps = len(self.train_loader)
        early_stopping = EarlyStopping_mekong_local_global(patience=self.args.patience, verbose=self.verbose,
                                                           delta=1e-6)

        model_optim_dic = {}
        if self.child_list:
            opt_params_child = [{'params': self.models_dic[m].parameters()} for m in self.child_list]
            model_optim_dic['child'] = self._select_optimizer(
                opt_params_child, lr=self.args.global_lr * (self.args.global_lr_factor ** loop)
            )
        if self.parent_list:
            opt_params_parent = [{'params': self.models_dic[m].parameters()} for m in self.parent_list]
            model_optim_dic['parent'] = self._select_optimizer(
                opt_params_parent, lr=self.args.global_lr * (self.args.global_lr_factor ** loop)
            )
        opt_params_target = [{'params': self.models_dic[self.target_station].parameters()}]
        model_optim_dic['target'] = self._select_optimizer(
            opt_params_target, lr=self.args.global_lr * (self.args.global_lr_factor ** loop)
        )
        criterion = self._select_criterion()

        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []
            epoch_time = time.time()
            for name in self.models_dic.keys():
                self.models_dic[name].train()
            for i, data_list in enumerate(self.train_loader):
                iter_count += 1
                input_list = data_list[0]
                true_list = data_list[1]
                batch_x_mark = data_list[2].float().to(self.device)
                batch_y_mark = data_list[3].float().to(self.device)
                input_list = [x.float().to(self.device) for x in input_list]
                true_list = [x.float().to(self.device) for x in true_list]
                target_i = self.stations_list.index(self.target_station)
                target_x = deepcopy(input_list[target_i])
                target_y = deepcopy(true_list[target_i])
                # decoder input
                dec_inp = torch.zeros_like(target_y[:, -self.args.pred_len:, -1:]).float()
                dec_inp = torch.cat([target_y[:, :self.args.label_len, -1:], dec_inp], dim=1).float().to(self.device)
                # encoder - decoder
                target_out = self.models_dic[self.target_station](target_x, batch_x_mark, dec_inp, batch_y_mark)
                target_out = target_out[:, -self.args.pred_len:, -1:]
                other_out_list = []
                # update child
                if self.child_list:
                    all_child_out = []
                    for x_i, station in enumerate(self.child_list):
                        station_i = self.stations_list.index(station)
                        input_x = deepcopy(input_list[station_i])
                        child_out = self.models_dic[station](input_x, batch_x_mark, dec_inp, batch_y_mark,
                                                             target_x=target_x)
                        all_child_out.append(child_out[:, -self.args.pred_len:, -1:])
                    sum_child_out = torch.stack(all_child_out, dim=0).sum(dim=0)
                    model_optim_dic['child'].zero_grad()
                    loss = (self.alpha * criterion(sum_child_out[:, -self.args.pred_len:, -1:],
                                                   target_y[:, -self.args.pred_len:, -1:])
                            + (1 - self.alpha) * criterion(target_out[:, -self.args.pred_len:, -1:].detach(),
                                                           sum_child_out))
                    loss.backward(retain_graph=True)
                    model_optim_dic['child'].step()

                # update parent
                if self.parent_list:
                    all_parent_out = []
                    for x_i, station in enumerate(self.parent_list):
                        station_i = self.stations_list.index(station)
                        input_x = deepcopy(input_list[station_i])
                        parent_out = self.models_dic[station](input_x, batch_x_mark, dec_inp, batch_y_mark,
                                                              target_x=target_x)
                        all_parent_out.append(parent_out[:, -self.args.pred_len:, -1:])
                    sum_parent_out = torch.stack(all_parent_out, dim=0).sum(dim=0)
                    model_optim_dic['parent'].zero_grad()
                    loss = (self.alpha * criterion(sum_parent_out[:, -self.args.pred_len:, -1:],
                                                   target_y[:, -self.args.pred_len:, -1:])
                            + (1 - self.alpha) * criterion(target_out[:, -self.args.pred_len:, -1:].detach(),
                                                           sum_parent_out))
                    loss.backward(retain_graph=True)
                    model_optim_dic['parent'].step()

                # update target
                if self.child_list and self.parent_list:
                    avg_other_out = torch.stack([sum_child_out, sum_parent_out], dim=0).mean(dim=0)
                elif self.child_list and self.parent_list == []:
                    avg_other_out = sum_child_out
                elif self.child_list == [] and self.parent_list:
                    avg_other_out = sum_parent_out
                else:
                    raise ValueError('both child list and parent list are empty')
                model_optim_dic['target'].zero_grad()
                loss = (self.alpha * criterion(target_out[:, -self.args.pred_len:, -1:],
                                               target_y[:, -self.args.pred_len:, -1:])
                        + (1 - self.alpha) * criterion(avg_other_out[:, -self.args.pred_len:, -1:].detach(),
                                                       target_out[:, -self.args.pred_len:, -1:]))
                loss.backward()
                model_optim_dic['target'].step()
                real_out = target_out
                real_out = real_out[:, -self.args.pred_len:, -1:]
                target_y = target_y[:, -self.args.pred_len:, -1:]
                del data_list
                del input_list
                del true_list
                loss = criterion(real_out, target_y)
                train_loss.append(loss.item())

                if self.verbose:
                    if (i + 1) % 100 == 0:
                        print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
                        speed = (time.time() - time_now) / iter_count
                        left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
                        print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                        iter_count = 0
                        time_now = time.time()

            if self.verbose:
                print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
            train_loss = np.average(train_loss)
            vali_loss = self.vali(criterion, self.vali_loader, train_phase='global', station=self.target_station)
            test_loss = self.vali(criterion, self.test_loader, train_phase='global', station=self.target_station)
            if self.verbose:
                print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
                    epoch + 1, train_steps, train_loss, vali_loss, test_loss))
            early_stopping(vali_loss, self.models_dic, path + f'/{self.phase}/global/')
            if early_stopping.early_stop:
                if self.verbose:
                    print("Early stopping")
                break

            for adj_lr_i, name in enumerate(model_optim_dic.keys()):
                if adj_lr_i == 0:
                    verbose = self.verbose
                else:
                    verbose = False
                adjust_learning_rate(model_optim_dic[name], epoch + 1, self.args, verbose)

        for name in self.models_dic.keys():
            ckpt_name = f'checkpoint_{name}.pth'
            best_model_path = path + f'/{self.phase}/global/' + ckpt_name
            self.models_dic[name].load_state_dict(torch.load(best_model_path))

    def train(self, setting):
        if self.child_list:
            child_list_with_target = deepcopy(self.child_list)
            child_list_with_target.append(self.target_station)
        else:
            child_list_with_target = []
        if self.parent_list:
            parent_list_with_target = deepcopy(self.parent_list)
            parent_list_with_target.append(self.target_station)
        else:
            parent_list_with_target = []
        for stations_list in [child_list_with_target, parent_list_with_target]:
            self._train_local(setting, stations_list)
        self._train_local(setting, [self.target_station])

        self.train_data, self.train_loader = self._get_data(flag='train')
        self.vali_data, self.vali_loader = self._get_data(flag='val')
        self.test_data, self.test_loader = self._get_data(flag='test')
        mae, mse, rmse, mape, mspe, r2, nse, pbias, kge, flv, fhv, *_ = self.test(setting, train_phase='local',
                                                                                  load_data=False)
        StationSpecSearchHyperparam(self.args, self.target_station, metric='r2').save_station_results(rmse, mae,
                                                                                                      mape, nse, r2,0)
        if self.verbose:
            print(f"Local, {self.target_station}, rmse: {rmse}, mae: {mae}, mape: {mape}, nse: {nse}, r2:{r2}")
        if self.args.ablation != 'None':
            csv_save_path = os.path.join(
                './ablation/', self.args.target, self.args.data_time_path, self.args.ablation,
                f'sl{self.args.seq_len}_pl{self.args.pred_len}', 'FlowNet', 'local'
            )
        else:
            csv_save_path = os.path.join(
                './baselines_results/', self.args.target, self.args.data_time_path,
                f'sl{self.args.seq_len}_pl{self.args.pred_len}', 'FlowNet', 'local'
            )
        if not os.path.exists(csv_save_path):
            os.makedirs(csv_save_path)
        make_table = MakeTable(self.args, self.args.model, self.phase, self.args.results, setting,
                               csv_save_path=csv_save_path, target_station=self.target_station, train_phase='local',
                               loop=0)
        make_table.save_to_csv()
        # make_table.collect_stations_rmse_results(phase=phase)
        if self.verbose:
            print("Make results table done")

        # Global
        if child_list_with_target == [] and parent_list_with_target == []:
            return None

        for loop in range(self.args.global_loop):
            self._train_global(setting, loop)
            mae, mse, rmse, mape, mspe, r2, nse, pbias, kge, flv, fhv, *_ = self.test(setting, train_phase='global',
                                                                                      loop=loop, load_data=False)
            if self.verbose:
                print(
                    f"Global loop {loop + 1}, {self.target_station}, rmse: {rmse}, mae: {mae}, mape: {mape}, nse: {nse}, r2:{r2}")
            StationSpecSearchHyperparam(self.args, self.target_station, metric='r2').save_station_results(rmse, mae,
                                                                                                          mape, nse, r2,
                                                                                                          loop + 1)
            if self.args.ablation != 'None':
                csv_save_path = os.path.join(
                    './ablation/', self.args.target, self.args.data_time_path, self.args.ablation,
                    f'sl{self.args.seq_len}_pl{self.args.pred_len}', 'FlowNet', f'global_{loop}'
                )
            else:
                csv_save_path = os.path.join(
                    './baselines_results/', self.args.target, self.args.data_time_path,
                    f'sl{self.args.seq_len}_pl{self.args.pred_len}', 'FlowNet', f'global_{loop}'
                )
            if not os.path.exists(csv_save_path):
                os.makedirs(csv_save_path)
            make_table = MakeTable(self.args, self.args.model, self.phase, self.args.results, setting,
                                   csv_save_path=csv_save_path, target_station=self.target_station,
                                   train_phase='global', loop=loop)
            make_table.save_to_csv()
            # make_table.collect_stations_rmse_results(phase=phase)
            if self.verbose:
                print("Make results table done")

    def test(self, setting, test=0, train_phase='global', loop=0, load_data=True):
        if load_data:
            self.train_data, self.train_loader = self._get_data(flag='train')
            self.vali_data, self.vali_loader = self._get_data(flag='val')
            self.test_data, self.test_loader = self._get_data(flag='test')

        if self.verbose:
            print('loading model')
        self.models_dic = self._build_models_dict()
        for name in self.models_dic.keys():
            ckpt_name = f'checkpoint_{name}.pth'
            self.models_dic[name].load_state_dict(torch.load(
                os.path.join('./checkpoints/' + setting, self.phase, train_phase, ckpt_name)
            ))

        preds = []
        trues = []
        inputs = []
        stations_preds = {}
        for station_i, station in enumerate(self.stations_list):
            stations_preds[station] = []
        folder_path = './test_results/' + setting + f'/{self.phase}/{train_phase}/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)
        for station in self.stations_list:
            station_out_folder_path = './test_results/' + setting + f'/{self.phase}/{train_phase}/' + f'{station}/'
            if not os.path.exists(station_out_folder_path):
                os.makedirs(station_out_folder_path)

        for name in self.models_dic.keys():
            self.models_dic[name].eval()
        with torch.no_grad():
            for i, data_list in enumerate(self.test_loader):
                input_list = data_list[0]
                true_list = data_list[1]
                input_list = [x.float().to(self.device) for x in input_list]
                true_list = [x.float().to(self.device) for x in true_list]
                batch_x_mark = data_list[2].float().to(self.device)
                batch_y_mark = data_list[3].float().to(self.device)
                target_i = self.stations_list.index(self.target_station)
                target_x = deepcopy(input_list[target_i])
                target_y = deepcopy(true_list[target_i])
                # decoder input
                dec_inp = torch.zeros_like(target_y[:, -self.args.pred_len:, -1:]).float()
                dec_inp = torch.cat([target_y[:, :self.args.label_len, -1:], dec_inp], dim=1).float().to(self.device)
                # encoder - decoder
                target_out = self.models_dic[self.target_station](target_x, batch_x_mark, dec_inp, batch_y_mark)
                target_out = target_out[:, -self.args.pred_len:, -1:]
                # other_out = self._get_other_out(input_list)
                # real_out is the output for inference: avg from target and others
                real_out = target_out
                real_out = real_out[:, -self.args.pred_len:, -1:]
                target_y = target_y[:, -self.args.pred_len:, -1:]
                real_out = real_out.detach().cpu().numpy()
                target_y = target_y.detach().cpu().numpy()
                if self.args.inverse:
                    shape = real_out.shape
                    real_out = self.test_data.inverse_transform(real_out.reshape(shape[0] * shape[1], -1)).reshape(
                        shape)
                    target_y = self.test_data.inverse_transform(target_y.reshape(shape[0] * shape[1], -1)).reshape(
                        shape)

                pred = real_out
                true = target_y
                preds.append(pred)
                trues.append(true)
                # if i % 20 == 0:
                for station_i, station in enumerate(self.stations_list):
                    station_x = deepcopy(input_list[station_i])
                    station_out = self.models_dic[station](station_x, batch_x_mark, dec_inp, batch_y_mark,
                                                           target_x=target_x)
                    station_out = station_out[:, -self.args.pred_len:, -1:]
                    station_out = station_out.detach().cpu().numpy()
                    if self.args.inverse:
                        shape = station_out.shape
                        station_out = self.test_data.inverse_transform(
                            station_out.reshape(shape[0] * shape[1], -1)).reshape(shape)
                    stations_preds[station].append(station_out)
                input = target_x[:, :, -1:].detach().cpu().numpy()
                if self.args.inverse:
                    shape = input.shape
                    input = self.test_data.inverse_transform(input.reshape(shape[0] * shape[1], -1)).reshape(shape)
                inputs.append(input)
                for bs_i in range(input.shape[0]):
                    if bs_i % 7 == 0:
                        for station_i, station in enumerate(self.stations_list):
                            gt = np.concatenate((input[bs_i, :, -1], true[bs_i, :, -1]), axis=0)
                            pd = np.concatenate((input[bs_i, :, -1], station_out[bs_i, :, -1]), axis=0)
                            station_out_folder_path = ('./test_results/' + setting +
                                                       f'/{self.phase}/{train_phase}/' + f'{station}/')
                            visual(gt, pd, os.path.join(station_out_folder_path, f'{i}_{bs_i}.pdf'))
                del data_list
                del input_list
                del true_list

        preds = np.concatenate(preds, axis=0)
        trues = np.concatenate(trues, axis=0)
        inputs = np.concatenate(inputs, axis=0)
        if self.verbose:
            print('test shape:', preds.shape, trues.shape)
        preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])
        trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1])
        inputs = inputs.reshape(-1, inputs.shape[-2], inputs.shape[-1])
        if self.verbose:
            print('test shape:', preds.shape, trues.shape)

        for station_i, station in enumerate(self.stations_list):
            stations_pred = stations_preds[station]
            stations_pred = np.concatenate(stations_pred, axis=0)
            stations_pred = stations_pred.reshape(-1, stations_pred.shape[-2], stations_pred.shape[-1])
            mae, mse, rmse, mape, mspe, r2, nse, pbias, kge, flv, fhv = metric(stations_pred[:, :, -1], trues[:, :, -1])
            station_out_folder_path = './test_results/' + setting + f'/{self.phase}/{train_phase}/' + f'{station}/'
            with open(os.path.join(station_out_folder_path, f"rmse.txt"), "w") as text_file:
                text_file.write(f"RMSE {rmse}, NSE {nse}")
            if self.args.out_station_preds:
                folder_path = 'out_preds/' + setting + f'/{self.phase}/'
                if not os.path.exists(folder_path):
                    os.makedirs(folder_path)
                np.save(folder_path + f'{station}_metrics.npy',
                        np.array([mae, mse, rmse, mape, mspe, r2, nse, pbias, kge, flv, fhv]))
                np.save(folder_path + f'{station}_pred.npy', stations_pred)

        # result save
        folder_path = self.args.results + setting + f'/{train_phase}_{loop}/{self.target_station}/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        mae, mse, rmse, mape, mspe, r2, nse, pbias, kge, flv, fhv = metric(preds[:, :, -1], trues[:, :, -1])
        if self.verbose:
            print(f'nse:{nse}, pbias:{pbias}, kge:{kge}, flv:{flv}, '
                  f'fhv:{fhv}, mae: {mae}, mse: {mse}, rmse: {rmse}, '
                  f'mape: {mape}, mspe: {mspe}, r2: {r2}')
        f = open("result_mekong_phase_a.txt", 'a')
        f.write(setting + "  \n")
        f.write(f'nse:{nse}, pbias:{pbias}, kge:{kge}, flv:{flv}, '
                f'fhv:{fhv}, mae: {mae}, mse: {mse}, rmse: {rmse}, '
                f'mape: {mape}, mspe: {mspe}, r2: {r2}')
        f.write('\n')
        f.write('\n')
        f.close()

        np.save(folder_path + 'metrics.npy', np.array([mae, mse, rmse, mape, mspe, nse, r2]))
        np.save(folder_path + 'pred.npy', preds)
        np.save(folder_path + 'true.npy', trues)
        np.save(folder_path + 'input.npy', inputs)
        if self.args.out_station_preds:
            folder_path = 'out_preds/' + setting + f'/{self.phase}/'
            if not os.path.exists(folder_path):
                os.makedirs(folder_path)
            np.save(folder_path + 'metrics.npy', np.array([mae, mse, rmse, mape, mspe, nse, r2]))
            np.save(folder_path + 'pred.npy', preds)
            np.save(folder_path + 'true.npy', trues)
            np.save(folder_path + 'input.npy', inputs)
        return mae, mse, rmse, mape, mspe, r2, nse, pbias, kge, flv, fhv, \
            mae, mse, rmse, mape, mspe, r2, nse, pbias, kge, flv, fhv