import argparse

import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from collections import OrderedDict

from model import GPT, GPTLM
from trainer import GPTTrainer
from utils.dataload import local_load_data, pai_load_data
from utils.tool import *


parser = argparse.ArgumentParser(description='Configuration from shell')
parser.add_argument('--local_flag',default=0)
parser.add_argument('--taskname',default='self_task')
parser.add_argument('--tables', default="TBT_newMechanism_dnn_model_feed_data_features_label_standard", type=str,
                        help='ODPS input table names')
parser.add_argument('--outputs', default="tables/tbt_mtl_predict_result_seq_attention", type=str,
                    help='ODPS output table names')
parser.add_argument('--load_step', default="0", type=str,
                    help='load [model_ckpt_{}.pt]')
local_args = parser.parse_args()
task_name = local_args.taskname

if local_args.local_flag == 0:
    Local = True
else:
    Local = False


def predict():
    args = load_arguments(Local)
    args['Local'] = 1 if Local else 0
    args["device"] = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    args['if_ddp_predict'] = 1
    args['batch_size'] = 1600
    print(args["device"])

    print("Set seed", args['seed'])
    set_seed(args['seed'])

    if not Local:
        print('DDP local multi-task start')

        save_dir = "TBT/" + task_name + "/"
        load_dir = "TBT/" + task_name + "/"
        train_dataloaders, val_dataloaders, test_dataloaders = pai_load_data(args, local_args)
    else:
        save_dir = "TBT/local/"
        load_dir = "TBT/3.63Mhid156layer1emb8_rep_FFN1_QKV64_ECQKV64_96_8card/"
        train_dataloaders, val_dataloaders, test_dataloaders = local_load_data(args)
        local_args.load_step = 1060000

    args['save_dir'] = save_dir
    # define tasks
    task_dict = {'trigger': {'metrics':['triggerMae','10mAcc','30mAcc'], 
                              'metrics_fn': TriggerMetric(args),
                              'loss_fn': TriggerLoss(args),
                              'weight': [1, 0, 1]}, 
                 'action': {'metrics':['acc', 'precision','recall','combine_acc'], 
                           'metrics_fn': ActionMetric(),
                           'loss_fn': ActionLoss(args),
                           'weight': [0, 0, 0, 1]},
                 'info': {'metrics':['acc'], 
                           'metrics_fn': InfoMetric(),
                           'loss_fn': InfoLoss(args),
                           'weight': [1]},
                 'voiceTimes': {'metrics':['acc','0_acc','1_acc','2_acc','3_acc','4_acc'], 
                            'metrics_fn': VoiceTimesMetric(),
                            'loss_fn': VoiceTimesLoss(args["voiceTimes_cnt"],args),
                            'weight': [1,1,1,1,0,0]}
    }
    # define trainer
    print("Building GPT model")
    gpt = GPT(args, hidden=args['hidden'], n_layers=args['layers'], attn_heads=args['attn_heads']).to(args["device"])
    model = GPTLM(gpt=gpt, args=args).to(args['device'])
    model.apply(init_weights)
    print(model)

    # load model
    load_path = load_dir + 'model_ckpt_{}.pt'.format(local_args.load_step)
    print(load_path)
    # load_path = "TBT/" + args['load_model']
    MAX_RETRIES = 3
    retry_count = 0
    while True:
        try:
            retry_count += 1
            buffer = BytesIO(bucket.get_object(load_path).read())
            break
        except Exception:
            if retry_count >= MAX_RETRIES:
                raise

    checkpoint = torch.load(buffer, map_location='cuda')
    new_state_dict = OrderedDict()
    for k, v in checkpoint.items():
        name = k[7:]
        new_state_dict[name] = v 
        print(name, v.size())
    model.load_state_dict(new_state_dict)
    TBTModel = GPTTrainer(model, args, train_dataloader=train_dataloaders, test_dataloader=test_dataloaders, val_dataloader=val_dataloaders, task_dict=task_dict)
                      
    print('model load finished')

    task_list = ['trigger','action','info','voiceTimes']
    result_dict, target_dict = TBTModel.seq_loss_recursive_predict(task_list, test_dataloaders)
    # result_dict, target_dict = TBTModel.predict(task_list, test_dataloaders)

    res = np.array([result_dict['trigger'], target_dict['trigger'],
                    result_dict['action'], target_dict['action'],
                    result_dict['info'], target_dict['info'],
                    result_dict['voiceTimes'], target_dict['voiceTimes'],
                    result_dict['caseid'], result_dict['path_id'], 
                    result_dict['seg_id'], result_dict['time'], 
                    result_dict['taskid_ori'], result_dict['nearest_label_gpst'],
                    result_dict['ds_to_sub_end'], result_dict['scene'], result_dict['mask_index']
                    ]).transpose(1, 0)
    if Local:
        print(result_dict)         

    if not Local:
        import common_io

        time1 = time.time()
        print('Table name', local_args.tables)
        print('Task name', local_args.taskname)
        writer = common_io.table.TableWriter(
            local_args.outputs,
            slice_id=int(os.environ.get("RANK"))
        )
        writer.write(res, col_indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]) 
        writer.close()

    print('---------------------finished!---------------------')

    
predict()