import os
import random
import time
import glob
import math
import pickle
# from itertools import chain

import numpy as np
import torch
import torch.nn.parallel
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
from torch.utils.tensorboard import SummaryWriter
import tqdm

import model_completion

from args import get_args
from utils import AllGather
from utils import get_cosine_schedule_with_warmup
from utils import get_model_name



from npz_dataset import NpzDataset, MAX_VIS_LEN, MAX_TEXT_LEN, DISTANCE_BTW_FEAT
from torch.utils.data.distributed import DistributedSampler
from eval_sampler import DistributedEvalSampler

allgather = AllGather.apply
FILTER_LEVEL = 1.0


def main():
    args = get_args()

    # either not evaluating or providing the path for eval
    assert (args.eval_path != '') or (not args.evaluate)

    # train video data would be provided
    assert args.train_path != ''

    # task name should also need to be provided
    assert args.task_name != ''

    # it is not having both, but should not overlap. Not(A & B) => Not A  OR   Not B
    assert (not args.extract_ilp) or (not args.next_step_pred)

    # SLURM-based setting
    if (args.world_size == -1) and ("SLURM_NPROCS" in os.environ):
        args.world_size = int(os.environ["SLURM_NPROCS"])
        args.rank = int(os.environ["SLURM_PROCID"])
        jobid = os.environ["SLURM_JOBID"]
        args.dist_url = "file://{}.{}".format(os.path.realpath(args.dist_file), jobid)
        print(
            "dist-url:{} at PROCID {} / {}".format(
                args.dist_url, args.rank, args.world_size
            )
        )
    # Should work for a single-server setting with commands but never tried
    elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        args.world_size = int(os.environ['WORLD_SIZE'])
        args.rank = int(os.environ["RANK"])
        args.gpu = int(os.environ['LOCAL_RANK'])
        print(
            "dist-url:{} at PROCID {} / {}".format(
                args.dist_url, args.rank, args.world_size
            )
        )

        # http://aaronsplace.co.uk/blog/2018-12-08-pytorch-distributed-learning.html
        if args.rank > 0:
            import time
            time.sleep(5)
    elif (args.world_size == -1):
        raise NotImplementedError

    args.distributed = (args.world_size >= 1) or args.multiprocessing_distributed
    ngpus_per_node = torch.cuda.device_count()
    # We will not use this option unless we are using SLURM
    if args.multiprocessing_distributed:
        args.world_size = ngpus_per_node * args.world_size
        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
    # this is the main function call
    else:
        main_worker(args.gpu, ngpus_per_node, args)


