from pathlib import Path

from transformers import AutoTokenizer

from modeling.mamba import MambaConfig, MambaForCausalLM, load_pretrained
from arguments import Args


def get_tokenizer(args: Args):
    tokenizer = AutoTokenizer.from_pretrained(args.tok_path)
    return tokenizer


def get_ckpt_path(path):
    if not path.endswith(".pt"):
        # a directory
        path = Path(path)
        checkpoint_files = [file for file in list(path.iterdir()) if file.name.endswith(".pt")]
        assert len(checkpoint_files) == 1, f"None or multiple .pt found in {path}"
        path = path / checkpoint_files[0]
    return Path(path)


def get_model(args: Args) -> MambaForCausalLM:
    print(f"Loading config from {args.model_config}")
    config = MambaConfig()
    print(f"Loading config: {config}")
    model = MambaForCausalLM(config)
    if args.pretrained_path != "":
        print(f"Loading checkpoint from: {args.pretrained_path}")
        load_pretrained(model, get_ckpt_path(args.pretrained_path))
    return model


def main():
    args: Args = Args().parse_args()
    output_dir = Path(args.output_dir)
    output_dir.mkdir(exist_ok=True, parents=True)
    print("============ args ================")
    print(args)
    args.save(output_dir / "args.json")
    tokenizer = get_tokenizer(args)
    model = get_model(args).cuda()
    prompt = "My name is"
    inputs = tokenizer(prompt, return_tensors='pt')
    breakpoint()
    # print(inputs)

    outputs = model.generate(tokenizer, prompt)
    print(outputs)


if __name__ == "__main__":
    main()
