from model import *
import argparse
import os
from data import *
from train import *


def main(args, **kwargs):
    setup_seed(args.seed)

    for file in ['pic', 'loss', 'src', 'data', 'model']:
        os.makedirs(f'{args.working_dir}/{file}', exist_ok=True)

    datas = get_data(args, True, **kwargs)
    print('prepare data done!')
    train(args, datas, **kwargs)



if __name__ == '__main__':

    parser = argparse.ArgumentParser(description="Pytorch distributed")

    parser.add_argument('-N_train', '--train_data_size', type = int, default = 1000) 
    parser.add_argument('-N_test', '--test_data_size', type = int, default = 500) 
    parser.add_argument('-sl', '--seq_len', type = int, default = 9)
    parser.add_argument('-dmin', '--data_min', type = int, default = 20)
    parser.add_argument('-dmax', '--data_max', type = int, default = 100)

    parser.add_argument('-dmode', '--data_mode', nargs='*', type=str, default = [1])
    parser.add_argument('-dp', '--data_percent', nargs='*', type=float, default = [1])
    parser.add_argument('-dn', '--data_name', nargs='*', type=str, default = ['full data'])
    parser.add_argument('-dmask', '--data_mask', nargs='*', type=int, default = [0])
    parser.add_argument('-dshow', '--data_show', nargs='*', type=int, default = [0])

    parser.add_argument('-func', '--target', type = str, default = '3x_to_x')

    parser.add_argument('-bs', '--batch_size', type = int, default = 10) 
    parser.add_argument('-vs', '--vocab_size', type = int, default = 201) 
    parser.add_argument('-mp', '--max_pos', type = int, default = 20)
    parser.add_argument('-dm', '--d_model', type = int, default = 400)
    parser.add_argument('-d_ff', '--d_feedforward', type = int, default = 1200)
    parser.add_argument('-dk', '--d_k', type = int, default = 64)
    parser.add_argument('-dv', '--d_v', type = int, default = 64)
    parser.add_argument('-nl', '--n_layers', type = int, default = 4)
    parser.add_argument('-nh', '--n_heads', type = int, default = 4)
    parser.add_argument('-cl', '--clip', type = int, default = 1)

    parser.add_argument('-ne', '--n_epoch', type = int, default = 3000) 
    parser.add_argument('-lr', '--lr', type = float, default = 1.e-4) 
    parser.add_argument('-lds', '--lr_decay_step', type = int, default = 1000) 
    parser.add_argument('-ldr', '--lr_decay_rate', type = float, default = 1)  
    parser.add_argument('-seed', '--seed', type = int, default = 1)  
    parser.add_argument('-scheduler', '--scheduler', type = str, default = 'GradualWarmupScheduler_CosineAnnealingLR')


    parser.add_argument('-m', '--model', type = str, default = 'GPT') 
    parser.add_argument('-op', '--optim', choices = ['Adam', 'SGD', 'AdamW'], default = 'AdamW')  

    parser.add_argument('-sme', '--save_model_epoch', type = int, default = 100) 
    parser.add_argument('-ple', '--print_loss_epoch', type = int, default = 10)
    parser.add_argument('-pae', '--print_acc_epoch', type = int, default = 100)
    parser.add_argument('-plae', '--plot_loss_acc_epoch', type = int, default = 500)
    
    parser.add_argument('-prefix', '--prefix', type = str, default = ' ')
    parser.add_argument('-suffix', '--suffix', type = str, default = ' ')

    parser.add_argument('-dir_suffix', '--dir_suffix', type = str, default = ' ')

    args, remaining = parser.parse_known_args()

    remaining_dict = {}
    for i in range(0, len(remaining), 2):
        key = remaining[i].lstrip('-')
        value = remaining[i+1]
        remaining_dict[key] = value

    working_dir = f'{args.target}-seed_{int(args.seed)}-N_{int(args.train_data_size)}'
    
    if args.prefix != ' ':
        working_dir = f'{args.prefix}-{working_dir}'
    if args.suffix != ' ':
        working_dir = f'{working_dir}-{args.suffix}'
    
    if args.dir_suffix != ' ':
        args.working_dir = f'./result/{args.model}_{args.dir_suffix}/{working_dir}'
    else:
        args.working_dir = f'./result/{args.model}/{working_dir}'

    print(args.working_dir)

    main(args, **remaining_dict)