def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    # set the NCCL distribution option
    if args.distributed:
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(
            backend=args.dist_backend,
            init_method=args.dist_url,
            world_size=args.world_size,
            rank=args.rank,
        )

    if args.rank == 0:
        print(args)


    # set random seed before we start running our model
    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)


    if os.path.exists(args.seq_data_path):
        label_data = np.load(args.seq_data_path, allow_pickle=True)
        label_set = label_data.item()
        label_set = label_set[args.task_name.strip()]

        # remove any duplicate and store the trajectory as a dict with filename
        no_duplicate_subtask_list = []
        for trad_idx, st in enumerate(label_set['trajectories']):
            if len(st['subtask_indices']) == 0:
                continue

            single_result = {}

            name_candidate = st['name']
            if name_candidate.endswith('_verb_phrases.tsv'):
                name_candidate = name_candidate[:-1 * len('_verb_phrases.tsv')]
            single_result['name'] = name_candidate
            if args.dataset_name == 'ProceL':
                single_result['vid_id'] = int(name_candidate.split('_')[-1])
                single_result['trad_idx'] = trad_idx
            else:
                single_result['trad_idx'] = trad_idx
                single_result['vid_id'] = trad_idx
            frame_numbers = st['frame_numbers']
            frame_numbers[frame_numbers<0] = 0

            # now we need to store start, end, and count
            temp_se_counter = np.full((label_set['num_subtask'], 3), -1000000, dtype=np.int32)
            result_array = []

            paired_data = np.stack([st['subtask_indices'], st['start_flags'].astype(np.int32), frame_numbers], axis=-1)
            # sort by time if needed
            paired_data = sorted(paired_data, key=lambda r: r[2] * 1000000 + r[1] * 500000 + r[0])

            for subtask_id, se_flag, frame_num in paired_data:
                # start
                if se_flag == 1:
                    if temp_se_counter[subtask_id, 0] == -1000000:
                        temp_se_counter[subtask_id, 0] = 0
                        temp_se_counter[subtask_id, 1] = frame_num
                        result_array.append([subtask_id, frame_num, 1])
                    temp_se_counter[subtask_id, 0] += 1
                # end
                else:
                    temp_se_counter[subtask_id, 0] -= 1
                    if (temp_se_counter[subtask_id, 0] == 0) and (temp_se_counter[subtask_id, 2] == -1000000):
                        temp_se_counter[subtask_id, 2] = frame_num
                        result_array.append([subtask_id, frame_num, 0])


            result_array = np.array(result_array)

            if len(result_array) > 0:
                single_result['subtask_indices'] = result_array[:, 0]
                single_result['frame_numbers'] = result_array[:, 1]
                single_result['start_flags'] = result_array[:, 2].astype(np.bool_)

            if 'subtask_indices' in single_result:
                no_duplicate_subtask_list.append(single_result)

        label_set['trajectories'] = no_duplicate_subtask_list
        args.current_label_set = label_set
    else:
        args.current_label_set = None
        raise Exception('label set is required from now on')

    # load the dataset only
    train_dataset = NpzDataset(os.path.join(args.train_path, args.task_name),
                                  is_train=True,
                                  nextstep_pred=args.next_step_pred,
                                  resample_lowerbound=1.0 if args.next_step_pred else args.resample_lowerbound,
                                  expected_batch_size=args.train_batch_size,
                                  with_text=args.with_text,
                                  extract_ilp=args.extract_ilp)
    test_seq_dataset = None
    if args.next_step_pred:
        test_seq_dataset = NpzDataset(os.path.join(args.train_path, args.task_name),
                                      is_train=False,
                                      resample_lowerbound=1.0,
                                      expected_batch_size=-1,
                                      with_text=args.with_text,
                                      extract_ilp=False)

    test_dataset = None
    args.current_val_completion = None
    if len(args.eval_path) > 0:
        test_dataset = NpzDataset(os.path.join(args.eval_path, args.task_name),
                                     is_train=False,
                                     resample_lowerbound=1.0,
                                     expected_batch_size=-1,
                                     with_text=args.with_text,
                                     extract_ilp=args.extract_ilp)
        if args.dataset_name == 'ProceL':
            test_label_path = os.path.join(args.eval_path, 'ProceL_completion_test_labels.pkl.npy')
            if os.path.exists(test_label_path):
                print('found completion labels for evaluation')
                all_test_labels = np.load(test_label_path, allow_pickle=True).item()
                if args.task_name not in all_test_labels:
                    raise Exception('failed to find test labels for the current task')

                candidate_labels = {}
                for k, v in all_test_labels[args.task_name].items():
                    name_candidate = k
                    if name_candidate.endswith('_verb_phrases.tsv'):
                        name_candidate = name_candidate[:-1 * len('_verb_phrases.tsv')]
                    vid_id = int(name_candidate.split('_')[-1])

                    if len(v[0]) != args.current_label_set['num_subtask']:
                        print('incident', k, 'should have', args.current_label_set['num_subtask'], 'values per prediction but it has', len(v[0]), v[0])
                    else:
                        candidate_labels[vid_id] = v

                args.current_val_completion = candidate_labels
            else:
                print('failed to find completion labels for evaluation:', test_label_path)

    # create models
    completion_modal = model_completion.CompletionModal(input_dim=args.input_dim,
                                                        output_dim=args.current_label_set['num_subtask'],
                                                        hidden_dim=args.hidden_dim,
                                                        num_layers=args.num_layers,
                                                        text_att_n_head=args.text_att_n_head,
                                                        text_feedforward_dim=args.text_feedforward_dim,
                                                        text_att_type=args.text_att_type)
    # load the model to the GPU
    if args.distributed:
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            completion_modal.cuda(args.gpu)
            args.train_batch_size = int(args.train_batch_size / ngpus_per_node)
            args.eval_batch_size = int(args.eval_batch_size / ngpus_per_node)
            args.num_thread_reader = int(args.num_thread_reader / ngpus_per_node)
            completion_modal = torch.nn.parallel.DistributedDataParallel(completion_modal, device_ids=[args.gpu])

        #  we are not using this
        else:
            raise Exception('not expected')
            completion_modal.cuda()
            completion_modal = torch.nn.parallel.DistributedDataParallel(completion_modal)

    # one GPU, without distribution
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        completion_modal = completion_modal.cuda(args.gpu)
    else:
        completion_modal = torch.nn.DataParallel(completion_modal).cuda()
        print('this experiment is not trained with DDP')

    train_sampler = DistributedSampler(dataset=train_dataset,
                                       num_replicas=args.world_size,
                                       rank=args.rank,
                                       seed=args.seed,
                                       shuffle=True)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.train_batch_size,
        num_workers=args.num_thread_reader,
        pin_memory=args.pin_memory,
        persistent_workers=True,
        prefetch_factor=10,
        sampler=train_sampler,
        drop_last=True
    )

    train_steps = len(train_loader)
    if args.gpu == 0:
        print('total #steps per epoc: {}'.format(train_steps))

    if (test_dataset is not None): # and (args.rank == 0):
        # FIXME: I am only using the master gpu for inference now.
        test_sampler = DistributedEvalSampler(test_dataset,
                                              num_replicas=1,
                                              rank=0,
                                              seed=args.seed,
                                              shuffle=False)

        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=args.eval_batch_size * args.world_size,
            num_workers=args.num_thread_reader,
            persistent_workers=True,
            prefetch_factor=10,
            sampler=test_sampler,
            shuffle=False,
            drop_last=False,
        )
    else:
        test_loader = None

    if (test_seq_dataset is not None): # and (args.rank == 0):
        # FIXME: I am only using the master gpu for inference now.
        test_seq_sampler = DistributedEvalSampler(test_seq_dataset,
                                                  num_replicas=1,
                                                  rank=0,
                                                  seed=args.seed,
                                                  shuffle=False)

        test_seq_loader = torch.utils.data.DataLoader(
            test_seq_dataset,
            batch_size=args.eval_batch_size * args.world_size,
            num_workers=args.num_thread_reader,
            persistent_workers=True,
            prefetch_factor=10,
            sampler=test_seq_sampler,
            shuffle=False,
            drop_last=False,
        )
    else:
        test_seq_loader = None

    classification_loss = torch.nn.BCELoss(reduction='none')

    if args.optimizer == 'adam':
        optimizer = torch.optim.Adam(completion_modal.parameters(), args.learning_rate_in_float)
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(completion_modal.parameters(), args.learning_rate_in_float, momentum=args.momemtum)

    scheduler = get_cosine_schedule_with_warmup(optimizer, args.warmup_steps, train_steps * args.epochs)
    tensorboard = None

    # optionally resume from a checkpoint
    if len(args.resume.strip()) > 0:
        args.model_name = args.resume
        checkpoint_dir = os.path.join(args.cp_root, args.model_name)
        checkpoint_path = get_last_checkpoint(checkpoint_dir)
        if checkpoint_path:
            log("=> loading checkpoint '{}'".format(checkpoint_path), args)
            checkpoint = torch.load(checkpoint_path)
            args.start_from = checkpoint["epoch"]
            completion_modal.load_state_dict(checkpoint["completion_modal_dict"])
            optimizer.load_state_dict(checkpoint["optimizer"])
            scheduler.load_state_dict(checkpoint["scheduler"])
            log("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint["epoch"]), args)
        else:
            log("=> no checkpoint found at '{}'".format(checkpoint_path), args)

        if args.rank == 0:
            tensorboard_dir = os.path.join(args.tb_root, args.model_name)
            if (args.tb_root != '') and (not os.path.exists(tensorboard_dir)):
                os.makedirs(tensorboard_dir)
            tensorboard = SummaryWriter(tensorboard_dir)

    else:
        # do the rest setup
        args.model_name = get_model_name(args)
        if args.rank == 0:
            checkpoint_dir = os.path.join(args.cp_root, args.model_name)
            if (args.cp_root != '') and (not os.path.exists(checkpoint_dir)):
                os.makedirs(checkpoint_dir)

            tensorboard_dir = os.path.join(args.tb_root, args.model_name)
            if (args.tb_root != '') and (not os.path.exists(tensorboard_dir)):
                os.makedirs(tensorboard_dir)
            tensorboard = SummaryWriter(tensorboard_dir)

    # optional: turn on cudnn benchmark option.
    if args.cudnn_benchmark:
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True

    total_batch_size = args.world_size * args.train_batch_size
    if args.rank == 0:
        log(
            "Starting training loop total batch size: {}".format(
                total_batch_size
            ), args
        )

        def count_parameters(model):
            return sum(p.numel() for p in model.parameters() if p.requires_grad)
        log('parameters in model: {}'.format(count_parameters(completion_modal)), args)

    # testing purpose
    if args.evaluate:
        if test_seq_loader is not None:
            evaluate(test_seq_loader, args.gpu, completion_modal, classification_loss,
                     tensorboard, args.start_from, args)
            comp_for_next_step(test_loader, args.gpu, completion_modal,
                               args.start_from, args)
        else:
            evaluate(test_loader, args.gpu, completion_modal, classification_loss,
                     tensorboard, args.start_from, args)
        if args.infer_only:
            return

    for epoch in range(args.start_from, args.epochs):
        if train_sampler is not None:
            train_sampler.set_epoch(epoch)

        if (epoch+1) % args.n_eval == 0:
            if args.distributed:
                torch.distributed.barrier()

            # only the main process saves the checkpoint
            if args.rank == 0:
                save_checkpoint(
                    {
                        "epoch": epoch + 1,
                        "completion_modal_dict": completion_modal.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "scheduler": scheduler.state_dict(),
                    }, checkpoint_dir, epoch + 1, args
                )

            if args.evaluate:
                if test_seq_loader is not None:
                    evaluate(test_seq_loader, args.gpu, completion_modal, classification_loss,
                             tensorboard, epoch+1, args)
                    comp_for_next_step(test_loader, args.gpu, completion_modal,
                                       epoch+1, args)
                else:
                    evaluate(test_loader, args.gpu, completion_modal, classification_loss,
                             tensorboard, epoch+1, args)

            if args.distributed:
                torch.distributed.barrier()

        # train for one epoch
        train(train_loader, args.gpu, completion_modal, classification_loss,
              optimizer, scheduler, tensorboard, epoch, train_steps, args)


def train(train_loader, gpu_id, completion_modal, classification_loss,
          optimizer, scheduler, tensorboard, epoch, train_steps, args):
    running_loss = 0.0
    s = time.time()
    completion_modal.train()
    for batch_idx, sample_batch in enumerate(train_loader):
        gpu_batch = {k: v.cuda(gpu_id, non_blocking=args.pin_memory) for k, v in sample_batch.items() if k not in ['vid_id', 'vis_max_text_idx', 'text_base_idx']}
        batch_loss = train_one_batch(completion_modal, optimizer, scheduler,
                                     gpu_batch, classification_loss, tensorboard,
                                     epoch * train_steps + batch_idx, args)
        running_loss += batch_loss
        if ((batch_idx + epoch + 1) % args.n_display == 0) and (args.verbose) and (args.rank == 0):
            d = time.time() - s
            log(
                "Epoch %d, Elapsed Time: %.3f, Epoch status: %.4f, Training loss: %.4f, Learning rate: %.6f"
                % (
                    epoch + 1,
                    d,
                    float(batch_idx) / train_steps,
                    running_loss / args.n_display,
                    optimizer.param_groups[0]['lr'],
                ), args
            )

            # it is already only for rank==0
            if tensorboard is not None:
                log_loss = running_loss / args.n_display
                tensorboard.add_scalar('train_epoch/running_loss', log_loss, epoch * train_steps + batch_idx)
            running_loss = 0.0
            s = time.time()

        if (batch_idx + 1) == train_steps:
            # end of one epoch
            break


def train_one_batch(completion_modal, opt, scheduler, data,
                    classification_loss, tensorboard, counter, args, epsilon=1e-10):
    opt.zero_grad()
    for p in completion_modal.parameters():
        p.grad = None

    # first to subsample the dataset
    pad_vis_feature = data['vis_feature']
    vis_seq_len = data['vis_seq_len']
    pad_label = data['label']
    pad_label_mask = data['pred_mask']

    pad_begin_mask = data['begin_mask']
    pad_end_mask = data['end_mask']

    if args.with_text:
        text_seq_len = data['text_seq_len']
        pad_text_feature = data['text_feature']
    else:
        text_seq_len = None
        pad_text_feature = None

    with torch.set_grad_enabled(True):
        pad_pred = completion_modal(vis_feat=pad_vis_feature,
                                    vis_seq_len=vis_seq_len,
                                    begin_mask=pad_begin_mask,
                                    end_mask=pad_end_mask,
                                    text_feat=pad_text_feature,
                                    text_seq_len=text_seq_len)

        if args.distributed:
            pad_pred = allgather(pad_pred, args)
            pad_label = allgather(pad_label, args)
            pad_label_mask = allgather(pad_label_mask, args)
            vis_seq_len = allgather(vis_seq_len, args)

            pad_begin_mask = allgather(pad_begin_mask, args)
            pad_end_mask = allgather(pad_end_mask, args)

        pad_label_mask = pad_label_mask.unsqueeze(1)
        unpad_mask = (torch.arange(MAX_VIS_LEN, device=vis_seq_len.device).reshape(1, MAX_VIS_LEN).repeat(len(vis_seq_len), 1) < vis_seq_len.unsqueeze(-1)).to(torch.float32)
        unpad_mask = unpad_mask.unsqueeze(-1)

        pred1 = torch.sigmoid(pad_pred + 1) * pad_label_mask
        label1 = (pad_label >= 0).to(torch.int16) * pad_label_mask

        pred2 = torch.sigmoid(pad_pred - 1) * pad_label_mask
        label2 = (pad_label >= 1).to(torch.int16) * pad_label_mask


        masked_pred_pos = torch.logical_or(pad_begin_mask, pad_end_mask).to(torch.float32)
        unsqueezed_pred_pos = masked_pred_pos.unsqueeze(-1)
        ce1 = (classification_loss(pred1, label1) * unpad_mask * unsqueezed_pred_pos).mean(-1).sum(-1) / masked_pred_pos.sum(-1)
        ce2 = (classification_loss(pred2, label2) * unpad_mask * unsqueezed_pred_pos).mean(-1).sum(-1) / masked_pred_pos.sum(-1)
        loss = ce1.mean() + ce2.mean()

    opt.zero_grad()
    loss.backward()
    opt.step()
    scheduler.step()


    if args.rank == 0:
        log_items = {'train_batch/loss': loss.item(),
                     'train_batch/ce_loss1': ce1.mean().item(),
                     'train_batch/ce_loss2': ce2.mean().item()}
        if tensorboard is not None:
            for k, v in log_items.items():
                tensorboard.add_scalar(k, v, counter)
    return loss.item()


def evaluate(test_loader, gpu_id, completion_modal, classification_loss,
             tensorboard, epoch, args):
    if test_loader is None:
        completion_modal.train()
        return

    completion_modal.eval()
    if args.rank == 0:
        log('Evaluating on {}, Epoch {}'.format(args.dataset_name, epoch), args)

    answer_dict = {}
    # taking the representation from the model
    with torch.no_grad():
        for sample_batch in tqdm.tqdm(test_loader):
            gpu_batch = {k: v.cuda(gpu_id, non_blocking=args.pin_memory) for k, v in sample_batch.items() if k not in ['vis_max_text_idx', 'text_base_idx']}

            # first to subsample the dataset
            pad_vis_feature = gpu_batch['vis_feature']
            vis_seq_len = gpu_batch['vis_seq_len']
            pad_label = gpu_batch['label']
            pad_label_begin = pad_label[:, ::DISTANCE_BTW_FEAT]
            pad_label_end = pad_label[:, DISTANCE_BTW_FEAT-1::DISTANCE_BTW_FEAT]
            pad_label_mask = gpu_batch['pred_mask']
            vid_id = gpu_batch['vid_id']

            pad_begin_mask = gpu_batch['begin_mask']
            pad_end_mask = gpu_batch['end_mask']

            if args.with_text:
                text_seq_len = gpu_batch['text_seq_len']
                pad_text_feature = gpu_batch['text_feature']
            else:
                text_seq_len = None
                pad_text_feature = None

            pad_pred = completion_modal(vis_feat=pad_vis_feature,
                                        vis_seq_len=vis_seq_len,
                                        begin_mask=pad_begin_mask,
                                        end_mask=pad_end_mask,
                                        text_feat=pad_text_feature,
                                        text_seq_len=text_seq_len)
            pad_pred_begin = pad_pred[:, ::DISTANCE_BTW_FEAT]
            pad_pred_end = pad_pred[:, DISTANCE_BTW_FEAT-1::DISTANCE_BTW_FEAT]

            pad_label_mask = pad_label_mask.unsqueeze(1)
            # remove any padded seq here
            unpad_mask = (torch.arange(MAX_VIS_LEN, device=vis_seq_len.device).reshape(1, MAX_VIS_LEN).repeat(len(vis_seq_len), 1) < vis_seq_len.unsqueeze(-1)).to(torch.float32)
            unpad_mask = unpad_mask.unsqueeze(-1)
            unpad_mask_begin = unpad_mask[:, ::DISTANCE_BTW_FEAT]
            unpad_mask_end = unpad_mask[:, DISTANCE_BTW_FEAT-1::DISTANCE_BTW_FEAT]

            pred1_begin = torch.sigmoid(pad_pred_begin + 1) * pad_label_mask
            label1_begin = (pad_label_begin >= 0).to(torch.int16) * pad_label_mask

            pred1_end = torch.sigmoid(pad_pred_end + 1) * pad_label_mask
            label1_end = (pad_label_end >= 0).to(torch.int16) * pad_label_mask

            pred2_begin = torch.sigmoid(pad_pred_begin - 1) * pad_label_mask
            label2_begin = (pad_label_begin >= 1).to(torch.int16) * pad_label_mask

            pred2_end = torch.sigmoid(pad_pred_end - 1) * pad_label_mask
            label2_end = (pad_label_end >= 1).to(torch.int16) * pad_label_mask

            vis_seq_len = (vis_seq_len / DISTANCE_BTW_FEAT).to(torch.int32)
            vis_seq_len[vis_seq_len == 0] = 1

            ce1_begin = (classification_loss(pred1_begin, label1_begin) * unpad_mask_begin).mean(-1).sum(-1) / vis_seq_len
            ce1_end =(classification_loss(pred1_end, label1_end) * unpad_mask_end).mean(-1).sum(-1) / vis_seq_len

            ce2_begin = (classification_loss(pred2_begin, label2_begin) * unpad_mask_begin).mean(-1).sum(-1) / vis_seq_len
            ce2_end = (classification_loss(pred2_end, label2_end) * unpad_mask_end).mean(-1).sum(-1) / vis_seq_len

            ce1 = (ce1_begin + ce1_end) / 2
            ce2 = (ce2_begin + ce2_end) / 2
            loss = (ce1 + ce2) / 2

            pad_pred_merged = torch.cat([pad_pred_begin.unsqueeze(2), pad_pred_end.unsqueeze(2)], dim=2)
            pad_pred_merged = pad_pred_merged.reshape(len(vis_seq_len), -1, args.current_label_set['num_subtask'])
            vis_seq_len = vis_seq_len * 2

            if args.rank == 0:
                for single_name, single_len, single_pred, single_loss in zip(vid_id, vis_seq_len, pad_pred_merged, loss):
                    answer_dict[single_name.cpu().item()] = (single_pred.detach().cpu()[:single_len.cpu().item()], single_loss.detach().cpu())


    if args.rank == 0:
        final_loss = 0.0
        final_hit = 0.0
        hit_item_count = 0
        completion_logit = {}
        completion_binary = {}
        label_match_result = {}
        for single_name, (single_pred, single_loss) in tqdm.tqdm(answer_dict.items(), total=len(answer_dict)):
            single_logit = single_pred.numpy()
            single_completion = single_logit > FILTER_LEVEL

            completion_logit[single_name] = single_logit
            completion_binary[single_name] = single_completion

            if  (args.current_val_completion is not None) and (single_name in args.current_val_completion):
                label_match_result[single_name] = (single_completion == args.current_val_completion[single_name])
                single_hit = label_match_result[single_name].astype(np.float32).mean()
                final_hit = final_hit + single_hit
                hit_item_count = hit_item_count + 1

            final_loss = final_loss + single_loss.item()

        if len(answer_dict) > 0:
            final_loss = final_loss / len(answer_dict)
        if hit_item_count > 0:
            final_hit = final_hit / hit_item_count

        trajectory_dict = {}
        if args.extract_ilp or args.next_step_pred:
            if args.current_label_set is None:
                completion_name = os.path.join(args.cp_root, args.model_name, 'completion_{}_{}.pkl.npy'.format(args.task_name, epoch))
                np.save(completion_name, completion_logit)
            else:
                for trad_idx, st in enumerate(args.current_label_set['trajectories']):
                    # which subtask number is used
                    vid_id = st['vid_id']
                    trajectory_dict[vid_id] = dict(st)
                    if vid_id in completion_logit:
                        trajectory_dict[vid_id]['completion_pred'] = completion_binary[vid_id]
                        trajectory_dict[vid_id]['completion_pred_number'] = completion_logit[vid_id]

                # remove non-exists files
                for vid_id in list(trajectory_dict.keys()):
                    if 'completion_pred' not in trajectory_dict[vid_id]:
                        trajectory_dict.pop(vid_id)

                # sort the results
                sorted_list = sorted(trajectory_dict.values(), key=lambda item: item['trad_idx'])
                final_task_dict = {}
                final_task_dict['num_subtask'] = args.current_label_set['num_subtask']
                final_task_dict['subtask_labels'] = args.current_label_set['subtask_labels']
                final_task_dict['trajectories'] = sorted_list
                completion_sr_name = os.path.join(args.cp_root, args.model_name, 'completion_{}_{}.pkl.npy'.format(args.task_name, epoch))
                np.save(completion_sr_name, {args.task_name: final_task_dict})

        log('Task: {}, Epoch: {}, Hit: {:.4} Loss: {:.5}'.format(args.task_name, epoch, final_hit, final_loss), args)

        # it is only for rank == 0
        loss_item = {'eval_epoch/loss': final_loss,
                     'eval_epoch/hit_ratio': final_hit}
        if tensorboard is not None:
            for k, v in loss_item.items():
                tensorboard.add_scalar(k, v, epoch)
    completion_modal.train()


def comp_for_next_step(test_loader, gpu_id, completion_modal, epoch, args):
    if test_loader is None:
        completion_modal.train()
        return

    completion_modal.eval()
    if args.rank == 0:
        log('Inferring next step on {}, Epoch {}'.format(args.dataset_name, epoch), args)

    answer_dict = {}
    # taking the representation from the model
    with torch.no_grad():
        for sample_batch in tqdm.tqdm(test_loader):
            gpu_batch = {k: v.cuda(gpu_id, non_blocking=args.pin_memory) for k, v in sample_batch.items()}

            # first to read a sequence
            pad_vis_feature = gpu_batch['vis_feature']
            vis_seq_len = gpu_batch['vis_seq_len']
            pad_label_mask = gpu_batch['pred_mask']
            vid_id = gpu_batch['vid_id']

            pad_begin_mask = gpu_batch['begin_mask']
            pad_end_mask = gpu_batch['end_mask']

            if args.with_text:
                text_seq_len = gpu_batch['text_seq_len']
                pad_text_feature = gpu_batch['text_feature']
                vis_max_text_idx = gpu_batch['vis_max_text_idx']
                text_base_idx = gpu_batch['text_base_idx']
            else:
                text_seq_len = None
                pad_text_feature = None
                vis_max_text_idx = None
                text_base_idx = None

            # run the data per each step (unroll step by step)
            total_steps = vis_seq_len.to(torch.int32).max().cpu().item()

            slice_vis_feature = torch.zeros_like(pad_vis_feature)
            slice_begin_mask = torch.zeros_like(pad_begin_mask)
            slice_end_mask = torch.zeros_like(pad_end_mask)
            if args.with_text:
                slice_text_feature = torch.zeros_like(pad_text_feature)
            else:
                slice_text_feature = None
                slice_text_seq_len = None

            answer_single_loop_dict = {}
            for step_len in range(DISTANCE_BTW_FEAT, total_steps+1, DISTANCE_BTW_FEAT):
                slice_vis_feature[:, :step_len] = pad_vis_feature[:, :step_len]
                slice_vis_seq_len = torch.where(vis_seq_len < step_len, vis_seq_len, step_len * torch.ones_like(vis_seq_len))
                if slice_begin_mask is not None:
                    slice_begin_mask[:, :step_len] = pad_begin_mask[:, :step_len]
                if slice_end_mask is not None:
                    slice_end_mask[:, :step_len] = pad_end_mask[:, :step_len]
                if slice_text_feature is not None:
                    max_text_idx = vis_max_text_idx[:, step_len:step_len+1]
                    text_matched_idx = ((text_base_idx - max_text_idx) == 0).to(torch.int8).argmax(dim=1)

                    slice_text_seq_len = torch.where(text_seq_len < text_matched_idx, text_seq_len, text_matched_idx)
                    for idx, (single_idx, orig_feat) in enumerate(zip(text_matched_idx, pad_text_feature)):
                        if single_idx > 0:
                            slice_text_feature[idx, :single_idx] = orig_feat[:single_idx]

                pad_pred = completion_modal(vis_feat=slice_vis_feature,
                                            vis_seq_len=slice_vis_seq_len,
                                            begin_mask=slice_begin_mask,
                                            end_mask=slice_end_mask,
                                            text_feat=slice_text_feature,
                                            text_seq_len=slice_text_seq_len)
                pad_pred_begin = pad_pred[:, ::DISTANCE_BTW_FEAT]
                pad_pred_end = pad_pred[:, DISTANCE_BTW_FEAT-1::DISTANCE_BTW_FEAT]

                current_pred_idx = step_len // DISTANCE_BTW_FEAT - 1

                if args.rank == 0:
                    for single_name, single_begin, single_end in zip(vid_id, pad_pred_begin, pad_pred_end):
                        save_key = single_name.cpu().item()
                        answer_single_loop_dict.setdefault(save_key, []).append(single_begin.detach().cpu()[current_pred_idx])
                        answer_single_loop_dict.setdefault(save_key, []).append(single_end.detach().cpu()[current_pred_idx])

            if args.rank == 0:
                vis_seq_len = (vis_seq_len / DISTANCE_BTW_FEAT).to(torch.int32)
                vis_seq_len = vis_seq_len * 2
                for single_name, single_len in zip(vid_id, vis_seq_len):
                    save_key = single_name.cpu().item()
                    value_arr = torch.nan_to_num(torch.stack(answer_single_loop_dict[save_key], dim=0), nan=0.0)
                    answer_dict[single_name.cpu().item()] = value_arr[:single_len.cpu().item()]


    if args.rank == 0:
        completion_logit = {}
        completion_binary = {}
        label_match_result = {}
        number_of_complete_pred_for_debug = []
        for single_name, single_pred in tqdm.tqdm(answer_dict.items(), total=len(answer_dict)):
            single_logit = single_pred.numpy()
            single_completion = single_logit > FILTER_LEVEL
            number_of_complete_pred_for_debug.append(single_completion.astype(np.int32).sum())
            completion_logit[single_name] = single_logit
            completion_binary[single_name] = single_completion
        print('completion output: {}'.format(sum(number_of_complete_pred_for_debug)))

        trajectory_dict = {}
        for trad_idx, st in enumerate(args.current_label_set['trajectories']):
            # which subtask number is used
            vid_id = st['vid_id']
            trajectory_dict[vid_id] = dict(st)
            if vid_id in completion_logit:
                trajectory_dict[vid_id]['completion_pred'] = completion_binary[vid_id]
                trajectory_dict[vid_id]['completion_pred_number'] = completion_logit[vid_id]

        # remove non-exists files
        for vid_id in list(trajectory_dict.keys()):
            if 'completion_pred' not in trajectory_dict[vid_id]:
                trajectory_dict.pop(vid_id)

        # sort the results
        sorted_list = sorted(trajectory_dict.values(), key=lambda item: item['trad_idx'])
        final_task_dict = {}
        final_task_dict['num_subtask'] = args.current_label_set['num_subtask']
        final_task_dict['subtask_labels'] = args.current_label_set['subtask_labels']
        final_task_dict['trajectories'] = sorted_list
        next_step_sr_name = os.path.join(args.cp_root, args.model_name, 'next_step_{}_{}.pkl.npy'.format(args.task_name, epoch))
        np.save(next_step_sr_name, {args.task_name: final_task_dict})

    completion_modal.train()


def save_checkpoint(state, checkpoint_dir, epoch, args):
    torch.save(state, os.path.join(checkpoint_dir, "epoch{:0>4d}.pth.tar".format(epoch)))
    config_path = os.path.join(checkpoint_dir, 'config.pkl')
    if not os.path.exists(config_path):
        with open(config_path, 'wb') as f:
            pickle.dump(vars(args), f)


def get_last_checkpoint(checkpoint_dir):
    all_ckpt = glob.glob(os.path.join(checkpoint_dir, 'epoch*.pth.tar'))
    if all_ckpt:
        all_ckpt = sorted(all_ckpt)
        return all_ckpt[-1]
    else:
        return ''


def log(output, args):
    if args.verbose:
        print(output)
    with open(os.path.join(args.cp_root, args.model_name, 'log.txt'), "a") as f:
        f.write(output + '\n')


if __name__ == "__main__":
    main()
