import argparse
import os
import torch
import random
import numpy as np
import uuid
import datetime
import importlib
import wandb


# from exp.exp_online import Exp_TS2VecSupervised


def init_dl_program(
        device_name,
        seed=None,
        use_cudnn=True,
        deterministic=False,
        benchmark=False,
        use_tf32=False,
        max_threads=None
):
    import torch
    if max_threads is not None:
        torch.set_num_threads(max_threads)  # intraop
        if torch.get_num_interop_threads() != max_threads:
            torch.set_num_interop_threads(max_threads)  # interop
        try:
            import mkl
        except:
            pass
        else:
            mkl.set_num_threads(max_threads)

    if seed is not None:
        random.seed(seed)

        np.random.seed(seed)

        torch.manual_seed(seed)

        torch.cuda.manual_seed(seed)

        torch.cuda.manual_seed_all(seed)

        torch.backends.cudnn.benchmark = False

        torch.backends.cudnn.deterministic = True

        # https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
        os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8'

        # avoiding nondeterministic algorithms (see https://pytorch.org/docs/stable/notes/randomness.html)
        torch.use_deterministic_algorithms(True)

    if isinstance(device_name, (str, int)):
        device_name = [device_name]

    devices = []
    for t in reversed(device_name):
        t_device = torch.device(t)
        devices.append(t_device)
        if t_device.type == 'cuda':
            assert torch.cuda.is_available()
            torch.cuda.set_device(t_device)
            if seed is not None:
                seed += 1
                torch.cuda.manual_seed(seed)
    devices.reverse()
    torch.backends.cudnn.enabled = use_cudnn
    torch.backends.cudnn.deterministic = deterministic
    torch.backends.cudnn.benchmark = benchmark

    if hasattr(torch.backends.cudnn, 'allow_tf32'):
        torch.backends.cudnn.allow_tf32 = use_tf32
        torch.backends.cuda.matmul.allow_tf32 = use_tf32

    return devices if len(devices) > 1 else devices[0]


parser = argparse.ArgumentParser(description='[Informer] Long Sequences Forecasting')

# LSTD

parser.add_argument('--dropout', type=float, default=0.05, help='dropout')
parser.add_argument('--depth', type=int, default=9, help='depth of TCN')
parser.add_argument('--hidden_dim', type=int, default=512, help='hidden dimension of en/decoder')
parser.add_argument('--hidden_layers', type=int, default=2, help='number of hidden layers of en/decoder')
parser.add_argument('--batch_size', type=int, default=32, help='batch size of train input data')
parser.add_argument('--learning_rate', type=float, default=0.003, help='optimizer learning rate')
parser.add_argument('--tau', type=float, default=0.6, help='trigger_data')
parser.add_argument('--zd_kl_weight', type=float, default=0.001, help='num of encoder layers')
parser.add_argument('--zc_kl_weight', type=float, default=0.001, help='num of encoder layers')
parser.add_argument('--L1_weight', type=float, default=0.001, help='num of encoder layers')
parser.add_argument('--L2_weight', type=float, default=0.001, help='num of encoder layers')
parser.add_argument('--rec_weight', type=float, default=0.5, help='latent dimension of koopman embedding')
parser.add_argument('--mode', type=str, default='time', help='latent dimension of koopman embedding')
parser.add_argument('--n_class', type=int, default=4, help='num of encoder layers')
parser.add_argument('--No_prior', action='store_true', default=True, help='num of encoder layers')
parser.add_argument('--is_bn', action='store_true', default=False, help='num of encoder layers')
parser.add_argument('--dynamic_dim', type=int, default=128, help='latent dimension of koopman embedding')
parser.add_argument('--lags', type=int, default=1, help='num of encoder layers')

