#!/usr/bin/env python3
"""
Interactive script for running Group Think cross-attend text infilling inference.
Shows real-time generation of different thinking paths.
"""

import argparse
import sys
import os
import time
import json
from datetime import datetime
from typing import Optional, List

# Add the current directory to the path so we can import inference_text_infilling
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from inference_text_infilling import PathState, group_think_cross_attend_generate
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch


class SessionLogger:
    """Logger to track all interactions and outputs during the session."""
    
    def __init__(self, output_file: Optional[str] = None):
        self.output_file = output_file
        self.session_data = {
            "session_start": datetime.now().isoformat(),
            "interactions": []
        }
        
    def log_interaction(self, prompt: str, path_states: List[PathState], tokenizer, 
                       num_paths: int, shift: int, max_path_tokens: int):
        """Log a complete interaction with prompt and all generated paths."""
        interaction = {
            "timestamp": datetime.now().isoformat(),
            "prompt": prompt,
            "parameters": {
                "num_paths": num_paths,
                "shift": shift,
                "max_path_tokens": max_path_tokens
            },
            "generated_paths": []
        }
        
        for i, path_state in enumerate(path_states):
            path_data = {
                "thinker_id": i + 1,
                "full_text": tokenizer.decode(path_state.ids[0], skip_special_tokens=True),
                "tokens_generated": path_state.length,
                "finished": path_state.finished
            }
            interaction["generated_paths"].append(path_data)
        
        self.session_data["interactions"].append(interaction)
        
        # Save incrementally if output file is specified
        if self.output_file:
            self.save_session()
    
    def save_session(self):
        """Save the complete session data to file."""
        if not self.output_file:
            return
            
        self.session_data["session_end"] = datetime.now().isoformat()
        
        with open(self.output_file, 'w', encoding='utf-8') as f:
            json.dump(self.session_data, f, indent=2, ensure_ascii=False)
    
    def get_summary(self) -> str:
        """Get a summary of the session."""
        total_interactions = len(self.session_data["interactions"])
        total_paths = sum(len(interaction["generated_paths"]) 
                         for interaction in self.session_data["interactions"])
        return f"Session completed: {total_interactions} interactions, {total_paths} total paths generated"


class InteractiveGenerator:
    def __init__(self, model, tokenizer, num_paths: int, shift: int, max_path_tokens: int):
        self.model = model
        self.tokenizer = tokenizer
        self.num_paths = num_paths
        self.shift = shift
        self.max_path_tokens = max_path_tokens
        self.path_outputs = [""] * num_paths
        self.path_finished = [False] * num_paths
        self.generation_done = False
        
    def update_path_output(self, path_idx: int, new_text: str):
        """Update the output for a specific path."""
        self.path_outputs[path_idx] = new_text
        
    def mark_path_finished(self, path_idx: int):
        """Mark a path as finished."""
        self.path_finished[path_idx] = True
        
    def is_all_finished(self) -> bool:
        """Check if all paths are finished."""
        return all(self.path_finished)
        
    def display_paths(self):
        """Display all current path outputs."""
        # Clear screen (works on most terminals)
        print("\033[2J\033[H", end="")
        
        print("=" * 100)
        print("Group Think THINKING PATHS - REAL-TIME GENERATION")
        print("=" * 100)
        
        for i in range(self.num_paths):
            status = "✓ FINISHED" if self.path_finished[i] else "⟳ THINKING..."
            print(f"\n🧠 THINKER {i+1} [{status}]")
            print("-" * 50)
            if self.path_outputs[i]:
                print(self.path_outputs[i])
            else:
                print("Starting to think...")
            print()
        
        if not self.generation_done:
            print("Press Ctrl+C to stop generation")
        else:
            print("All thinkers have finished!")
        print("=" * 100)


def load_model_and_tokenizer(model_path: str):
    """Load model and tokenizer from the specified path."""
    print(f"Loading model from: {model_path}")
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    
    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device_map="auto" if torch.cuda.is_available() else None,
    )
    
    # Set pad token if not set
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    print(f"Model loaded successfully on device: {model.device}")
    return model, tokenizer


