import torch

def split_tokenize_context_prompt_targets(tok, contexts, prompt, target):
    assert tok.padding_side == 'left'
    if 'llama' in tok.name_or_path:
        target = target.strip()
        prompt = prompt.strip()
    context_toks = tok(contexts, return_tensors='pt', padding=True, truncation=True) # [N_c, L]
    context_tokens = context_toks['input_ids']    
    prompt_tokens = tok(prompt, add_special_tokens=False, return_tensors='pt')['input_ids']
    target_tokens = tok(target, add_special_tokens=False, return_tensors='pt')['input_ids']
    
    if context_tokens.numel() != 0:
        assert context_tokens[-1][-1].item() != tok.eos_token_id # may have bos, but shouldn't contain eos
        bsz = context_tokens.shape[0]
    else:
        bsz = 1
        context_tokens = context_tokens.long()
        
    prompt_tokens_rep = prompt_tokens.repeat([bsz, 1])
    target_tokens_rep = target_tokens[:, :-1].repeat([bsz, 1])

    tgt_len, pt_len = target_tokens.shape[1], prompt_tokens.shape[1]

    # concat input tensor
    input_tensor = torch.cat([context_tokens, prompt_tokens_rep, target_tokens_rep], dim=1)
    cp_tensor = torch.cat([context_tokens, prompt_tokens_rep], dim=1)
    attention_mask = input_tensor != tok.pad_token_id
    # construct rewriting targets
    rewriting_targets = torch.zeros_like(input_tensor)
    prompt_mask = torch.zeros_like(input_tensor)
    # rewriting_targets[]
    for i in range(bsz):
        rewriting_targets[i, -tgt_len:] = target_tokens.to(input_tensor.device)
        if tgt_len == 1:
            prompt_mask[i, (-tgt_len-pt_len+1):] = 1
        else:
            prompt_mask[i, (-tgt_len-pt_len+1):(-tgt_len+1)] = 1
    
    cp_prompt_mask = torch.zeros_like(cp_tensor).bool()
    cp_prompt_mask[:, -pt_len:] = 1
        
    return {
        'decoder_inputs': input_tensor,
        'attention_mask': attention_mask,
        'rewriting_targets': rewriting_targets,
        'prompt_tokens': prompt_tokens,
        'target_tokens': target_tokens,
        'context_tokens': context_tokens,
        'prompt_mask': prompt_mask.bool(),
        'cp_inputs': cp_tensor,
        'cp_prompt_mask': cp_prompt_mask,
    }
    

if __name__ == '__main__':
    pass