from byprot import utils
        
def get_initial(args, length, tokenizer, device):
    seq = ['<mask>'] * length
    seq = [''.join(seq)]
    init_seq = seq * args.num_seqs
    batch = tokenizer.batch_encode_plus(init_seq,
                                add_special_tokens=True,
                                padding="longest",
                                return_tensors='pt')
    batch = {
        # <cls> <mask> <mask> ... <mask> <eos>
        'input_ids':  batch['input_ids'],
        # 1 1 1 ... 1 1
        'input_mask': batch['attention_mask'].bool(),
    }
    batch = utils.recursive_to(batch, device)
    
    return batch