def interactive_generate(
    model,
    tokenizer,
    prompt: str,
    num_paths: int = 4,
    shift: int = 2000,
    max_path_tokens: int = 128,
    session_logger: Optional[SessionLogger] = None,
):
    """Interactive generation function that uses group_think_cross_attend_generate with real-time updates."""
    
    # Initialize interactive generator
    interactive_gen = InteractiveGenerator(model, tokenizer, num_paths, shift, max_path_tokens)

    # Define step callback for real-time updates
    def step_callback(path_states, tokenizer, step):
        # Update display for all paths
        for i, path_state in enumerate(path_states):
            current_text = tokenizer.decode(path_state.ids[0], skip_special_tokens=True)
            interactive_gen.update_path_output(i, current_text)
            if path_state.finished:
                interactive_gen.mark_path_finished(i)
        
        # Update display every few steps to avoid too much flickering
        if step % 3 == 0:
            interactive_gen.display_paths()
            time.sleep(0.1)  # Small delay for readability

    # Initial display
    interactive_gen.display_paths()

    # Use the existing group_think_cross_attend_generate function with callback
    path_states = group_think_cross_attend_generate(
        model=model,
        tokenizer=tokenizer,
        prompt=prompt,
        num_paths=num_paths,
        shift=shift,
        max_path_tokens=max_path_tokens,
        verbose=False,  # We handle our own display
        step_callback=step_callback
    )

    # Final display
    interactive_gen.generation_done = True
    interactive_gen.display_paths()
    
    # Log the interaction if session logger is provided
    if session_logger:
        session_logger.log_interaction(
            prompt=prompt,
            path_states=path_states,
            tokenizer=tokenizer,
            num_paths=num_paths,
            shift=shift,
            max_path_tokens=max_path_tokens
        )
    
    return path_states


def main():
    parser = argparse.ArgumentParser(
        description="Run interactive Group Think cross-attend text infilling inference"
    )
    parser.add_argument(
        "--model_path", 
        type=str, 
        required=True,
        help="Path to the model directory or HuggingFace model name"
    )
    parser.add_argument(
        "--num_paths", 
        type=int, 
        default=4,
        help="Number of parallel thinking paths (default: 4)"
    )
    parser.add_argument(
        "--shift", 
        type=int, 
        default=3000,
        help="Position shift between paths (default: 3000)"
    )
    parser.add_argument(
        "--max_path_tokens", 
        type=int, 
        default=512,
        help="Maximum tokens per path (default: 512)"
    )
    parser.add_argument(
        "--output_file", 
        type=str, 
        default=None,
        help="Optional output file to save all session traces (JSON format)"
    )
    
    args = parser.parse_args()
    
    try:
        # Load model and tokenizer
        model, tokenizer = load_model_and_tokenizer(args.model_path)
        
        # Initialize session logger if output file is specified
        session_logger = SessionLogger(args.output_file) if args.output_file else None
        
        print("\n" + "="*80)
        print("INTERACTIVE GROUP THINKING")
        print("="*80)
        print("Enter your prompt (it will automatically be appended with '<Parallel>')")
        print("Type 'quit' or 'exit' to stop")
        if args.output_file:
            print(f"Session traces will be saved to: {args.output_file}")
        print("="*80)
        
        while True:
            try:
                # Get user input
                user_prompt = input("\n🧠 Your question: ").strip()
                
                if user_prompt.lower() in ['quit', 'exit', 'q']:
                    print("Goodbye!")
                    break
                    
                if not user_prompt:
                    print("Please enter a question.")
                    continue
                
                # Ensure prompt ends with <Parallel>
                if not user_prompt.endswith("<Parallel>"):
                    user_prompt += "<Parallel>"
                
                print(f"\nStarting generation with {args.num_paths} thinkers...")
                print("Watch as different thinking paths develop in real-time!")
                time.sleep(1)
                
                # Run interactive generation
                path_states = interactive_generate(
                    model=model,
                    tokenizer=tokenizer,
                    prompt=user_prompt,
                    num_paths=args.num_paths,
                    shift=args.shift,
                    max_path_tokens=args.max_path_tokens,
                    session_logger=session_logger
                )
                
                # Wait for user to continue
                input("\nPress Enter to ask another question...")
                
            except KeyboardInterrupt:
                print("\n\nGeneration interrupted by user.")
                break
            except (ValueError, RuntimeError, OSError) as e:
                print(f"\nError during generation: {e}")
                continue
        
        # Save final session data and show summary
        if session_logger:
            session_logger.save_session()
            print(f"\n{session_logger.get_summary()}")
            print(f"Session data saved to: {args.output_file}")
        
    except (ValueError, RuntimeError, OSError) as e:
        print(f"Error: {e}")
        sys.exit(1)


if __name__ == "__main__":
    main()
