import fire

from src.generator import Generator
from src.modeling_args import LoraModelArgs
from src.modeling_lora import LoraLLaMA
from src.tokenizer import Tokenizer
from src.utils import setup_model_parallel, barrier


def main(
        ckpt_dir: str = "result/7B/gsm8k-alpha-0.01-diversity-5-seed-2048",
        max_seq_len: int = 512,
        lora_rank: int = 64,
        config_file: str = 'config/7B/params.json',
        tokenizer_path: str = 'config/tokenizer.model',
        seed: int = None
):
    prompts = [
        "If I had five apples, I gave apples to each of my parents, and then I bought 6 more apples."
        "How many apples I have now?",
        "A red ball costs $5, a blue ball costs $8 and a yellow ball costs $9."
        "I bought five blue balls, and bought yellow balls twice that many of blue balls I bought."
        "How much would I pay?",
        "Could you write a Python code to express your emotion?",
        "I would like to go to school, and I can ride a bike or drive a car. "
        "Riding bike costs me $1 per hour, Driving car costs me $5 per hour. "
        "And riding bike takes 4 times longer than driving car to get me to the school. "
        "Which costs me less?"
    ]
    local_rank, world_size = setup_model_parallel(
        use_float16=True, seed=seed)
    params = LoraModelArgs(
        max_seq_len=max_seq_len,
        local_rank=local_rank,
        world_size=world_size,
        r=lora_rank).from_json(config_file)
    model = LoraLLaMA(params)
    model.load(ckpt_dir)
    barrier()
    generator = Generator(model, Tokenizer(tokenizer_path))
    results = generator.generate(
        prompts=prompts,
        max_gen_len=max_seq_len,
        temperature=0.0)
    for result in results:
        print("=" * 100)
        print(result['instruction'] + result['output'])


if __name__ == '__main__':
    fire.Fire(main)
