import argparse
import os
import torch
from exp.exp_main_original import Exp_Main
from exp.exp_main_distributed import Exp_Main as Exp_Main_dist

import random
import numpy as np
import data_provider.data_info
import torch.distributed as dist
import torch.multiprocessing as mp
import time 

parser = argparse.ArgumentParser(description='Autoformer & Transformer family for Time Series Forecasting')


# random seed
parser.add_argument('--random_seed', type=int, default=2021, help='random seed')

# setting config
parser.add_argument('--is_training', type=int, default=1, help='status')
parser.add_argument('--model', type=str, default='PatchTST')
parser.add_argument('--data', type=str, required=True, default='ETTm1', help='dataset type')
parser.add_argument('--seq_len', type=int, default=96, help='input sequence length')
parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length')
parser.add_argument('--global_path', type=str, default='exp_results', help='global path of results')

# model config
parser.add_argument('--dropout', type=float, default=0.3, help='dropout')
parser.add_argument('--patch_len', type=int, default=16, help='patch length')
parser.add_argument('--stride', type=int, default=8, help='stride')
parser.add_argument('--d_model', type=int, default=16, help='dimension of model')
parser.add_argument('--n_heads', type=int, default=4, help='num of heads')
parser.add_argument('--e_layers', type=int, default=3, help='num of encoder layers')
parser.add_argument('--d_ff', type=int, default=128, help='dimension of fcn')
parser.add_argument('--split_num', type=int, default=1, help='stride')
# parser.add_argument('--split_num_e', type=int, default=10, help='stride')
parser.add_argument('--split_mult', type=int, default=-1, help='stride')
parser.add_argument('--comp_dim', type=int, default=-1, help='dimension of fcn')
parser.add_argument('--rep_num', type=int, default=3, help='dimension of fcn')
parser.add_argument('--adaptive_dilated_atten', type=eval, default=True)
parser.add_argument('--no_dilated_atten', type=eval, default=False)
parser.add_argument('--multi_query', type=eval, default=True)


# optimization
parser.add_argument('--batch_size', type=int, default=128, help='batch size of train input data')
parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate')

# python run_longExp.py --model PatchTST --data weather --seq_len 512 --pred_len 96 --patch_len 64 --stride 64  --dropout 0.3 --d_model 32 --n_heads 4 --e_layers 3 --d_ff 128 --gpu 0
# others (fix)
parser.add_argument('--train_epochs', type=int, default=100, help='train epochs')
parser.add_argument('--patience', type=int, default=100, help='early stopping patience')
parser.add_argument('--loss', type=str, default='mse', help='loss function')
parser.add_argument('--lradj', type=str, default='type3', help='adjust learning rate')
parser.add_argument('--label_len', type=int, default=0, help='start token length')
parser.add_argument('--root_path', type=str, default='./dataset/', help='root path of the data file')
parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints')
parser.add_argument('--save_checkpoints', type=eval, default=True)

# distributed
parser.add_argument('--gpu', type=str, default='7', help='gpu')
parser.add_argument('--num_workers_per_proc','-nwpp', type=int, default=2)




def init_for_distributed(rank, opts):

    # 1. setting for distributed training
    opts.rank = rank
    local_gpu_id = int(opts.gpu[opts.rank])
    opts.local_gpu_id = local_gpu_id
    opts.device = f'cuda:{local_gpu_id}'
    torch.cuda.set_device(local_gpu_id)
    if opts.rank is not None:
        print("Use GPU: {} for training".format(local_gpu_id))

    # 2. init_process_group
    dist.init_process_group(backend='nccl',
                            init_method=f'tcp://127.0.0.1:{23456 + int(opts.ori_gpu[0])}', # 
                            world_size=opts.world_size,
                            rank = rank)

    # if put this function, the all processes block at all.
    torch.distributed.barrier()
    # convert print fn iif rank is zero
    setup_for_distributed(opts.rank == 0)
    print(opts)
    return opts


def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print

def main_worker1(rank, args, setting):
    args = init_for_distributed(rank, args)
    exp = Exp_Main_dist(args)  # set experiments
    # print("start"); time.sleep(100) 
    print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
    exp.train(setting)
    print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
    exp.test(setting)
    

def main_worker2(rank, args, setting):
    args = init_for_distributed(rank, args)
    exp = Exp_Main_dist(args)  # set experiments
    print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
    exp.test(setting, test=1)
        


if __name__ == "__main__":

    args = parser.parse_args()

    if args.data == "electricity":
        args.patience = 50

    # random seed
    fix_seed = args.random_seed
    random.seed(fix_seed)
    torch.manual_seed(fix_seed)
    np.random.seed(fix_seed)
    args.train_epochs = 100
    args.data_path, args.enc_in = data_provider.data_info.data_information[args.data]
    if args.comp_dim <= 0:
        args.comp_dim = args.d_model
    if args.split_mult <= 0:
        args.split_mult = args.enc_in // args.split_num 

    # assert "0" not in args.gpu, f"gpu({args.gpu})"

    args.use_gpu = True
    args.distributed = ',' in args.gpu

    if args.distributed:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
        args.ori_gpu = args.gpu.split(",")
        args.gpu = [str(i) for i in range(len(args.gpu.split(",")))]#args.gpu.split(",")
        args.world_size  = len(args.gpu)
        args.num_workers = len(args.gpu) * args.num_workers_per_proc    

    # if args.data == "electricity" and args.d_model > 400:
    #     args.d_model = 448
    #     args.d_ff = 128


    print('Args in experiment:')
    if args.no_dilated_atten:
        mark = "ndil"
    elif args.adaptive_dilated_atten:
        mark = "adil"
    else:
        mark = "dil"
    
    setting = '{}_{}_sl{}_pl{}_p{}_s{}_d{}_dm{}_nh{}_el{}_df{}_{}_{}_{}_lr{}_c{}_{}_{}'.format(
        args.model, args.data, args.seq_len, args.pred_len, args.patch_len, args.stride,
        args.dropout, args.d_model, args.n_heads, args.e_layers, args.d_ff, 
        args.split_num, args.split_mult, args.rep_num, args.learning_rate, args.comp_dim, mark, args.random_seed)

    if args.distributed:
        if args.is_training:
            mp.spawn(main_worker1, args = (args, setting), nprocs=args.world_size,join=True)
        else:
            mp.spawn(main_worker2, args = (args, setting), nprocs=args.world_size,join=True)
    else:
        print(args)
        if args.is_training:
            exp = Exp_Main(args)  # set experiments
            print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
            exp.train(setting)
            if args.save_checkpoints:
                print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
                exp.test(setting)

            torch.cuda.empty_cache()
        else:

            exp = Exp_Main(args)  # set experiments
            print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
            exp.test(setting, test=1)
            torch.cuda.empty_cache()
