#!/usr/bin/env python3
"""
Main script for running multiverse cross-attend text infilling inference.
"""

import argparse
import sys
import os

# Add the current directory to the path so we can import inference_text_infilling
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from inference_text_infilling import group_think_cross_attend_generate
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch


def load_model_and_tokenizer(model_path: str):
    """Load model and tokenizer from the specified path."""
    print(f"Loading model from: {model_path}")
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    
    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device_map="auto" if torch.cuda.is_available() else None,
    )
    
    # Set pad token if not set
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    print(f"Model loaded successfully on device: {model.device}")
    return model, tokenizer


def format_output(path_states, tokenizer):
    """Format and print the generated paths."""
    print("\n" + "="*80)
    print("GENERATED PATHS:")
    print("="*80)
    
    for i, path_state in enumerate(path_states):
        print(f"\n--- Thinker {i+1} ---")
        # Decode the full sequence (including seed)
        full_text = tokenizer.decode(path_state.ids[0], skip_special_tokens=True)
        print(full_text)
        print("-" * 40)


def main():
    parser = argparse.ArgumentParser(
        description="Run multiverse cross-attend text infilling inference"
    )
    parser.add_argument(
        "--model_path", 
        type=str, 
        required=True,
        help="Path to the model directory or HuggingFace model name"
    )
    parser.add_argument(
        "--prompt", 
        type=str, 
        required=True,
        help="Input prompt (must end with '<Parallel>')"
    )
    parser.add_argument(
        "--num_paths", 
        type=int, 
        default=4,
        help="Number of parallel thinking paths (default: 4)"
    )
    parser.add_argument(
        "--shift", 
        type=int, 
        default=3000,
        help="Position shift between paths (default: 3000)"
    )
    parser.add_argument(
        "--max_path_tokens", 
        type=int, 
        default=512,
        help="Maximum tokens per path (default: 512)"
    )
    parser.add_argument(
        "--verbose", 
        action="store_true",
        help="Enable verbose output"
    )
    parser.add_argument(
        "--output_file", 
        type=str, 
        default=None,
        help="Optional output file to save results"
    )
    
    args = parser.parse_args()
    
    try:
        # Load model and tokenizer
        model, tokenizer = load_model_and_tokenizer(args.model_path)
        
        # Run inference
        print(f"\nStarting inference with {args.num_paths} paths...")
        print(f"Prompt: {args.prompt}")
        print(f"Max tokens per path: {args.max_path_tokens}")
        print(f"Position shift: {args.shift}")
        
        # Generate
        path_states = group_think_cross_attend_generate(
            model=model,
            tokenizer=tokenizer,
            prompt=args.prompt,
            num_paths=args.num_paths,
            shift=args.shift,
            max_path_tokens=args.max_path_tokens,
            verbose=args.verbose
        )
        
        # Format and display results
        format_output(path_states, tokenizer)
        
        # Save to file if requested
        if args.output_file:
            with open(args.output_file, 'w', encoding='utf-8') as f:
                f.write(f"Prompt: {args.prompt}\n")
                f.write(f"Model: {args.model_path}\n")
                f.write(f"Parameters: num_paths={args.num_paths}, shift={args.shift}, max_tokens={args.max_path_tokens}\n")
                f.write("="*80 + "\n\n")
                
                for i, path_state in enumerate(path_states):
                    f.write(f"--- Thinker {i+1} ---\n")
                    full_text = tokenizer.decode(path_state.ids[0], skip_special_tokens=True)
                    f.write(full_text + "\n")
                    f.write("-" * 40 + "\n\n")
            
            print(f"\nResults saved to: {args.output_file}")
        
    except KeyboardInterrupt:
        print("\nInference interrupted by user.")
        sys.exit(1)
    except Exception as e:
        print(f"Error during inference: {e}")
        sys.exit(1)


if __name__ == "__main__":
    main()
