#!/usr/bin/env python3
"""
Main script for running group think simulation inference.
"""

import argparse
import sys
import os

# Add the current directory to the path so we can import inference_text_infilling
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from inference_simulation import group_think_simulation_generate
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch


def load_model_and_tokenizer(model_path: str):
    """Load model and tokenizer from the specified path."""
    print(f"Loading model from: {model_path}")
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    
    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device_map="auto" if torch.cuda.is_available() else None,
    )
    
    # Set pad token if not set
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    print(f"Model loaded successfully on device: {model.device}")
    return model, tokenizer


def main():
    parser = argparse.ArgumentParser(
        description="Run group think simlation inference"
    )
    parser.add_argument(
        "--model_path", 
        type=str, 
        required=True,
        help="Path to the model directory or HuggingFace model name"
    )
    parser.add_argument(
        "--prompt", 
        type=str, 
        required=True,
        help="Input prompt (must end with '<Parallel>')"
    )
    parser.add_argument(
        "--num_paths", 
        type=int, 
        default=4,
        help="Number of parallel thinking paths (default: 4)"
    )
    parser.add_argument(
        "--max_path_tokens", 
        type=int, 
        default=512,
        help="Maximum tokens per path (default: 512)"
    )
    parser.add_argument(
        "--verbose", 
        action="store_true",
        help="Enable verbose output"
    )
    parser.add_argument(
        "--output_file", 
        type=str, 
        default=None,
        help="Optional output file to save results"
    )
    
    args = parser.parse_args()
    
    try:
        # Load model and tokenizer
        model, tokenizer = load_model_and_tokenizer(args.model_path)
        
        # Run inference
        print(f"\nStarting group think simulation inference with {args.num_paths} paths...")
        print(f"Prompt: {args.prompt}")
        print(f"Max tokens per path: {args.max_path_tokens}")
        
        # Generate
        output_ids = group_think_simulation_generate(
            model=model,
            tokenizer=tokenizer,
            prompt=args.prompt,
            shift=args.shift,
            max_path_tokens=args.max_path_tokens,
            verbose=args.verbose
        )

        # Save to file if requested
        if args.output_file:
            with open(args.output_file, 'w', encoding='utf-8') as f:
                f.write(f"Prompt: {args.prompt}\n")
                f.write(f"Model: {args.model_path}\n")
                f.write(f"Parameters: num_paths={args.num_paths}, max_tokens={args.max_path_tokens}\n")
                f.write("="*80 + "\n\n")

                full_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
                f.write(full_text + "\n")
            
            print(f"\nResults saved to: {args.output_file}")
        
    except KeyboardInterrupt:
        print("\nInference interrupted by user.")
        sys.exit(1)
    except Exception as e:
        print(f"Error during inference: {e}")
        sys.exit(1)


if __name__ == "__main__":
    main()
