import sys
sys.path.append('../..')

from modeling.rwkv import RWKV
from transformers import AutoTokenizer

tok_name = 'RWKV/v5-EagleX-v2-7B-HF'
print(f"Loading tokenizer from: {tok_name}")
tok = AutoTokenizer.from_pretrained(tok_name, truth_remote_code=True)

text = 'My name is'
input_ids = tok(text, return_tensors='pt', max_length=128, pad_to_multiple_of=128).input_ids.cuda()
print(input_ids)

path = '../../RWKV-5-World-0.4B-v2-20231113-ctx4096.pt'
print(f"Loading checkpoint from: {path}")
model = RWKV(path).cuda()

breakpoint()
outputs = model(input_ids)
print(outputs)
breakpoint()
