import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import argparse

def chat():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, required=True, help="Path to the exported model")
    args = parser.parse_args()

    print(f"🚀 Loading model from: {args.model_path}")

    tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path, 
        device_map="auto", 
        torch_dtype="auto", 
        trust_remote_code=True
    )

    print("\n💬 Model loaded! Type 'exit' to quit.\n")
    print("-" * 50)

    history = []
    while True:
        query = input("User: ")
        if query.strip().lower() == "exit":
            break
        
        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": query}
        ]
        
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

        generated_ids = model.generate(
            **model_inputs,
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.9
        )
        
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]

        response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        print(f"Agent: {response}\n")
        print("-" * 50)

if __name__ == "__main__":
    chat()