import sys
sys.path.append('../..')

import torch
from modeling.rwkv5.modeling_rwkv5_hf import Rwkv5ForCausalLM
from modeling.rwkv5.configuration_rwkv5 import Rwkv5Config
from transformers import AutoTokenizer
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging


def convert_state_dict(sd):
    new_sd = {}
    for k in sd:
        old_k = k
        if k != 'head.weight':
            k = 'rwkv.' + k
        replacements = {
            'ln0': 'pre_ln',
            'emb.': 'embeddings.',
            'att.': 'attention.',
            'ffn.': 'feed_forward.',
            '.time_mix_k': '.time_mix_key',
            '.time_mix_v': '.time_mix_value',
            '.time_mix_r': '.time_mix_receptance',
            '.time_mix_g': '.time_mix_gate',
        }
        for a, b in replacements.items():
            k = k.replace(a, b)
        new_sd[k] = sd[old_k]
    return new_sd


def main():
    pretrained_name = '/mnt/data/user/tc_agi/cyf/rwkv5/RWKV-5-World-0.4B-v2-20231113-ctx4096.pt'
    prompt = '\n' + "我们发现" * 100
    NUM_TRIALS = 3
    LENGTH_PER_TRIAL = 100
    TEMPERATURE = 1.0
    TOP_P = 0.1

    config = Rwkv5Config(hidden_size=1024)
    tokenizer = AutoTokenizer.from_pretrained('../../tokenizers/rwkv5-tok', trust_remote_code=True)
    model = Rwkv5ForCausalLM(config)
    sd = torch.load(pretrained_name)
    model.load_state_dict(convert_state_dict(sd))
    model = model.cuda()

    print("==== prompt ====")
    print(prompt)
    print("================")
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids.cuda()
    output_ids = model.generate(input_ids, max_new_tokens=20)
    output_text = tokenizer.batch_decode(output_ids)[0]
    print(output_text)
    breakpoint()


if __name__ == '__main__':
    main()