parser.add_argument('--data', type=str, default='ETTh2', help='data')
parser.add_argument('--root_path', type=str, default='./data/', help='root path of the data file')
parser.add_argument('--data_path', type=str, default='ETTh2.csv', help='data file')
parser.add_argument('--features', type=str, default='M',
                    help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
parser.add_argument('--freq', type=str, default='h',
                    help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints')

parser.add_argument('--seq_len', type=int, default=96, help='input sequence length of Informer encoder')
parser.add_argument('--label_len', type=int, default=0, help='start token length of Informer decoder')
parser.add_argument('--pred_len', type=int, default=1, help='prediction sequence length')
# Informer decoder input: concat[start token series(label_len), zero padding series(pred_len)]

parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
parser.add_argument('--c_out', type=int, default=7, help='output size')
parser.add_argument('--d_model', type=int, default=32, help='dimension of model')
parser.add_argument('--n_heads', type=int, default=8, help='num of heads')
parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers')
parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers')
parser.add_argument('--s_layers', type=str, default='3,2,1', help='num of stack encoder layers')
parser.add_argument('--d_ff', type=int, default=128, help='dimension of fcn')
parser.add_argument('--factor', type=int, default=5, help='probsparse attn factor')
parser.add_argument('--padding', type=int, default=0, help='padding type')
parser.add_argument('--distil', action='store_false',
                    help='whether to use distilling in encoder, using this argument means not using distilling',
                    default=True)

parser.add_argument('--attn', type=str, default='prob', help='attention used in encoder, options:[prob, full]')
parser.add_argument('--embed', type=str, default='timeF',
                    help='time features encoding, options:[timeF, fixed, learned]')
parser.add_argument('--activation', type=str, default='gelu', help='activation')
parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder')
parser.add_argument('--do_predict', action='store_true', help='whether to predict unseen future data')
parser.add_argument('--mix', action='store_false', help='use mix attention in generative decoder', default=True)
parser.add_argument('--cols', type=str, nargs='+', help='certain cols from the data files as the input features')
parser.add_argument('--num_workers', type=int, default=0, help='data loader num workers')
parser.add_argument('--itr', type=int, default=1, help='experiments times')
parser.add_argument('--train_epochs', type=int, default=3, help='train epochs')
parser.add_argument('--patience', type=int, default=3, help='early stopping patience')
parser.add_argument('--learning_rate_w', type=float, default=0.001, help='optimizer learning rate')
parser.add_argument('--learning_rate_bias', type=float, default=0.001, help='optimizer learning rate')
parser.add_argument('--weight_decay', type=float, default=1e-3, help='optimizer learning rate')
parser.add_argument('--des', type=str, default='test', help='exp description')
parser.add_argument('--loss', type=str, default='mse', help='loss function')
parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate')
parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False)
parser.add_argument('--inverse', action='store_true', help='inverse output data', default=False)
parser.add_argument('--method', type=str, default='onenet_fsnet')

# PatchTST
parser.add_argument('--fc_dropout', type=float, default=0.05, help='fully connected dropout')
parser.add_argument('--head_dropout', type=float, default=0.0, help='head dropout')
parser.add_argument('--patch_len', type=int, default=16, help='patch length')
parser.add_argument('--stride', type=int, default=8, help='stride')
parser.add_argument('--padding_patch', default='end', help='None: None; end: padding on the end')
parser.add_argument('--revin', type=int, default=0, help='RevIN; True 1 False 0')
parser.add_argument('--affine', type=int, default=0, help='RevIN-affine; True 1 False 0')
parser.add_argument('--subtract_last', type=int, default=0, help='0: subtract mean; 1: subtract last')
parser.add_argument('--decomposition', type=int, default=0, help='decomposition; True 1 False 0')
parser.add_argument('--kernel_size', type=int, default=25, help='decomposition-kernel')
parser.add_argument('--tcn_output_dim', type=int, default=320, help='decomposition-kernel')
parser.add_argument('--tcn_layer', type=int, default=2, help='decomposition-kernel')
parser.add_argument('--tcn_hidden', type=int, default=160, help='decomposition-kernel')
parser.add_argument('--individual', type=int, default=1, help='individual head; True 1 False 0')

parser.add_argument('--teacher_forcing', action='store_true', help='use teacher forcing during forecasting',
                    default=False)
parser.add_argument('--online_learning', type=str, default='full')
parser.add_argument('--opt', type=str, default='adam')

parser.add_argument('--test_bsz', type=int, default=1)
parser.add_argument('--n_inner', type=int, default=1)
parser.add_argument('--channel_cross', type=bool, default=False)

parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
parser.add_argument('--gpu', type=int, default=0, help='gpu')
parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False)
parser.add_argument('--devices', type=str, default='0,1,2,3', help='device ids of multile gpus')

