import argparse
import torch
import tiktoken
from model import GPT, inference

def main():
    parser = argparse.ArgumentParser(description="Interactive Q&A with GPT model")
    parser.add_argument("-i", "--model_path", type=str, required=True, help="Path to the model checkpoint")
    args = parser.parse_args()

    # Initialize model and tokenizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Loading model from {args.model_path}...")
    model = GPT.from_pretrained(args.model_path, device)
    enc = tiktoken.get_encoding("gpt2")
    
    print("\nModel loaded! Enter your questions (type 'quit' to exit):")
    
    while True:
        # Get user input
        try:
            user_input = input("\n> ")
        except EOFError:
            break
            
        # Check for exit command
        if user_input.lower() in ['quit', 'exit', 'q']:
            print("Goodbye!")
            break
            
        if not user_input.strip():
            continue
            
        try:
            # Generate response
            response = inference(
                model=model,
                input_text=user_input,
                tokenizer=enc,
                max_new_tokens=100,
                stop_token=198,
                temperature=0,
            )
            # breakpoint()
            # Clean up the response by removing the input prompt
            print(response, end='' if response[-1] == '\n' else '\n')
                
        except Exception as e:
            print(f"Error generating response: {str(e)}")

if __name__ == "__main__":
    main()
