import sys
import time
import jsonlines
from functools import partial
sys.path.append('../..')

import torch
from modeling.mamba2.modeling_mamba2_dao import Mamba2ForCausalLM, MambaConfig
from transformers import AutoTokenizer


def convert_state_dict(sd):
    new_sd = {}
    for k in sd:
        old_k = k
        replacements = {
            'ln0': 'pre_ln',
            'emb.': 'embeddings.',
            'att.': 'attention.',
            'ffn.': 'feed_forward.',
            '.time_mix_k': '.time_mix_key',
            '.time_mix_v': '.time_mix_value',
            '.time_mix_r': '.time_mix_receptance',
            '.time_mix_g': '.time_mix_gate',
        }
        for a, b in replacements.items():
            k = k.replace(a, b)
        new_sd[k] = sd[old_k]
    return new_sd


def gen(prompt: str, model, tokenizer, top_k=10, max_new_tokens=5, device='cuda'):
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device=device)
    start_time = time.time()
    max_length = input_ids.shape[1] + max_new_tokens
    # print(f"{max_length = }")
    out = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        # top_k=top_k,
        cg=True,
        return_dict_in_generate=True,
        output_scores=True,
        enable_timing=False,
    )
    print(f"Time spent: {time.time() - start_time}")
    output_text = tokenizer.batch_decode(out.sequences.tolist())[0][len(prompt):]
    print("==== Output ====")
    print(output_text)
    print('================')
    return output_text


def main():
    pretrained_path = '/mnt/data/user/tc_agi/zxr/mamba/mamba2-370m'
    tok_path = '../../tokenizers/mamba-tok'
    examples = list(jsonlines.open('/home/jeeves/long-rnn/data/passkey_example.jsonl', 'r'))
    prompt = examples[0]['context'] + '\nThe pass key is'
    prompt = '''There is an important info hidden inside a lot of irrelevant text. Find it and memorize it. I will quiz you about the important information.

The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.
The pass key is 12345. Remember it. 12345 is the pass key.
The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.
The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.
The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.
The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.
The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.
The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.

What is the pass key?

The pass key is'''
    device = 'cuda'
    dtype = torch.float16
    top_k = 1

    print("==== prompt ====")
    print(prompt)
    print("================")
    print(f"Loading tokenizer {tok_path}")
    tokenizer = AutoTokenizer.from_pretrained(tok_path, trust_remote_code=True)
    print(f"Loading model {pretrained_path}")
    model = Mamba2ForCausalLM.from_pretrained(pretrained_path).to(dtype=dtype, device=device)
    model.eval()

    print("==== Generating ====")
    start_time = time.time()
    pred = gen(prompt, model=model, tokenizer=tokenizer, max_new_tokens=5, top_k=top_k)
    print(f"time spent: {time.time() - start_time}")
    print(pred)

    infer = partial(gen, model=model, tokenizer=tokenizer)
    infer('The capital of China is')
    breakpoint()


if __name__ == '__main__':
    main()