parser.add_argument('--finetune', action='store_true', default=False)
parser.add_argument('--finetune_model_seed', type=int)

parser.add_argument('--aug', type=int, default=0, help='Training with augmentation data aug iterations')
parser.add_argument('--lr_test', type=float, default=1e-3, help='learning rate during test')

# supplementary config for FEDformer model
parser.add_argument('--version', type=str, default='Wavelets',
                    help='for FEDformer, there are two versions to choose, options: [Fourier, Wavelets]')
parser.add_argument('--mode_select', type=str, default='random',
                    help='for FEDformer, there are two mode selection method, options: [random, low]')
parser.add_argument('--modes', type=int, default=64, help='modes to be selected random 64')
parser.add_argument('--L', type=int, default=3, help='ignore level')
parser.add_argument('--base', type=str, default='legendre', help='mwt base')
parser.add_argument('--cross_activation', type=str, default='tanh',
                    help='mwt cross atention activation function tanh or softmax')
parser.add_argument('--moving_avg', default=[24], help='window size of moving average')

parser.add_argument('--gamma', type=float, default=0.1)
parser.add_argument('--m', type=int, default=24)
parser.add_argument('--loss_aug', type=float, default=0.5, help='weight for augmentation loss')
parser.add_argument('--use_adbfgs', action='store_true', help='use the Adbfgs optimizer', default=True)
parser.add_argument('--period_len', type=int, default=12)
parser.add_argument('--mlp_depth', type=int, default=3)
parser.add_argument('--mlp_width', type=int, default=256)
parser.add_argument('--station_lr', type=float, default=0.0001)

parser.add_argument('--sleep_interval', type=int, default=1, help='latent dimension of koopman embedding')
parser.add_argument('--sleep_epochs', type=int, default=1, help='latent dimension of koopman embedding')
parser.add_argument('--sleep_kl_pre', type=float, default=0, help='latent dimension of koopman embedding')
parser.add_argument('--delay_fb', action='store_true', default=False, help='use delayed feedback')
parser.add_argument('--online_adjust', type=float, default=0.5, help='latent dimension of koopman embedding')
parser.add_argument('--offline_adjust', type=float, default=0.5, help='latent dimension of koopman embedding')
parser.add_argument('--online_adjust_var', type=float, default=0.5, help='latent dimension of koopman embedding')
parser.add_argument('--var_weight', type=float, default=0.0, help='latent dimension of koopman embedding')
parser.add_argument('--alpha_w', type=float, default=0.0001, help='spectrum filter ratio')
parser.add_argument('--alpha_d', type=float, default=0.003, help='spectrum filter ratio')
parser.add_argument('--test_lr', type=float, default=0.1, help='spectrum filter ratio')
parser.add_argument('--seed', type=int, default=1, help='set of seed')

parser.add_argument('--x_dim', type=int, default=10, help='x embed dimension/length')
parser.add_argument('--output', type=str, default='results', help='result output path')
parser.add_argument('--sparsity_weight', type=float, default=0.001, help='sparsity loss weight')

args = parser.parse_args()

args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False
args.test_bsz = args.batch_size if args.test_bsz == -1 else args.test_bsz
if args.use_gpu and args.use_multi_gpu:
    args.devices = args.devices.replace(' ', '')
    device_ids = args.devices.split(',')
    args.device_ids = [int(id_) for id_ in device_ids]
    args.gpu = args.device_ids[0]

