import os
import torch
from model import GPTConfig, GPT

def load_model(checkpoint_path, device="cuda:0"):
    """
    Loads a GPT model from a checkpoint file.
    Args:
        checkpoint_path (str): The path to the model checkpoint file.
        device (str, optional): The device to load the model onto. Defaults to "cuda:0".
    Returns:
        tuple: A tuple containing the loaded model and its configuration.
    """
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model_args = checkpoint['model_args']
    gptconf = GPTConfig(**model_args)
    model = GPT(gptconf)
    model.load_state_dict(checkpoint['model'])
    model.to(device)
    model.eval()
    return model, gptconf

def encode(text, tokenizer, block_size):
    """
    Encodes a text string into a tensor of token IDs.
    Args:
        text (str): The input text to encode.
        tokenizer: The tokenizer object with an `encode` method.
        block_size (int): The maximum sequence length.
    Returns:
        torch.Tensor: A tensor of token IDs with shape [1, seq_len].
    """
    tokens = tokenizer.encode(text)
    tokens = tokens[:block_size]
    return torch.tensor(tokens, dtype=torch.long).unsqueeze(0)  # [1, seq_len]

def decode(tokens, tokenizer):
    """
    Decodes a list or tensor of token IDs into a text string.
    Args:
        tokens: The token IDs to decode.
        tokenizer: The tokenizer object with a `decode` method.
    Returns:
        str: The decoded text string.
    """
    return tokenizer.decode(tokens)

if __name__ == "__main__":
    import argparse
    import tiktoken

    parser = argparse.ArgumentParser()
    parser.add_argument("--ckpt", type=str, required=True, help="Path to model checkpoint (ckpt.pt)")
    parser.add_argument("--prompt", type=str, required=True, help="Prompt text for generation")
    parser.add_argument("--max_new_tokens", type=int, default=100, help="Number of tokens to generate")
    parser.add_argument("--device", type=str, default="cuda:0", help="Device to run inference on")
    args = parser.parse_args()

    # Load tokenizer
    tokenizer = tiktoken.get_encoding("gpt2")

    # Load model
    model, gptconf = load_model(args.ckpt, device=args.device)

    # Encode input prompt
    input_ids = encode(args.prompt, tokenizer, gptconf.block_size).to(args.device)

    # Generate new tokens
    with torch.no_grad():
        output_ids = model.generate(input_ids, max_new_tokens=args.max_new_tokens)
    output_ids = output_ids[0].cpu().tolist()

    # Decode output tokens
    output_text = decode(output_ids, tokenizer)
    print("=== Generated Text ===")
    print(output_text)
