import torch
import numpy as np
import argparse
import time

import csv
from torch import nn
import torch.optim as optim

from models import iTransformer, DUETformer, DLinear, STID, GWNET, STAEformer, STOP
from models import SNIPformer

import os
import sys
from functools import partial

import logging
from logging import getLogger
import pickle
import json
from icecream import ic

from utils import dataset_provider, load_aux_data, scheduler, metric, graph_utils



parser = argparse.ArgumentParser()

# load configuration from files
parser.add_argument('--config_file', type=str, default='none')

# the information about the data preprocessing
parser.add_argument('--dataset_name', type=str, default='none')
parser.add_argument('--input_length', type=int, default=-1)
parser.add_argument('--predict_length', type=int, default=-1)
parser.add_argument('--scaler_type', type=str,default='zscore')
parser.add_argument('--slice_size_per_day', type=int, default=-1)
parser.add_argument('--scale_channel_wise', type=int, default=0)

parser.add_argument('--dataset_seed', type=int, default=1)
parser.add_argument('--adj_type', type=str, default='auto')
parser.add_argument('--adj_link_or_distance', type=str, default='distance')

### for SNIP
parser.add_argument('--prompt_type', type=str, default='SNIP')
parser.add_argument('--period_type', type=str, default='day_week')
parser.add_argument('--stci_cal_type', type=str, default='fast')
parser.add_argument('--static_type', type=str, default='Pp_lap_Sl_Ll')
parser.add_argument('--static_func_type', type=str, default='featMLP')
parser.add_argument('--dynamic_func_type', type=str, default='gcn2')
parser.add_argument('--adj_dropout', type=float, default=0.2)

parser.add_argument('--cal_period_pca_emb_dim', type=int, default=24)
parser.add_argument('--cal_pca_emb_dim', type=int, default=8)
parser.add_argument('--cal_lap_emb_dim', type=int, default=8)
parser.add_argument('--use_drop_data_emb_for_snip', type=int, default=1)
parser.add_argument('--use_meta_emb_as_proxy', type=int, default=0)

parser.add_argument('--before_emb_norm_type', type=str, default='S')
parser.add_argument('--emb_fuse_topk', type=int, default=3)
parser.add_argument('--emb_fuse_type', type=str, default='stci')
parser.add_argument('--stage2_emb_type', type=str, default='fuse_period')
parser.add_argument('--finetune_drop', type=int, default=0)


# the hyper-parameter-setting in the model
### for SNIPformer and other models (if applied). -1 means it should be loaded from config json file.
parser.add_argument('--modelid', type=str, default='SNIPformer')
parser.add_argument('--hid_dim', type=int, default=-1) 
parser.add_argument('--n_heads', type=int, default=-1) 
parser.add_argument('--M', type=int, default=-1) # the number of proxy nodes
parser.add_argument('--num_layers', type=int, default=-1) # 
parser.add_argument('--tau', type=int, default=3, help='0:no TCN')
parser.add_argument('--hasTemb', type=int, default=1)
parser.add_argument('--hasRawSemb', type=int, default=0) #### for expanding-node forecasting, original spatial embeddings are not used
parser.add_argument('--hasSTencoder', type=int, default=1)
parser.add_argument('--enc_dropout', type=float, default=0.1) 
parser.add_argument('--att_dropout', type=float, default=0.1)
parser.add_argument('--emb_dropout', type=float, default=0.1)
parser.add_argument('--te_emb_dropout', type=float, default=-1)
parser.add_argument('--se_emb_dropout', type=float, default=-1)
parser.add_argument('--return_att', type=int, default=0)
parser.add_argument('--att_type', type=str, default='proxy')
parser.add_argument('--norm_flag', type=str, default='none')
parser.add_argument('--activation_data', type=str, default='relu')
parser.add_argument('--activation_enc', type=str, default='gelu')
parser.add_argument('--activation_dec', type=str, default='gelu')
parser.add_argument('--revin', type=int, default=0)
parser.add_argument('--revin_type', type=str, default='ST')

#### for GWNET
parser.add_argument('--addaptadj', type=int, default=0)
parser.add_argument('--aptonly', type=int, default=0)
parser.add_argument('--use_SNIP_as_adjadp', type=int, default=0)
parser.add_argument('--addSNIPemb_GWN', type=int, default=0)

#### for DUETformer
parser.add_argument('--noisy_gating', type=int, default=1)
parser.add_argument('--num_experts', type=int, default=4)
parser.add_argument('--DUET_k', type=int, default=3)
parser.add_argument('--CI', type=int, default=1)
parser.add_argument('--moving_avg', type=int, default=3)


# the hyper-parameter-setting of the training process
parser.add_argument('--task_type', type=str, default="expand_node_forecasting")
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--learning_rate', type=float,default=0.001)
parser.add_argument('--weight_decay', type=float,default=0.0001)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--print_every', type=int, default=50)
parser.add_argument('--early_stop', type=bool, default=False)
parser.add_argument('--early_stop_step', type=int, default=20)
parser.add_argument('--lr_decay', type=int, default=1)
parser.add_argument('--huber_delta', type=int, default=2, help='delta in huber loss')
parser.add_argument('--loss_type', type=str, default="huber")
parser.add_argument('--finetune_type', type=str, default="finetune")
parser.add_argument('--eval_mask', type=int, default=0, help='eval_mask; {-1: not mask; 0: exclude zero; x: exclude values below x}')
parser.add_argument('--lr_scheduler_type', type=str, default='cosinelr')
parser.add_argument('--save_epochs', type=str, default='')
parser.add_argument('--save_output', type=int, default=0)
parser.add_argument('--note', type=str, default='')
parser.add_argument('--load_expid', type=str, default=None)
parser.add_argument('--load_note', type=str, default='')


expid = time.strftime("%m%d%H%M%S", time.localtime())



def merge_args(args1, args2):
    add_flag = False
    merged_args = argparse.Namespace(**vars(args1))
    for key, value in vars(args2).items():
        if not hasattr(merged_args, key):
            add_flag = True
            setattr(merged_args, key, value)
    return merged_args, add_flag

