#!/usr/bin/env python3
"""
Model selection utility for LinearizeLLM.
Allows users to choose between different LLM models at the beginning of execution.
"""

import os
import sys
from pathlib import Path
from typing import Union, Optional

# Add src to path for direct execution
if __name__ == "__main__":
    sys.path.insert(0, str(Path(__file__).parent.parent.parent))

from src.utils.llm_model_manager import LLMConfig, get_available_models, create_llm
from src.utils.api_key_manager import APIKeyManager


def select_model_interactive() -> Union[str, LLMConfig]:
    """
    Interactive model selection for users.
    
    Returns:
        Either a string (model name) or LLMConfig object
    """
    print("\n🤖 MODEL SELECTION")
    print("=" * 50)
    
    # Get available models
    available_models = get_available_models()
    
    # Display available models by provider
    print("Available models by provider:")
    print()
    
    all_models = []
    model_index = 1
    
    for provider, models in available_models.items():
        print(f"{provider.upper()}:")
        for model in models:
            print(f"  {model_index}. {model}")
            all_models.append((provider, model))
            model_index += 1
        print()
    
    # Add custom configuration option
    print(f"{model_index}. Custom configuration")
    print(f"{model_index + 1}. Use default (o3/gpt-4o)")
    print()
    
    # Get user choice
    try:
        choice = input(f"Select model (1-{model_index + 1}) or press Enter for default: ").strip()
        
        if not choice:
            print("✅ Using default model: o3 (GPT-4o)")
            return "o3"
        
        choice_num = int(choice)
        
        if choice_num == model_index:
            # Custom configuration
            return get_custom_configuration()
        elif choice_num == model_index + 1:
            # Default
            print("✅ Using default model: o3 (GPT-4o)")
            return "o3"
        elif 1 <= choice_num <= len(all_models):
            # Selected model
            provider, model = all_models[choice_num - 1]
            print(f"✅ Selected model: {model} ({provider})")
            return model
        else:
            print("❌ Invalid choice, using default model: o3")
            return "o3"
            
    except (ValueError, KeyboardInterrupt):
        print("❌ Invalid input, using default model: o3")
        return "o3"


def get_custom_configuration() -> LLMConfig:
    """
    Get custom model configuration from user input.
    
    Returns:
        LLMConfig object
    """
    print("\n🔧 CUSTOM CONFIGURATION")
    print("-" * 30)
    
    # Provider selection
    providers = ["openai", "anthropic", "google"]
    print("Available providers:")
    for i, provider in enumerate(providers, 1):
        print(f"  {i}. {provider}")
    
    try:
        provider_choice = int(input(f"\nSelect provider (1-{len(providers)}): ").strip())
        if 1 <= provider_choice <= len(providers):
            provider = providers[provider_choice - 1]
        else:
            print("❌ Invalid choice, using OpenAI")
            provider = "openai"
    except (ValueError, KeyboardInterrupt):
        print("❌ Invalid input, using OpenAI")
        provider = "openai"
    
    # Model name
    model_name = input(f"Enter model name for {provider} (e.g., gpt-4o, claude-3-opus, gemini-2.5-pro): ").strip()
    if not model_name:
        if provider == "openai":
            model_name = "gpt-4o"
        elif provider == "anthropic":
            model_name = "claude-3-opus-20240229"
        elif provider == "google":
            model_name = "gemini-2.5-pro"
    
    # Temperature
    try:
        temp_input = input("Enter temperature (0.0-1.0, default 0.0): ").strip()
        temperature = float(temp_input) if temp_input else 0.0
    except (ValueError, KeyboardInterrupt):
        temperature = 0.0
    
    # Max tokens
    try:
        tokens_input = input("Enter max tokens (default 16000): ").strip()
        max_tokens = int(tokens_input) if tokens_input else 16000
    except (ValueError, KeyboardInterrupt):
        max_tokens = 16000
    
    # Timeout
    try:
        timeout_input = input("Enter timeout in seconds (default 120): ").strip()
        timeout = int(timeout_input) if timeout_input else 120
    except (ValueError, KeyboardInterrupt):
        timeout = 120
    
    config = LLMConfig(
        provider=provider,
        model_name=model_name,
        temperature=temperature,
        max_tokens=max_tokens,
        timeout=timeout
    )
    
    print(f"✅ Custom configuration created: {provider}/{model_name}")
    return config


