import argparse
import torch
import random
import numpy as np
import optuna
import os

from utils.dataloader import load_data
from trainer import Trainer

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
ROOT_PATH = 'D:/WorkSpace_Python/GRformer'

def objective(trial, args):
    # 完全固定随机种子
    fix_seed = args.random_seed
    random.seed(fix_seed)
    torch.manual_seed(fix_seed)
    np.random.seed(fix_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # 参数学习设置
    # args.batch_size = trial.suggest_categorical("batch_size", [32])
    # args.learning_rate = trial.suggest_float("learning_rate", 0.00001, 0.00016, step=0.000007)
    # args.fc_dropout = trial.suggest_float("fc_dropout", 0.1, 0.3, step=0.1)
    # args.dropout = trial.suggest_float("dropout", 0.2, 0.4, step=0.05)
    # args.activation = trial.suggest_categorical("activation", ['relu', 'gelu'])
    # args.prop_alpha = trial.suggest_float("props_alpha", 0.03, 0.15, step=0.02)
    # args.subgraph_size = trial.suggest_categorical("subgraph_size", [1, 4])
    # args.learning_rate = trial.suggest_categorical("learning_rate", [0.000076, 0.000091])
    # args.gcn_depth = trial.suggest_int("gcn_depth", 1, 2)
    print('Args in experiment:')
    print(args)

    if args.use_gcn:
        topk=f'_topk{args.subgraph_size}'
        prop='_prop{:.2f}'.format(args.prop_alpha)
        mlptype=f'_mlp{args.mlp_type}'
        RNN=f'_R{args.rnn}'
    else:
        topk=''
        prop=''
        mlptype=''
        RNN=''

    setting = '{}_fea={}_bs{}_pl{}_std{}_dm{}_df{}_nh{}_drop{:.2f}_el{}{}{}{}'.format(
        args.model,                 # 本次使用的模型名称
        args.features,              # 本次的预测任务类型，多/单变量
        args.batch_size,            # batch-size
        args.patch_len,             # patch段长度
        args.stride,                # stride长度
        args.d_model,               # 每个patch变为的嵌入长度
        args.d_ff,
        args.n_heads,               # 多头注意力
        args.dropout,
        args.e_layers,              # encoder层数
        topk,
        mlptype,
        prop,
    )
    task_path = '{}_{}_{}{}{}_loss_{}'.format(
        args.dataset_name,          # 本次数据集
        args.seq_len,               # 使用的历史长度
        args.pred_len,              # 预测长度
        f'_rnn({args.rnn})' if args.rnn != 0 else '',
        f'_gcn({args.subgraph_size}+{args.gcn_depth})' if args.use_gcn else '',
        args.loss
    )

    # 加载数据，data由train、test、pred三部分数据集以及相应的dataloader组成
    data, corr, high_correlated_count = load_data(args)
    if args.is_training:
        engine = Trainer(args, setting, task_path, corr, high_correlated_count)
        engine.train(data=data)
        mse = engine.test(test_loader=data['test_loader'])
        # 返回给optuna做判断的；在训练时训练模型，不再训练时直接加载模型验证用
        for name in ['train','val','test']:
            del data[name]
            del data[name+'_loader']
            torch.cuda.empty_cache()
        return mse
    else:
        engine = Trainer(args, setting, task_path, corr, high_correlated_count)
        engine.predict(pred_loader=data['pred_loader'])
        return 0


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Autoformer & Transformer family for Time Series Forecasting')

    # random seed
    parser.add_argument('--random_seed',        type=int,   default=2023, help='random seed')

    # basic config
    parser.add_argument('--is_training',        type=int,   default=1, help='status')
    parser.add_argument('--model',              type=str,   default='GRformer', help='model name, options: [Autoformer, Informer, Transformer, PatchTST, MultiPatchTST]')
    parser.add_argument('--decompose',                      default=False, help='whether to decompose the series')

    # data loader
    parser.add_argument('--dataset_name',       type=str,   default='weather', help='model id')
    parser.add_argument('--dataset_type',       type=str,   default='custom', help='dataset type')
    parser.add_argument('--root_path',          type=str,   default='data/weather', help='root path of the data file')
    parser.add_argument('--data_path',          type=str,   default='weather.csv', help='data file')
    parser.add_argument('--features',           type=str,   default='M', help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
    parser.add_argument('--target',             type=str,   default='OT', help='target feature in S or MS task')
    parser.add_argument('--freq',               type=str,   default='h', help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
    parser.add_argument('--checkpoints',        type=str,   default='./checkpoints/', help='location of model checkpoints')

    # forecasting task
    parser.add_argument('--seq_len',            type=int,   default=336, help='input sequence length')
    parser.add_argument('--label_len',          type=int,   default=48, help='start token length')
    parser.add_argument('--pred_len',           type=int,   default=96, help='prediction sequence length')


    # DLinear
    parser.add_argument('--individual',                     default=False, action='store_true', help='DLinear: a linear layer for each variate(channel) individually')

    # Graph Structure Construct
    parser.add_argument('--d_node',             type=int,   default=48, help='dim of nodes')
    parser.add_argument('--gcn_depth',          type=int,   default=2,  help='graph convolution depth')
    parser.add_argument('--subgraph_size',      type=int,   default=2, help='k')
    parser.add_argument('--tanh_alpha',         type=float, default=2,help='adj alpha')
    parser.add_argument('--prop_alpha',         type=float, default=0.09,help='prop alpha')
    parser.add_argument('--use_gcn',                        default=True, action='store_true', help='generate a graph?')
    parser.add_argument('--rnn',                type=int,   default=2, help='use RNN/LocalRNN PE? 1: LocalRNN 2: RNN')
    parser.add_argument('--mlp_type',           type=int,   default=2, help='0:gc 1:ln 2:custom-ln')
    # MTGNN
    parser.add_argument('--residual_channels',  type=int,   default=32, help='residual channels')
    parser.add_argument('--conv_channels',      type=int,   default=32, help='convolution channels')
    parser.add_argument('--skip_channels',      type=int,   default=128,help='skip channels')
    parser.add_argument('--end_channels',       type=int,   default=128,help='end channels')
    parser.add_argument('--dilation_exponential', type=int, default=2, help='dilation exponential')

    # PatchTST
    parser.add_argument('--fc_dropout',         type=float, default=0.2, help='fully connected dropout')
    parser.add_argument('--head_dropout',       type=float, default=0.0, help='head dropout')
    parser.add_argument('--patch_len',          type=int,   default=16, help='patch length')
    parser.add_argument('--multi_patch',                    default=False, action="store_true", help='patch length')
    parser.add_argument('--patch_len2',         type=str,   default='32,48', help='patch length 2')
    parser.add_argument('--stride',             type=str,   default="half", help='stride, half or full length of patch_len')
    parser.add_argument('--padding_patch',      type=str,   default='end', help='None: None; end: padding on the end')
    parser.add_argument('--revin',              type=int,   default=1, help='RevIN; True 1 False 0')
    parser.add_argument('--affine',             type=int,   default=0, help='RevIN-affine; True 1 False 0')
    parser.add_argument('--subtract_last',      type=int,   default=0, help='0: subtract mean; 1: subtract last')
    parser.add_argument('--decomposition',      type=int,   default=0, help='decomposition; True 1 False 0')
    parser.add_argument('--kernel_size',        type=int,   default=25, help='decomposition-kernel')

    # Formers 
    parser.add_argument('--pos_embed_type',     type=str,   default='sincos', help='how do you generate position embedding for Encoder')
    parser.add_argument('--embed_type',         type=int,   default=0, 
                        help='0: default 1: value embedding + temporal embedding + positional embedding 2: value embedding + temporal embedding 3: value embedding + positional embedding 4: value embedding')
    parser.add_argument('--enc_in',             type=int,   default=7, help='encoder input size')
    parser.add_argument('--dec_in',             type=int,   default=7, help='decoder input size')
    parser.add_argument('--c_out',              type=int,   default=7, help='output size')
    parser.add_argument('--d_model',            type=int,   default=128, help='dimension of model')
    parser.add_argument('--d_patch',            type=int,   default=8, help='embedding for single patch num')
    parser.add_argument('--n_heads',            type=int,   default=8, help='num of heads')
    parser.add_argument('--e_layers',           type=int,   default=2, help='num of encoder layers')
    parser.add_argument('--d_layers',           type=int,   default=1, help='num of decoder layers')
    parser.add_argument('--d_ff',               type=int,   default=256, help='dimension of fcn')
    parser.add_argument('--moving_avg',         type=int,   default=25, help='window size of moving average')
    parser.add_argument('--factor',             type=int,   default=1, help='attn factor')
    parser.add_argument('--distil',                         default=True, action='store_false', help='whether to use distilling in encoder, using this argument means not using distilling',)
    parser.add_argument('--dropout',            type=float, default=0.4, help='dropout')
    parser.add_argument('--embed',              type=str,   default='timeF', help='time features encoding, options:[timeF, fixed, learned]')
    parser.add_argument('--activation',         type=str,   default='gelu', help='activation')
    parser.add_argument('--output_attention',               default=False, action='store_true', help='whether to output attention in ecoder')
    parser.add_argument('--do_predict',                     default=False, action='store_true', help='whether to predict unseen future data')

    # optimization
    parser.add_argument('--num_workers',        type=int,   default=3, help='data loader num workers')
    parser.add_argument('--itr',                type=int,   default=1, help='experiments times')
    parser.add_argument('--train_epochs',       type=int,   default=100, help='train epochs')
    parser.add_argument('--batch_size',         type=int,   default=32, help='batch size of train input data')
    parser.add_argument('--patience',           type=int,   default=20, help='early stopping patience')
    parser.add_argument('--learning_rate',      type=float, default=0.00027, help='optimizer learning rate')
    parser.add_argument('--loss',               type=str,   default='mae', help='loss function, choose [mse, rmse, mae, mape, huber]')
    parser.add_argument('--lradj',              type=str,   default='type3', help='adjust learning rate')
    parser.add_argument('--pct_start',          type=float, default=0.2, help='how many epochs should the lr_rate get to max?')
    parser.add_argument('--use_amp',                        default=False, action='store_true', help='use automatic mixed precision training')
    parser.add_argument('--opt',                type=str,   default='adam', help='optimizer chosen from [Adam, SGD]')

    # GPU
    parser.add_argument('--use_gpu',                        default=False, action='store_true', help='use gpu')
    parser.add_argument('--gpu',                type=int,   default=0, help='gpu')
    parser.add_argument('--use_multi_gpu',                  default=False, action='store_true', help='use multiple gpus')
    parser.add_argument('--test_flop',                      default=False, action='store_true', help='See utils/tools for usage')

    args = parser.parse_args()
    # 查找可用的GPU设备
    if torch.cuda.is_available():
        args.device='cuda:{}'.format(args.gpu)
    else:
        args.device='cpu'
    
    if args.model == 'MTGNN':
        args.use_gcn=True
    
    if args.use_multi_gpu:
        args.device_ids = [i for i in range(torch.cuda.device_count())]
    
    args.root_path = '{}/{}'.format(ROOT_PATH, args.root_path)
    # 初始化以最小化目标的optuna学习器
    study = optuna.create_study(direction='minimize')
    # 将args绑定进objective
    objective_with_args = lambda trial: objective(trial, args)
    study.optimize(objective_with_args, n_trials=1)
    # 打印最佳参数和指标
    print('Best parameters:', study.best_params)
    print('Best score:', study.best_value)