def get_logger(log_dir, log_filename, name=None):
    logfilepath = os.path.join(log_dir, log_filename)

    logger = logging.getLogger(name)

    log_level = 'DEBUG'

    if log_level.lower() == 'info':
        level = logging.INFO
    elif log_level.lower() == 'debug':
        level = logging.DEBUG
    elif log_level.lower() == 'error':
        level = logging.ERROR
    elif log_level.lower() == 'warning':
        level = logging.WARNING
    elif log_level.lower() == 'critical':
        level = logging.CRITICAL
    else:
        level = logging.INFO

    logger.setLevel(level)

    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(logfilepath)
    file_handler.setFormatter(formatter)

    console_formatter = logging.Formatter(
        '%(asctime)s - %(levelname)s - %(message)s')
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setFormatter(console_formatter)

    logger.addHandler(file_handler)
    logger.addHandler(console_handler)

    logger.info('Log directory: %s', log_dir)
    return logger


class trainer():
    def __init__(self, args, device, scaler, inverse=True):
        self._logger = getLogger()
        self.args = args
        self.scaler = scaler
        self.device = device
        self.clip = 1
        self.eval_mask = args.eval_mask
        self.inverse = args.inverse
        self.epochs = args.epochs

        self.model = self._build_model(args.modelid)
        self.loss = self._build_loss_func(args.loss_type)
        self.optimizer = optim.AdamW(self.model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
        self.lr_scheduler = self._build_lr_scheduler()

    def _build_model(self, modelid):
        model_dict = {
            'GWNET':GWNET,
            'STID':STID,
            'STAEformer':STAEformer,
            'iTransformer':iTransformer,
            'SNIPformer':SNIPformer,
            'DUETformer':DUETformer,
            'DLinear':DLinear
        }
        self.modelid = modelid
        if modelid == 'STOP':
            base = STOP.MLP(# node_num=node_num,
                input_dim=self.args.in_dim,
                output_dim=self.args.pre_dim,
                num_layer=3, 
                model_dim=64, 
                prompt_dim=32, 
                tod_size=self.args.slice_size_per_day, 
                kernel_size=3)
            model = STOP.Model(self.args, stmodel=base)
        elif modelid in model_dict.keys():
            model = model_dict[modelid].Model(self.args) 
        else:
            self._logger.error(
                f'no model named `{modelid}`, please try again!')
            sys.exit('error')

        for p in model.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        model.to(self.device)

        self._logger.info(model)
        for name, param in model.named_parameters():
            self._logger.info(str(name) + '\t' + str(param.shape) + '\t' +
                              str(param.device) + '\t' + str(param.requires_grad))
        total_para_num, trainable_para_num = get_parameter_number(model)
        self._logger.info(f'total_parameter_num is {total_para_num}, trainable_parameter_num is {trainable_para_num}')   
        return model
    
    def _build_loss_func(self, loss_type):
        if loss_type == 'huber':
            loss_func = partial(metric.huber_loss, delta=self.args.huber_delta)
        elif loss_type == 'masked_huber_0':
            loss_func = partial(metric.masked_huber_loss, delta=self.args.huber_delta, null_val = 0.)
        elif loss_type == 'masked_mae_0':
            loss_func = partial(metric.masked_mae_torch, null_val = 0.)
        elif loss_type == 'masked_huber_eval':
            loss_func = partial(metric.masked_huber_loss, delta=self.args.huber_delta, null_val = 0., mask_val= self.args.eval_mask)
        return loss_func
    
    def _build_lr_scheduler(self):
        self.lr_decay = bool(self.args.lr_decay)
        if self.lr_decay:            
            self.lr_scheduler_type = self.args.lr_scheduler_type
            self.lr_decay_ratio = 0.1
            self.lr_T_max = 30
            self.lr_eta_min = 0
            self.lr_warmup_epoch = 5
            self.lr_warmup_init = 1e-6
            if self.lr_scheduler_type.lower() == 'cosinelr':
                lr_scheduler = scheduler.CosineLRScheduler(
                    self.optimizer, t_initial=self.epochs, lr_min=self.lr_eta_min, decay_rate=self.lr_decay_ratio,
                    warmup_t=self.lr_warmup_epoch, warmup_lr_init=self.lr_warmup_init)
                self._logger.info(f'Use {self.lr_scheduler_type.lower()} lr_scheduler.')
            else:
                lr_scheduler = None
        else:
            lr_scheduler = None
        return lr_scheduler

    def _forward(self, input, stage=None, mode=None):
        output, emb_or_loss = self.model(input, mode=mode)  # output:(B,T,N,C) #real: (B,T,N,C)
        if self.args.inverse:
            predict = self.scaler.inverse_transform(output, stage=stage)
        else:
            predict = output

        return predict, emb_or_loss

    def _metric(self, predict, real):
        filter_num = 1e-5
        if self.eval_mask == -1:
            mae = metric.mae(predict, real).item()
            rmse = metric.rmse(predict, real).item()
            mape = metric.masked_mape_torch(predict, real, 0.0).item()
        elif self.eval_mask == 0:
            mae = metric.masked_mae_torch(predict, real, mask_val=0.0, filter_num=filter_num).item()
            rmse = metric.masked_rmse_torch(predict, real, mask_val=0.0, filter_num=filter_num).item()
            mape = metric.masked_mape_torch(predict, real, mask_val=0.0, filter_num=filter_num).item()
        return mae, mape, rmse

    def train_epoch(self, input, real_val, batches_seen, stage=None, mode=None):  # input(B,T,N,C), real_val(B,T,N,C)
        self.model.train()
        self.optimizer.zero_grad()
        predict, emb = self._forward(input, stage=stage, mode=mode)
        predict_ready = predict
        real_val_ready = real_val

        loss = self.loss(predict_ready, real_val_ready)

        loss.backward()
        if self.clip is not None:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
        self.optimizer.step()
        if self.lr_scheduler is not None:
            if self.lr_scheduler_type.lower() == 'cosinelr':
                self.lr_scheduler.step_update(num_updates=batches_seen)
        mae, mape, rmse = self._metric(predict_ready, real_val_ready)
        return [loss.item(), mae, mape, rmse], emb

    def eval_epoch(self, input, real_val, test_node_set=None, stage=None):
        self.model.eval()
        predict, _ = self._forward(input, stage=stage, mode='val')
        if test_node_set is not None:
            predict_ready = predict[:,:,test_node_set,:]
            real_val_ready = real_val[:,:,test_node_set,:]
        else:
            predict_ready = predict
            real_val_ready = real_val
        loss = self.loss(predict_ready, real_val_ready)
        mae, mape, rmse = self._metric(predict_ready, real_val_ready)
        return loss.item(), mae, mape, rmse


    def training(self, dataloader, aux_data_train, aux_data_val, current_pid, save_path, is_finetune=False, log_tip='Forecasting', stage=None, mode=None, save_epoch_list=None):
        all_start_time = time.time()
        his_loss = []
        trainloss_record = []
        val_time = []
        train_time = []
        best_validate_loss = np.inf
        validate_score_non_decrease_count = 0

        self.log_in_train_details = []
        batches_seen = 0

        for i in range(1, self.args.epochs + 1):
            train_loss = []
            train_mae = []
            train_mape = []
            train_rmse = []
            semb_epoch = []
            hidden_epoch=[]
            t1 = time.time()

            for iter, (x, x_time, y) in enumerate(dataloader['train_loader']):
                trainx = torch.Tensor(x).to(self.device)
                trainy = torch.Tensor(y).to(self.device)
                trainxtime = torch.LongTensor(x_time).to(self.device)
                trainytime = None
                trainx = [trainx, trainxtime, trainytime, aux_data_train]
                metrics, emb = self.train_epoch(trainx, trainy, batches_seen, stage=stage, mode=mode)
                batches_seen += 1
                train_loss.append(metrics[0])
                train_mae.append(metrics[1])
                train_mape.append(metrics[2])
                train_rmse.append(metrics[3])
                if emb is not None:
                    if type(emb) is list: 
                        if emb[0] is not None:
                            semb_epoch.append(emb[0].detach())
                        hidden_epoch.append(emb[1].detach())
                if iter % self.args.print_every == 0:
                    log = 'Iter: {:03d} [{:d}], Train Loss: {:.4f},Train MAE: {:.4f}, Train MAPE: {:.4f}, Train RMSE: {:.4f}'
                    self._logger.info(log.format(
                        iter, batches_seen, train_loss[-1], train_mae[-1], train_mape[-1], train_rmse[-1]))
            t2 = time.time()
            train_time.append(t2 - t1)
            
            # validation
            valid_loss = []
            valid_mae = []
            valid_mape = []
            valid_rmse = []

            s1 = time.time()
            
            for iter, (x, x_time, y) in enumerate(dataloader['val_loader']):
                
                valx = torch.Tensor(x).to(self.device)
                valy = torch.Tensor(y).to(self.device)

                valxtime = torch.LongTensor(x_time).to(self.device)
                valytime = None

                valx = [valx, valxtime, valytime, aux_data_val]

                metrics = self.eval_epoch(valx, valy, stage=stage)
                valid_loss.append(metrics[0])
                valid_mae.append(metrics[1])
                valid_mape.append(metrics[2])
                valid_rmse.append(metrics[3])

            s2 = time.time()
            log = 'Epoch: {:03d}, Inference Time: {:.4f} secs'
            self._logger.info(log.format(i, (s2 - s1)))
            val_time.append(s2 - s1)
            mtrain_loss = np.mean(train_loss)
            mtrain_mae = np.mean(train_mae)
            mtrain_mape = np.mean(train_mape)
            mtrain_rmse = np.mean(train_rmse)

            mvalid_loss = np.mean(valid_loss)
            mvalid_mae = np.mean(valid_mae)
            mvalid_mape = np.mean(valid_mape)
            mvalid_rmse = np.mean(valid_rmse)

            if self.lr_scheduler is not None:
                if self.lr_scheduler_type.lower() == 'reducelronplateau':
                    self.lr_scheduler.step(mvalid_loss)
                elif self.lr_scheduler_type.lower() == 'cosinelr':
                    self.lr_scheduler.step(i)
                else:
                    self.lr_scheduler.step()

            his_loss.append(mvalid_loss)
            trainloss_record.append(
                [mtrain_loss, mtrain_mae, mtrain_mape, mtrain_mape, mvalid_loss, mvalid_mae, mvalid_mape, mvalid_rmse])
            log = 'Epoch: {:03d}, Train Loss: {:.4f}, Train MAE: {:.4f} Train MAPE: {:.4f}, Train RMSE: {:.4f}, '
            log += 'Valid Loss: {:.4f}, Valid MAE: {:.4f}, Valid MAPE: {:.4f}, Valid RMSE: {:.4f}, Training Time: {:.4f}/epoch'
            self._logger.info(log.format(i, mtrain_loss, mtrain_mae, mtrain_mape, mtrain_rmse, mvalid_loss, mvalid_mae, mvalid_mape,
                                mvalid_rmse, (t2 - t1)))
            self.log_in_train_details.append([i, mtrain_loss, mtrain_mae, mtrain_mape, mtrain_rmse, mvalid_loss, mvalid_mae, mvalid_mape,
                                        mvalid_rmse, (t2 - t1)])

            if best_validate_loss > mvalid_loss:
                best_validate_loss = mvalid_loss
                validate_score_non_decrease_count = 0
                if is_finetune:
                    torch.save(self.model.state_dict(), save_path + "finetune_best.pth")
                else:
                    torch.save(self.model.state_dict(), save_path + "best.pth")

                self._logger.info('got best validation result: {:.4f}, {:.4f}, {:.4f}'.format(
                    mvalid_loss, mvalid_mape, mvalid_rmse))
                
            else:
                if i >=50: 
                    validate_score_non_decrease_count += 1

            if save_epoch_list is not None:
                if i in save_epoch_list:
                    torch.save(self.model.state_dict(), save_path+f'epoch{i}.pth')
 


            self._logger.info(f'---- {expid}_{self.args.modelid}_{self.args.dataset_name}_{self.args.note} ----')
            
            if self.args.early_stop and validate_score_non_decrease_count >= self.args.early_stop_step:
                break
        


        avg_train_time = np.mean(train_time)
        avg_inference_time = np.mean(val_time)
        self._logger.info(
            "Average Training Time: {:.4f} secs/epoch".format(avg_train_time))
        self._logger.info("Average Valid Inference Time: {:.4f} secs".format(avg_inference_time))

        training_time = (time.time() - all_start_time) / 60

        bestid = np.argmin(his_loss)
        self._logger.info("Training finished")
        self._logger.info(
            f"The valid loss on best model is {str(round(his_loss[bestid], 4))}")
        train_metric = trainloss_record[bestid][0:4]
        valid_metric = trainloss_record[bestid][4:]

        try:
            with open(save_path + 'train_record.csv', 'a+') as f:
                fcsv = csv.writer(f)
                fcsv.writerow(['bestid',bestid,'train_loss_mae_mape_rmse']+ train_metric +['valid_loss_mae_mape_rmse'] + valid_metric)
        except:
            print('error in save train_record')

        return training_time, train_metric, valid_metric

def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_num, trainable_num



def testing(args, model, test_dataloader, device, aux_data_test, scaler, pre_dim, logger, result_tip='Forcasting', node_idx_list1=None, node_idx_list2=None, stage=None,
            train_metric=None, valid_metric=None):
    outputs = []
    realy_list = []
    y_valid_mask_list =[]
    inference_time = 0
    inverse = args.inverse
    model.eval()

    for iter, group in enumerate(test_dataloader):
        if len(group) == 4:
            x, x_time, y, y_valid_mask = group
            y_valid_mask_list.append(y_valid_mask)
        elif len(group) == 3:
            x, x_time, y = group
        testx = torch.Tensor(x).to(device)
        testy = torch.Tensor(y).to(device)
        testxtime = torch.LongTensor(x_time).to(device)
        testytime = None
        testx = [testx, testxtime, testytime, aux_data_test]
        with torch.no_grad():
            t1 = time.time()
            preds, _ = model(testx, mode='test')
            inference_time += time.time()-t1
        outputs.append(preds)
        realy_list.append(testy)

    yhat = torch.cat(outputs, dim=0)
    realy = torch.cat(realy_list, dim=0)
    if len(y_valid_mask_list)>0:
        y_valid_mask = torch.cat(y_valid_mask_list, dim=0)
        assert y_valid_mask.shape == realy.shape
    else:
        y_valid_mask = None
    print('yhat.shape:',yhat.shape)
    print('realy.shape:', realy.shape)
    print('inference_time:', inference_time)

    if args.predict_length == 0:
        ready_length = args.input_length
    else:
        ready_length = args.predict_length

    all_amae, all_amape, all_armse, all_amre, all_ar2 = [], [], [], [], [] # all results of single time step
    MAE_list, MAPE_list, RMSE_list, MRE_list, R2_list = [], [], [], [], [] # average results of all time steps 
    MAE_list_sub1, MAPE_list_sub1, RMSE_list_sub1, MRE_list_sub1, R2_list_sub1 = [], [], [], [], [] # average results of all time steps 
    MAE_list_sub2, MAPE_list_sub2, RMSE_list_sub2, MRE_list_sub2, R2_list_sub2 = [], [], [], [], [] # average results of all time steps 

    filter_num = 1e-5
    for feature_idx in range(pre_dim):
        amae, amape, armse, amre, ar2 = [], [], [], [], []
        amae_sub1, amape_sub1, armse_sub1, amre_sub1, ar2_sub1 = [], [], [], [], []
        amae_sub2, amape_sub2, armse_sub2, amre_sub2, ar2_sub2 = [], [], [], [], []
        pred_feature = yhat[..., feature_idx]
        real_feature = realy[..., feature_idx]
        if y_valid_mask is not None:
           y_valid_mask_feature = y_valid_mask[..., feature_idx]

        for i in range(ready_length):
            if inverse:
                print('testing inverse')
                pred = scaler.inverse_transform(pred_feature[:, i], stage=stage)
            else:
                pred = pred_feature[:, i]
            real = real_feature[:, i]


            if y_valid_mask is not None:
                y_valid_mask_feature_ready = y_valid_mask_feature[:, i].to(pred.device)
            else:
                y_valid_mask_feature_ready = None
            metrics = metric.metric_torch(pred, real, mask_val=args.eval_mask, valid_mask=y_valid_mask_feature_ready, filter_num=filter_num)
            log = 'Evaluate best model on test data for [dim{:d}] horizon {:d}, Test MAE: {:.4f}, Test MAPE: {:.4f}, Test RMSE: {:.4f}'
            logger.info(log.format(feature_idx, i + 1, metrics[0], metrics[1], metrics[2]))
            amae.append(metrics[0])
            amape.append(metrics[1])
            armse.append(metrics[2])
            amre.append(metrics[3])
            ar2.append(metrics[4])

            if node_idx_list1 is not None and node_idx_list2 is not None:
                pred_subset1 = pred[:,node_idx_list1]
                pred_subset2 = pred[:,node_idx_list2]
                real_subset1 = real[:,node_idx_list1]
                real_subset2 = real[:,node_idx_list2]
                metrics_subset1 = metric.metric_torch(pred_subset1, real_subset1, mask_val=args.eval_mask,filter_num=filter_num)
                metrics_subset2 = metric.metric_torch(pred_subset2, real_subset2, mask_val=args.eval_mask,filter_num=filter_num)
                amae_sub1.append(metrics_subset1[0])
                amape_sub1.append(metrics_subset1[1])
                armse_sub1.append(metrics_subset1[2])
                amre_sub1.append(metrics_subset1[3])
                ar2_sub1.append(metrics_subset1[4])
                amae_sub2.append(metrics_subset2[0])
                amape_sub2.append(metrics_subset2[1])
                armse_sub2.append(metrics_subset2[2])
                amre_sub2.append(metrics_subset2[3])
                ar2_sub2.append(metrics_subset2[4])


        log = '[dim{:d}] On average over {:d} horizons, Test MAE: {:.4f}, Test MAPE: {:.4f}, Test RMSE: {:.4f} Test MRE: {:.4f} Test R2: {:.4f}'
        MAE = np.mean(amae)
        MAPE = np.mean(amape)
        RMSE = np.mean(armse)
        MRE = np.mean(amre)
        R2 = np.mean(ar2)

        logger.info(log.format(feature_idx, ready_length, MAE, MAPE, RMSE, MRE, R2))
        MAE_list.append(MAE)
        MAPE_list.append(MAPE)
        RMSE_list.append(RMSE)
        MRE_list.append(MRE)
        R2_list.append(R2)


        all_amae.append(amae)
        all_amape.append(amape)
        all_armse.append(armse)
        all_amre.append(amre)
        all_ar2.append(ar2)

        if node_idx_list1 is not None and node_idx_list2 is not None:
            MAE_sub1 = np.mean(amae_sub1)
            MAPE_sub1 = np.mean(amape_sub1)
            RMSE_sub1 = np.mean(armse_sub1)
            MRE_sub1 = np.mean(amre_sub1)
            R2_sub1 = np.mean(ar2_sub1)
            MAE_sub2 = np.mean(amae_sub2)
            MAPE_sub2 = np.mean(amape_sub2)
            RMSE_sub2 = np.mean(armse_sub2)
            MRE_sub2 = np.mean(amre_sub2)
            R2_sub2 = np.mean(ar2_sub2)

            MAE_list_sub1.append(MAE_sub1)
            MAPE_list_sub1.append(MAPE_sub1)
            RMSE_list_sub1.append(RMSE_sub1)
            MRE_list_sub1.append(MRE_sub1)
            R2_list_sub1.append(R2_sub1)
            MAE_list_sub2.append(MAE_sub2)
            MAPE_list_sub2.append(MAPE_sub2)
            RMSE_list_sub2.append(RMSE_sub2)
            MRE_list_sub2.append(MRE_sub2)
            R2_list_sub2.append(R2_sub2)

    if not os.path.exists(f'./EXP_{result_tip}_results/{args.modelid}/'):
        os.makedirs(f'./EXP_{result_tip}_results/{args.modelid}/')
        
    result_csv_path = f'./EXP_{result_tip}_results/{args.modelid}/{args.modelid}_{args.dataset_name}_{args.input_length}to{args.predict_length}_Forecasting_Results.csv'

    with open('Inference_time.csv','a+') as f:
        fcsv = csv.writer(f)
        fcsv.writerow([expid, args.dataset_name, args.note, inference_time])

    with open(result_csv_path, 'a+', newline='')as f0:
        f_csv = csv.writer(f0)
        if args.predict_length == 0:
            row = [expid, args.dataset_name, args.note,args.hid_dim, 
                   args.period_type, args.static_type, args.static_func_type, args.dynamic_func_type, 
                   args.mask_node_rate, args.remain_rate_type,
                   'test']
        else:
            row = [expid, args.dataset_name, args.note, args.hid_dim, 'test']
        if pre_dim >1:
            row.extend([f'avg', np.array(MAE_list).mean(),np.array(MAPE_list).mean()*100,np.array(RMSE_list).mean(),
                        np.array(MRE_list).mean(),np.array(R2_list).mean()
                        ])
        for feature_idx in range(pre_dim):
            row.extend([f'dim-{feature_idx}',MAE_list[feature_idx], MAPE_list[feature_idx]
                       * 100, RMSE_list[feature_idx], MRE_list[feature_idx], R2_list[feature_idx]])
            if node_idx_list1 is not None and node_idx_list2 is not None:
                row.extend([f'subset1-dim{feature_idx}',MAE_list_sub1[feature_idx], MAPE_list_sub1[feature_idx]
                       * 100, RMSE_list_sub1[feature_idx], MRE_list_sub1[feature_idx], R2_list_sub1[feature_idx]])
                row.extend([f'subset2-dim{feature_idx}',MAE_list_sub2[feature_idx], MAPE_list_sub2[feature_idx]
                       * 100, RMSE_list_sub2[feature_idx], MRE_list_sub2[feature_idx], R2_list_sub2[feature_idx]])
                
        if args.predict_length==12:
            row.extend(['horizon-3',all_amae[0][2],all_amape[0][2],all_armse[0][2],all_amre[0][2],all_ar2[0][2],
                        'horizon-6',all_amae[0][5],all_amape[0][5],all_armse[0][5],all_amre[0][5],all_ar2[0][5],
                        'horizon-12',all_amae[0][11],all_amape[0][11],all_armse[0][11],all_amre[0][11],all_ar2[0][11],
                        ])
        try:
            if train_metric is not None and valid_metric is not None:
                row.append('train_loss_mae_mape_rmse')
                row.extend(train_metric)
                row.append('valid_loss_mae_mape_rmse')
                row.extend(valid_metric)
        except:
            print('error in save train record results')
        f_csv.writerow(row)

    logger.info('log and results saved.')


def main():
    current_pid = os.getpid()
    args = parser.parse_args()

    if args.load_expid is None:
        expid = time.strftime("%m%d%H%M%S", time.localtime())
    else:
        expid = args.load_expid

    device = torch.device(args.device)

    args.cache_data_path = 'datasets'
    if len(args.save_epochs)>0:
        save_epoch_list = args.save_epochs.split('_')
        save_epoch_list = [int(it) for it in save_epoch_list]
    else:
        save_epoch_list = None
    
    if os.path.exists(f'./configurations/{args.config_file}'):
        with open(f'./configurations/{args.config_file}', 'r') as f:
            x = json.load(f)
            data_config = x['data_config']
            model_config = x['model_config']
            training_config = x['training_config']

            args.dataset_name =data_config['dataset_name']
            args.slice_size_per_day = data_config['slice_size_per_day']
            args.in_dim = data_config['in_dim']
            args.pre_dim = data_config['pre_dim']
            args.input_length = data_config['input_length']
            args.predict_length = data_config['predict_length']
            args.eval_mask = training_config['eval_mask']
            args.stage2_emb_type = data_config['stage2_emb_type']
            
            args.M = model_config['M']
            args.hid_dim = model_config['hid_dim'] 
            args.n_heads = model_config['n_heads']
            args.num_layers = model_config['num_layers']
            args.se_emb_dropout = model_config['se_emb_dropout']
            args.te_emb_dropout = model_config['te_emb_dropout']

            args.loss_type = training_config['loss_type']
            
    else:
        print('there is no config file!')
    
    
    # se_drop_dict ={'PEMS04':0.1,'EPeMS':0.1,'nrel_al':0.7,'AQI437':0.7,'SeaLoop':0.1}
    # te_drop_dict ={'PEMS04':0.5,'EPeMS':0.7,'nrel_al':0.7,'AQI437':0.1,'SeaLoop':0.9}
    
    # if args.dataset_name == 'nrel_al':
    #     args.hid_dim = 32
    #     args.stage2_emb_type = 'fuse_period_stci' 
    #     args.num_layers = 1

    # if args.se_emb_dropout < 0:
    #     args.se_emb_dropout = se_drop_dict[args.dataset_name]
    # if args.te_emb_dropout < 0:
    #     args.te_emb_dropout = te_drop_dict[args.dataset_name]

    args.norm_y = 0
    args.inverse = True

    suffix = '' if args.note =='' else f'_{args.note}'
    args.note = f'bs{args.batch_size}_d{args.hid_dim}_L{args.num_layers}{suffix}'
    save_path = f'./EXP_experiments/{args.modelid}/{args.dataset_name}/{args.modelid}_{args.dataset_name}_{args.input_length}to{args.predict_length}_exp' + str(
        expid) + "/"
    log_path = save_path

    if args.load_expid is None:
        log_filename = f'{args.modelid}_{args.dataset_name}_{args.input_length}to{args.predict_length}_exp{expid}.log'
    else:
        current = time.strftime("%m%d%H%M", time.localtime())
        log_filename = f'{args.modelid}_{args.dataset_name}_{args.input_length}to{args.predict_length}_exp{expid}_load_ON{current}.log'
    
    if not os.path.exists(log_path):
        os.makedirs(log_path)

    logger = get_logger(log_path, log_filename)
    exp_name = f'{args.modelid}_{args.dataset_name}'
    logger.info(exp_name)

    # if args.dataset_name in ['SeaLoop','nrel_al']:
    #     args.loss_type= 'masked_mae_0'

    print('args.slice_size_per_day*7:', args.slice_size_per_day*7)
    


    full_data, full_time, full_df_time, full_valid_mask = dataset_provider.load_full_data_time(args.dataset_name,1)
    
    stage_division = dataset_provider.load_stage_division(args.dataset_name, data_length=full_data.shape[0], df_time=full_df_time,
                                                        slices_per_week= args.slice_size_per_day*7)
    node_division = dataset_provider.load_node_division(args.dataset_name, all_num_nodes=full_data.shape[1], 
                                                        seed=args.dataset_seed, new_add_rate=0.2, retire_rate=0.05)
    stage1_start_idx, stage2_start_idx, test_start_idx =  stage_division
    stage1_node_idx, stage2_node_idx, newadd_set, remain_set, retire_set = node_division
    
    A_link, A_Distance, full_distance = graph_utils.load_adj_matrix(args.dataset_name)
    full_adj = A_link if args.adj_link_or_distance == 'link' else A_Distance

    if args.task_type == 'transductive_forecasting':

        train_indices = list(range(0,stage2_start_idx))
        full_train_data = full_data[train_indices,:,0]
        full_train_df_time = full_df_time.iloc[train_indices]
        args.num_nodes = full_data.shape[1]
        
        remain_node_idx_in_stage2 = []
        newadd_node_idx_in_stage2 = []
        for idx,node in enumerate(stage2_node_idx):
            if node in stage1_node_idx:
                remain_node_idx_in_stage2.append(idx)
            else:
                newadd_node_idx_in_stage2.append(idx)

        full_dataloader, scaler, full_train_indices = dataset_provider.load_stage_dataloader(args.dataset_name, full_data, full_time, stage2_start_idx, test_start_idx, None,None,'full',
                                                                                                 scaler=None,
                                                                                                 batch_size=args.batch_size, valid_mask =full_valid_mask, 
                                                                                                 input_length=args.input_length, predict_length=args.predict_length,
                                                                                                 channel_wise_norm=bool(args.scale_channel_wise), norm_y=bool(args.norm_y), scaler_type=args.scaler_type
                                                                                                 )
        full_train_data = scaler.transform(full_train_data, stage='full')

        full_load_tip ='full'
        if args.before_emb_norm_type!= 'none':
            full_load_tip = f'{full_load_tip}_{args.before_emb_norm_type}norm'

        full_node_period_feats = load_aux_data.load_init_period_features(args.dataset_name, args.cal_period_pca_emb_dim, 
                                                                 args.slice_size_per_day, args.period_type, 
                                                                 load_tip=full_load_tip, full_train_data=full_train_data, df_time=full_train_df_time, target_node_idx = None,norm_type = args.before_emb_norm_type,
                                                                 )
        full_node_delay_matrix, full_node_csd_matrix = load_aux_data.load_init_stci_matrix(args.dataset_name, args.cal_lap_emb_dim, args.input_length,
                                                                              load_tip=full_load_tip, full_train_data=full_train_data,norm_type = args.before_emb_norm_type)
        
        full_node_set = np.arange(full_node_period_feats.shape[0])
        if full_adj is None:
            A_link, A_Distance, full_distance = graph_utils.get_corr_adj(full_train_data)            
            full_adj = A_link if args.adj_link_or_distance == 'link' else A_Distance
        aux_data_full, emb_dim, static_feats_dim_list = load_aux_data.build_aux_data_from_all(full_node_set, full_node_set, full_node_set, full_adj,
                                              full_node_period_feats, full_node_delay_matrix, full_node_csd_matrix,
                                              args.cal_lap_emb_dim, device, args.static_type, args.period_type)         
        args.static_feats_dim_list = static_feats_dim_list
        ic(args.static_feats_dim_list)
        args.ref_emb_dim = emb_dim   
        args.support_len =len(aux_data_full['adj_mx'])
        train_scaler = scaler

        if args.load_expid is None:
            with open(f'{save_path}/config.pkl','wb') as f:
                pickle.dump(args,f)
            with open(f'{save_path}/config.json','w') as f:
                json.dump(vars(args), f)
        else:
            load_expid = args.load_expid
            task_type = args.task_type
            with open(f'{save_path}/config.pkl','rb') as f:
                ori_args = pickle.load(f)
            args, add_flag = merge_args(ori_args, args)
            if add_flag:
                with open(f'{save_path}/config_newload_{args.load_note}.pkl','wb') as f:
                    pickle.dump(args,f)
                with open(f'{save_path}/config_newload_{args.load_note}.json','w') as f:
                    json.dump(vars(args), f)
            args.load_expid =load_expid 
            args.task_type = task_type

        if args.modelid == 'GWNET':
            args.support = aux_data_full['adj_mx']
    else:
        remain_node_idx_in_stage2 = []
        newadd_node_idx_in_stage2 = []
        for idx,node in enumerate(stage2_node_idx):
            if node in stage1_node_idx:
                remain_node_idx_in_stage2.append(idx)
            else:
                newadd_node_idx_in_stage2.append(idx)

        stage1_data = full_data[stage1_start_idx:stage2_start_idx]
        stage2_data = full_data[stage2_start_idx:]
        stage1_time = full_time[stage1_start_idx:stage2_start_idx]
        stage2_time = full_time[stage2_start_idx:]

        stage1_data = stage1_data[:,stage1_node_idx]
        stage2_data = stage2_data[:,stage2_node_idx]
        
        stage1_indices = list(range(stage1_start_idx, stage2_start_idx))
        stage1_df_time = full_df_time.iloc[stage1_indices]
        stage2_indices = list(range(stage2_start_idx, full_data.shape[0]))
        stage2_df_time = full_df_time.iloc[stage2_indices]

        all_mean, all_std, stage1_mean, stage1_std, stage2_mean, stage2_std = dataset_provider.get_full_scaler_info(full_data[...,0], stage2_start_idx, test_start_idx, remain_set, retire_set, newadd_set, bool(args.scale_channel_wise)) 


        
        stage1_dataloader, stage1_scaler, stage1_train_indices = dataset_provider.load_stage_dataloader(args.dataset_name, full_data, full_time, stage2_start_idx, test_start_idx, 
                                                                                                 stage1_node_idx, stage2_node_idx,'stage1', scaler=None,
                                                                                                 batch_size=args.batch_size, valid_mask = full_valid_mask, 
                                                                                                 input_length=args.input_length, predict_length=args.predict_length,
                                                                                                 channel_wise_norm=bool(args.scale_channel_wise), norm_y=bool(args.norm_y),
                                                                                                 scaler_mean = all_mean, scaler_std = all_std, scaler_type=args.scaler_type
                                                                                                 )

        stage1_train_data = stage1_data[stage1_train_indices,:,0]
        stage1_train_df_time = stage1_df_time.iloc[stage1_train_indices]
        stage1_train_data = stage1_scaler.transform(stage1_train_data, stage='stage1')

        if args.dataset_name in ['EPeMS']:
            stage1_load_tip = f'stage1'
        else:
            stage1_load_tip = f'stage1_seed{args.dataset_seed}_known0.8'

        if args.before_emb_norm_type!= 'none':
            stage1_load_tip = f'{stage1_load_tip}_{args.before_emb_norm_type}norm'

        adj_type = 'doubletransition' if args.modelid=='GWNET' else 'auto'

        stage1_node_period_feats = load_aux_data.load_init_period_features(args.dataset_name, args.cal_period_pca_emb_dim, 
                                                                args.slice_size_per_day, args.period_type, 
                                                                load_tip=stage1_load_tip, full_train_data=stage1_data[...,0], df_time=stage1_df_time, target_node_idx = None,
                                                                norm_type = args.before_emb_norm_type,)
        stage1_node_delay_matrix, stage1_node_csd_matrix = load_aux_data.load_init_stci_matrix(args.dataset_name, args.cal_lap_emb_dim, args.input_length,
                                                                            load_tip=stage1_load_tip,  full_train_data=stage1_data[...,0],
                                                                            norm_type = args.before_emb_norm_type)

        args.num_nodes = stage1_data.shape[1]
        remain_idx = np.arange(stage1_data.shape[1])
        if full_adj is None:
            A_link, A_Distance, full_distance = graph_utils.get_corr_adj(stage1_train_data)
            stage1_adj = A_link if args.adj_link_or_distance == 'link' else A_Distance
            print('stage1_adj.shape:',stage1_adj.shape, 'remain_idx.shape:', remain_idx.shape)
        else:
            stage1_adj = full_adj[stage1_node_idx, :][:, stage1_node_idx]

        aux_data_stage1, emb_dim, static_feats_dim_list = load_aux_data.build_aux_data_from_all(remain_idx, remain_idx, remain_idx, stage1_adj,
                                              stage1_node_period_feats, stage1_node_delay_matrix, stage1_node_csd_matrix,
                                              args.cal_lap_emb_dim, device, args.static_type, args.period_type, adj_type = adj_type)      
        args.ref_emb_dim = emb_dim
        args.static_feats_dim_list = static_feats_dim_list
        ic(args.static_feats_dim_list)
        if args.load_expid is None:
            with open(f'{save_path}/config.pkl','wb') as f:
                pickle.dump(args,f)
            with open(f'{save_path}/config.json','w') as f:
                json.dump(vars(args), f)
        else:
            with open(f'{save_path}/config.pkl','rb') as f:
                ori_args = pickle.load(f)
            load_expid = args.load_expid
            task_type = args.task_type
            args, add_flag = merge_args(ori_args, args)
            if add_flag:
                with open(f'{save_path}/config_newload_{args.load_note}.pkl','wb') as f:
                    pickle.dump(args,f)
                with open(f'{save_path}/config_newload_{args.load_note}.json','w') as f:
                    json.dump(vars(args), f)
            args.load_expid = load_expid
            args.task_type = task_type
            
        if args.dataset_name in ['EPeMS']:
            stage2_known_indices = list(range(0, args.slice_size_per_day*5))
        else:
            stage2_known_indices = list(range(0, args.slice_size_per_day*7))
            
        stage2_train_data = stage2_data[stage2_known_indices,:,0]
        stage2_train_df_time = stage2_df_time.iloc[stage2_known_indices]
       

        ic(stage2_train_data.shape)
        ic(stage2_train_df_time.shape)
        remain_idx2 = np.arange(stage2_data.shape[1])
        if full_adj is None:
            A_link, A_Distance, full_distance = graph_utils.get_corr_adj(stage2_train_data)
            stage2_adj = A_link if args.adj_link_or_distance == 'link' else A_Distance
        else:
            stage2_adj = full_adj[stage2_node_idx, :][:, stage2_node_idx]

        
        print('stage2_adj has nan:',np.isnan(stage2_adj).any())

        if args.dataset_name in ['EPeMS']:
            stage2_load_tip = 'stage2'
        else:
            stage2_load_tip = f'stage2_seed{args.dataset_seed}_known0.8_retire0.05_add0.2'
        
        if args.before_emb_norm_type!= 'none':
            stage2_load_tip = f'{stage2_load_tip}_{args.before_emb_norm_type}norm'

        if args.task_type == 'stage2_transductive_forecasting':
            args.num_nodes = stage2_data.shape[1]
            stage2_dataloader, stage2_scaler, stage2_finetune_indices = dataset_provider.load_stage_dataloader(args.dataset_name, full_data, full_time, stage2_start_idx, test_start_idx, 
                                                                                            stage1_node_idx, stage2_node_idx,'stage2', scaler=None,
                                                                                            batch_size=args.batch_size, valid_mask = full_valid_mask, 
                                                                                            input_length=args.input_length, predict_length=args.predict_length,
                                                                                            channel_wise_norm=bool(args.scale_channel_wise), norm_y=bool(args.norm_y),
                                                                                            scaler_mean = stage2_mean, scaler_std = stage2_std, scaler_type=args.scaler_type
                                                                                            )

            stage2_train_data = stage2_scaler.transform(stage2_train_data, stage='stage2')
            stage2_node_period_feats = load_aux_data.load_init_period_features(args.dataset_name, args.cal_period_pca_emb_dim, 
                                                                 args.slice_size_per_day, args.period_type, 
                                                                 load_tip=stage2_load_tip, full_train_data=stage2_train_data, df_time=stage2_train_df_time, target_node_idx = None,
                                                                 norm_type = args.before_emb_norm_type)
        
        else:
            input_stage2_scaler = stage1_scaler
            stage2_dataloader, stage2_scaler, stage2_finetune_indices = dataset_provider.load_stage_dataloader(args.dataset_name, full_data, full_time, stage2_start_idx, test_start_idx, 
                                                                                        stage1_node_idx, stage2_node_idx,'stage2', scaler=input_stage2_scaler,
                                                                                        batch_size=args.batch_size, valid_mask = full_valid_mask, 
                                                                                        input_length=args.input_length, predict_length=args.predict_length,
                                                                                        channel_wise_norm=bool(args.scale_channel_wise), norm_y=bool(args.norm_y),
                                                                                        scaler_mean = all_mean, scaler_std = all_std, scaler_type=args.scaler_type
                                                                                        )
            stage2_train_data = stage2_scaler.transform(stage2_train_data, stage='stage2')
            stage2_node_period_feats = load_aux_data.load_init_period_features(args.dataset_name, args.cal_period_pca_emb_dim, 
                                                        args.slice_size_per_day, args.period_type, 
                                                        load_tip=stage2_load_tip, full_train_data=stage2_train_data, df_time=stage2_train_df_time, target_node_idx = None,
                                                        norm_type = args.before_emb_norm_type, ignore_full_weeks= args.dataset_name in ['EPeMS','EElectricy','EWeather'])


        
        stage2_node_delay_matrix, stage2_node_csd_matrix = load_aux_data.load_init_stci_matrix(args.dataset_name, args.cal_lap_emb_dim, args.input_length,
                                                                              load_tip=stage2_load_tip, full_train_data=stage2_train_data,
                                                                              norm_type = args.before_emb_norm_type)
        
        aux_data_stage2, emb_dim, static_feats_dim_list = load_aux_data.build_aux_data_from_all(remain_idx2, remain_idx2, remain_idx2, stage2_adj,
                                            stage2_node_period_feats, stage2_node_delay_matrix, stage2_node_csd_matrix,
                                            args.cal_lap_emb_dim, device, args.static_type, args.period_type, adj_type=adj_type
                                            )
        fused_stage2_period_feats, fused_stage2_stci_feats, fused_stage2_structure_feats = load_aux_data.build_aux_data_from_stage1(stage2_train_data, aux_data_stage1, stage1_node_idx, stage2_node_idx,stage2_node_csd_matrix,device, args.emb_fuse_topk, fuse_type = args.emb_fuse_type)

        
        if 'period' in args.stage2_emb_type:
            aux_data_stage2['period_feats'] = fused_stage2_period_feats
        if 'stci' in args.stage2_emb_type:
            aux_data_stage2['stci_feats'] = fused_stage2_stci_feats
        if 'structure' in args.stage2_emb_type:
            aux_data_stage2['structure_feats'] = fused_stage2_structure_feats

        args.support_len = len(aux_data_stage2['adj_mx'])
        assert args.support_len == len(aux_data_stage1['adj_mx'])

        if args.task_type == 'stage2_transductive_forecasting':
            train_scaler = stage2_scaler
        else:
            train_scaler = stage1_scaler


   



    torch.cuda.empty_cache()
            
    logger.info(args)
    logger.info(args.note)

    args.note = f'{args.note}_{args.load_note}'
    logger.info("start training...")
    if args.task_type == 'transductive_forecasting':
        if args.modelid.startswith('Demo'): args.aux_data_train = aux_data_full
        engine = trainer(scaler= train_scaler, device=device, args=args, inverse=args.inverse)
    
        if args.load_expid is None:
            training_time, train_metric, valid_metric = engine.training(full_dataloader, aux_data_full, aux_data_full, current_pid, save_path, log_tip='TransductiveForecasting',stage='full', save_epoch_list=save_epoch_list)

        logger.info("start testing...")
        testing(args, engine.model, full_dataloader['test_loader'], device, aux_data_full, scaler, args.pre_dim, logger, result_tip ='TransductiveForecasting',stage='full')
   
    elif args.task_type == 'stage2_transductive_forecasting':
        if args.modelid.startswith('Demo'): args.aux_data_train = aux_data_stage2
        
        engine = trainer(scaler= train_scaler, device=device, args=args, inverse=args.inverse)
    
        if args.load_expid is None:
            training_time, train_metric, valid_metric = engine.training(stage2_dataloader, aux_data_stage2, aux_data_stage2, current_pid, save_path, log_tip='stage2_TransductiveForecasting',stage='stage2', mode='finetune')

        logger.info("start testing...")
        testing(args, engine.model, stage2_dataloader['test_loader'], device, aux_data_stage2, stage2_scaler, args.pre_dim, logger, result_tip ='Stage2_TransductiveForecasting',stage='stage2')
    
    
    elif args.task_type == 'expand_node_forecasting':
        if args.load_expid is None:
            engine = trainer(scaler= train_scaler, device=device, args=args, inverse=args.inverse)
    
            engine.training(stage1_dataloader, aux_data_stage1, aux_data_stage1, current_pid, save_path, log_tip='ExpandNodeForecasting_stage1_tdc',stage='stage1',mode='train')

            args.num_nodes = stage2_data.shape[1]
            args.support_len =len(aux_data_stage2['adj_mx'])
            
            logger.info("start stage2 finetuning...")
            finetune_time, finetune_metric, finetune_metric = engine.training(stage2_dataloader, aux_data_stage2, aux_data_stage2, current_pid, save_path, is_finetune=True, log_tip='stage2_FineTuning', stage='stage2',mode='finetune')
            logger.info(f'finetune time :{finetune_time}')
            stage2_result_tip = 'stage2_FinetuneForecasting'

            logger.info("start stage2 testing...")
            test_scaler = stage1_scaler
            testing(args, engine.model, stage2_dataloader['test_loader'], device, aux_data_stage2, test_scaler, args.pre_dim, logger, result_tip = stage2_result_tip,
                    node_idx_list1=remain_node_idx_in_stage2, node_idx_list2= newadd_node_idx_in_stage2, stage='stage2')

    

if __name__ == "__main__":
    t1 = time.time()
    main()
    t2 = time.time()
    print("Total time spent: {:.4f}".format(t2 - t1))