def setup_model_with_api_key(model_spec: Union[str, LLMConfig]) -> Union[str, LLMConfig]:
    """
    Set up model and ensure appropriate API key is configured.
    
    Args:
        model_spec: Model specification (string or LLMConfig)
        
    Returns:
        Model specification with API key configured
    """
    if isinstance(model_spec, str):
        # Determine provider from model name
        if model_spec.startswith("gpt-") or model_spec in ["o3"]:
            provider = "openai"
        elif model_spec.startswith("claude-"):
            provider = "anthropic"
        elif model_spec.startswith("gemini-"):
            provider = "google"
        else:
            # Default to OpenAI for unknown models
            provider = "openai"
    else:
        provider = model_spec.provider
    
    # Set up API key for the provider
    try:
        manager = APIKeyManager(provider.capitalize())
        api_key = manager.get_api_key(provider.capitalize())
        
        # Set environment variable for the provider
        env_var_name = f"{provider.upper()}_API_KEY"
        os.environ[env_var_name] = api_key
        
        print(f"✅ {provider.capitalize()} API key configured")
        
    except Exception as e:
        print(f"⚠️ Warning: Could not configure {provider.capitalize()} API key: {e}")
    
    return model_spec


def quick_model_test(model_spec: Union[str, LLMConfig]) -> bool:
    """
    Quick test to verify the model is working.
    
    Args:
        model_spec: Model specification
        
    Returns:
        True if test passes, False otherwise
    """
    try:
        print("🧪 Testing model connection...")
        
        # Ensure environment variable is set before creating LLM
        if isinstance(model_spec, str):
            if model_spec.startswith("gpt-") or model_spec in ["o3"]:
                provider = "openai"
            elif model_spec.startswith("claude-"):
                provider = "anthropic"
            elif model_spec.startswith("gemini-"):
                provider = "google"
            else:
                provider = "openai"
        else:
            provider = model_spec.provider
        
        # Check if environment variable is set, if not set it
        env_var_name = f"{provider.upper()}_API_KEY"
        if not os.getenv(env_var_name):
            try:
                manager = APIKeyManager(provider.capitalize())
                api_key = manager.get_api_key(provider.capitalize())
                os.environ[env_var_name] = api_key
            except Exception as e:
                print(f"❌ Failed to get {provider.capitalize()} API key: {e}")
                return False
        
        # Now create the LLM
        llm = create_llm(model_spec)
        
        # Simple test prompt
        test_prompt = "Hello! Please respond with 'Model test successful' if you can see this message."
        response = llm.invoke(test_prompt)
        
        if "successful" in response.content.lower():
            print("✅ Model test successful!")
            return True
        else:
            print("⚠️ Model responded but test inconclusive")
            return True  # Still consider it working
            
    except Exception as e:
        print(f"❌ Model test failed: {e}")
        return False


def main():
    """
    Standalone model selection utility.
    """
    print("🚀 LINEARIZELLM MODEL SELECTOR")
    print("=" * 50)
    
    # Select model
    model_spec = select_model_interactive()
    
    # Set up API key
    model_spec = setup_model_with_api_key(model_spec)
    
    # Test model
    if quick_model_test(model_spec):
        print(f"\n🎉 Model setup complete!")
        print(f"Selected model: {model_spec}")
        
        # Show usage example
        if isinstance(model_spec, str):
            print(f"\nUsage example:")
            print(f"python run_linearizellm_data.py --file problem.tex --model {model_spec}")
        else:
            print(f"\nUsage example:")
            print(f"# Create config file: model_config.json")
            print(f"# Then run: python run_linearizellm_data.py --file problem.tex --model-config model_config.json")
    else:
        print(f"\n❌ Model setup failed. Please check your API key and try again.")


if __name__ == "__main__":
    main()