import argparse

import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

from model import GPT, GPTLM
from trainer import GPTTrainer
from utils.dataload import local_load_data, pai_load_data
from utils.tool import *


parser = argparse.ArgumentParser(description='Configuration from shell')
parser.add_argument('--local_flag',default=0)
parser.add_argument('--taskname',default='self_task')
parser.add_argument('--tables', default="", type=str,
                        help='ODPS input table names')
local_args = parser.parse_args()
task_name = local_args.taskname

if local_args.local_flag == 0:
    Local = True
else:
    Local = False


def train():
    args = load_arguments(Local)
    args['Local'] = 1 if Local else 0
    args["device"] = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    args['if_ddp_predict'] = 0
    print("Set seed", args['seed'])
    set_seed(args['seed'])

    #args['lr'] = 2e-3
    #args['adam_weight_decay'] = 0
    # args['hidden'] =156
    # args['mlp_hid_dim'] = 156
    # args['layers'] = 1
    # args['mlp_emb_dim'] = 8

    # args['att_emb'] = 9*args['mlp_emb_dim']+7     
    # args['embeded_all_feature_cnt'] = 33*args['mlp_emb_dim']+11+128 

    if not Local:
        print('DDP multi-task training start')
        if torch.cuda.is_available():
            dist.init_process_group("nccl")
        else:
            dist.init_process_group("gloo")

        save_dir = "TBT/" + task_name + "/"
        load_dir = "TBT/" + task_name + "/"
        train_dataloaders, val_dataloaders, test_dataloaders = pai_load_data(args, local_args)
    else:
        save_dir = "TBT/local/"
        load_dir = './'
        train_dataloaders, val_dataloaders, test_dataloaders = local_load_data(args)

    args['save_dir'] = save_dir
    # define tasks
    task_dict = {'trigger': {'metrics':['triggerMae','10mAcc','30mAcc'], 
                              'metrics_fn': TriggerMetric(args),
                              'loss_fn': TriggerLoss(args),
                              'weight': [1, 0, 1]}, 
                 'action': {'metrics':['acc', 'precision','recall','combine_acc'], 
                           'metrics_fn': ActionMetric(),
                           'loss_fn': ActionLoss(args),
                           'weight': [0, 0, 0, 1]},
                 'info': {'metrics':['acc'], 
                           'metrics_fn': InfoMetric(),
                           'loss_fn': InfoLoss(args),
                           'weight': [1]},
                 'voiceTimes': {'metrics':['acc','0_acc','1_acc','2_acc','3_acc','4_acc'], 
                            'metrics_fn': VoiceTimesMetric(),
                            'loss_fn': VoiceTimesLoss(args["voiceTimes_cnt"],args),
                            'weight': [1,1,1,1,0,0]}
    }
    # 4. define trainer
    other_net_lr = args['lr']
    weight_decay = args['adam_weight_decay']

    print("Building GPT model")
    gpt = GPT(args, hidden=args['hidden'], n_layers=args['layers'], attn_heads=args['attn_heads'], dropout=args["gpt_drop_out"]).to(args["device"])
    model = GPTLM(gpt=gpt, args=args).to(args['device'])
    model.apply(init_weights)
    print(model)
    print('model params: ')
    print(count_module_all_parameters(model))

    if not Local:
        model = DDP(model, find_unused_parameters=True) # find_unused_parameters=True
        print('DDP predict')
    else:
        torch.save(model,'./ckpt/model.pt')

    optimizer = torch.optim.Adam([
            {'params': [p[1] for p in model.named_parameters() if 'decoders.trigger' not in p[0] and p[1].requires_grad], 'eps':1e-08},
            {'params': [p[1] for p in model.named_parameters() if 'decoders.trigger' in p[0] and p[1].requires_grad], 'lr': 2e-3, 'eps':1e-08},
        ],
        lr=other_net_lr, weight_decay=weight_decay)
        
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=
                                               [100000, 200000, 300000, 400000, 500000, 600000, 700000, 800000, 900000, 1000000
                                                , 1100000, 1200000, 1300000, 1400000, 1500000, 1600000,1700000,1800000,1900000, 2000000
                                                , 2100000, 2200000, 2300000, 2400000, 2500000, 2600000, 2700000, 2800000, 2900000, 3000000
                                                           ],
                                               gamma=0.8)

    print('other_net_lr, weight_decay', other_net_lr, weight_decay)

    print("Creating GPT Trainer")
    trainer = GPTTrainer(model, args, train_dataloader=train_dataloaders, test_dataloader=test_dataloaders, val_dataloader=val_dataloaders,
                          lr=args['lr'], weight_decay=args['adam_weight_decay'],
                          with_cuda=args['with_cuda'], log_freq=args['log_freq'], 
                          optimizer=optimizer, scheduler=scheduler, task_dict=task_dict)

    print("Training Start")
    trainer.train(args)
# if __name__ == '__main__':
train()