data_parser = {
    'ETTh1': {'data': 'ETTh1.csv', 'T': 'OT', 'M': [7, 7, 7], 'S': [1, 1, 1], 'MS': [7, 7, 1]},
    'ETTh2': {'data': 'ETTh2.csv', 'T': 'OT', 'M': [7, 7, 7], 'S': [1, 1, 1], 'MS': [7, 7, 1]},
    'ETTm1': {'data': 'ETTm1.csv', 'T': 'OT', 'M': [7, 7, 7], 'S': [1, 1, 1], 'MS': [7, 7, 1]},
    'ETTm2': {'data': 'ETTm2.csv', 'T': 'OT', 'M': [7, 7, 7], 'S': [1, 1, 1], 'MS': [7, 7, 1]},
    'WTH': {'data': 'WTH.csv', 'T': 'WetBulbCelsius', 'M': [12, 12, 12], 'S': [1, 1, 1], 'MS': [12, 12, 1]},
    'ECL': {'data': 'ECL.csv', 'T': 'OT', 'M': [321, 321, 321], 'S': [1, 1, 1], 'MS': [321, 321, 1]},
    'Solar': {'data': 'solar_AL.csv', 'T': 'POWER_136', 'M': [137, 137, 137], 'S': [1, 1, 1], 'MS': [137, 137, 1]},
    'Toy': {'data': 'Toy.csv', 'T': 'Value', 'S': [1, 1, 1]},
    'ToyG': {'data': 'ToyG.csv', 'T': 'Value', 'S': [1, 1, 1]},
    'Exchange': {'data': 'exchange_rate.csv', 'T': 'OT', 'M': [8, 8, 8]},
    'Illness': {'data': 'national_illness.csv', 'T': 'OT', 'M': [7, 7, 7]},
    'Traffic': {'data': 'traffic.csv', 'T': 'OT', 'M': [862, 862, 862]},
}
if args.data in data_parser.keys():
    data_info = data_parser[args.data]
    args.data_path = data_info['data']
    args.target = data_info['T']
    args.enc_in, args.dec_in, args.c_out = data_info[args.features]

args.s_layers = [int(s_l) for s_l in args.s_layers.replace(' ', '').split(',')]
args.detail_freq = args.freq
args.freq = args.freq[-1:]

print('Args in experiment:')
print(args)

# Exp = Exp_TS2VecSupervised
Exp = getattr(importlib.import_module('exp.exp_{}'.format(args.method)), 'Exp_TS2VecSupervised')

metrics, preds, true, mae, mse = [], [], [], [], []

for ii in range(1):
    print('\n ====== Run {} ====='.format(args.seed))
    # setting record of experiments
    # method_name = 'ts2vec_finetune' if args.finetune else 'ts2vec_supervised'
    method_name = args.method
    uid = uuid.uuid4().hex[:4]
    suffix = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M") + "_" + uid
    setting = '{}_{}_pl{}_ol{}_opt{}_tb{}_{}'.format(method_name, args.data, args.pred_len, args.online_learning,
                                                         args.opt, args.test_bsz, suffix)

    init_dl_program(args.gpu, seed=args.seed)
    args.finetune_model_seed = args.seed
    exp = Exp(args)  # set experiments
    print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
    print('Total parameters ', sum(p.numel() for p in exp.model.parameters() if p.requires_grad))
    # exit()
    exp.train(setting)

    print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
    m, mae_, mse_, p, t = exp.test(setting)
    _mae, _mse = m[0], m[1]
    # os.makedirs(f"./results/{args.data}/{args.method}/", exist_ok=True)
    # f = open(f"./results/{args.data}/{args.method}/{args.seq_len}_{args.pred_len}.txt",
    #          'a')
    os.makedirs(f"./{args.output}/{args.data}/{args.method}/", exist_ok=True)
    f = open(f"./{args.output}/{args.data}/{args.method}/{args.seq_len}_{args.pred_len}.txt",
             'a')
    f.write(f"{args}\n")
    f.write(f"{args.seed}\n")
    f.write('mse:{}, mae:{}'.format(_mse, _mae))
    f.write('\n\n')
    f.close()

    metrics.append(m)
    if str(args.data) == 'Traffic':
        preds = [0]
        true = [0]
    else:
        preds.append(p)
        true.append(t)

    mae.append(mae_)
    mse.append(mse_)
    torch.cuda.empty_cache()

# folder_path = './results/' + setting + '/'
# if not os.path.exists(folder_path):
#     os.makedirs(folder_path)
# np.save(folder_path + 'metrics.npy', np.array(metrics))
# np.save(folder_path + 'preds.npy', np.array(preds))
# np.save(folder_path + 'trues.npy', np.array(true))
# np.save(folder_path + 'mae.npy', np.array(mae))
# np.save(folder_path + 'mse.npy', np.array(mse))
