from transformers import AutoTokenizer 
from transformers .tokenization_utils import PreTrainedTokenizer 
import torch 
import numpy as np 
import copy 

IGNORE_INDEX =-100 

tokenizer :PreTrainedTokenizer =AutoTokenizer .from_pretrained (
"/bb/llm/gaf51275/llama/huggingface-checkpoint/Swallow-7b-hf/"
)
torch .set_printoptions (threshold =4096 )

conversations :dict ={
"input":[
{
"role":"user",
"text":"こんにちは"
},
{
"role":"assistant",
"text":"hello"
},
{
"role":"user",
"text":"good"
}
],
"output":"いいね"
}

SYSTEM_PROMPT =[
{"role":"system","text":"あなたは誠実で優秀な日本人のアシスタントです。"}
]

prompt :str =tokenizer .apply_chat_template (
conversation =SYSTEM_PROMPT +conversations ["input"],
tokenize =False 
)

if len (prompt )>=4096 *(2 /3 ):
    print (f"\n\nWARNING: len(prompt)={len (prompt )}, prompt={prompt }\n\n")

example :str =prompt +conversations ["output"]

encoded_prompt :torch .Tensor =torch .tensor (
tokenizer .encode (prompt ,add_special_tokens =False ),
dtype =torch .int64 
)
encoded_example :list [int ]=tokenizer .encode (
example ,add_special_tokens =False 
)
encoded_example .append (tokenizer .eos_token_id )
encoded_tensor_example :torch .Tensor =torch .tensor (encoded_example ,dtype =torch .int64 )

padding :int =4096 -encoded_tensor_example .shape [0 ]
if padding >0 :
    encoded_tensor_example =torch .cat ((encoded_tensor_example ,torch .zeros (padding ,dtype =torch .int64 )-1 ))
elif padding <0 :
    encoded_tensor_example =encoded_tensor_example [:4096 ]

labels =copy .deepcopy (encoded_tensor_example )

labels [:len (encoded_prompt )]=-1 

example_mask =encoded_tensor_example .ge (0 )
label_mask =labels .ge (0 )


encoded_tensor_example [~example_mask ]=0 

labels [~label_mask ]=IGNORE_INDEX 

print (
{
"input_ids":encoded_tensor_example ,
"labels":labels ,
"attention_mask":example_mask .float (),
}
)
