#!/usr/bin/env python3
"""Test script for training using a mock LLM that returns constant completions."""

import unsloth
import os, random, numpy as np, torch
import logging
import sys
from pathlib import Path
from typing import List, Dict, Any, Optional

# Add project root to Python path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

# Import configuration utilities
from utils.config_loader import load_config, get_run_name

# Load configuration directly from the standard location
CONFIG_PATH = project_root / 'configs' / 'training_config.yaml'
config = load_config(str(CONFIG_PATH))

print(f"Loading configuration from: {CONFIG_PATH}")

# Configuration for log suppression
if config['environment']['suppress_logs']:
    # Suppress all logging messages
    logging.getLogger().setLevel(logging.CRITICAL)

# Import the refactored ToolEnvironment class
from envs.environments import ToolEnvironment
from envs.validation_config import ValidationConfig
from rewards.power_system_reward import PowerSystemReward

# Environment variables and settings
SEED = config['training']['seed']

# Set random seeds for reproducibility
os.environ["PYTHONHASHSEED"] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Import dependencies
from trl import GRPOConfig
from utils.data_utils import preprocess_dataset
from tools.electrical_check import electrical_check

# Import mock utilities
from mocks.mock_llm import MockLLM, MockSamplingParams, inject_mock_llm_into_trainer
from mocks.mock_model import MockModel, MockTokenizer

# Monkey patch to avoid vLLM import
import sys
fake_vllm = type(sys)('vllm')
fake_vllm.SamplingParams = MockSamplingParams
fake_vllm.LLM = type('MockLLM', (), {})()
sys.modules['vllm'] = fake_vllm

# Now we can safely import the trainer
from trainers.grpo_env_trainer_unsloth import UnslothGRPOEnvTrainer


# ==============================================================================
# CONFIGURABLE COMPLETIONS FOR TESTING
# ==============================================================================
DEFAULT_TEST_COMPLETIONS = [
    # Fixed completions for max_steps=1 - include both tool calls and answers
    '<think>I need to check the electrical system for voltage issues.</think>\n<tool>{"name": "electrical_check", "args": {"system_name": "system_dict", "checks": ["voltage"]}}</tool>\n<answer>Voltage analysis completed successfully.</answer>',
    '<think>Let me check the connectivity of the power system.</think>\n<tool>{"name": "electrical_check", "args": {"system_name": "system_dict", "checks": ["connectivity"]}}</tool>\n<answer>Connectivity check shows proper system connections.</answer>', 
    '<think>I should verify the frequency parameters.</think>\n<tool>{"name": "electrical_check", "args": {"system_name": "system_dict", "checks": ["frequency"]}}</tool>\n<answer>Frequency parameters are within acceptable ranges.</answer>',
    '<think>Let me check the system completeness for isolated blocks.</think>\n<tool>{"name": "electrical_check", "args": {"system_name": "system_dict", "checks": ["completeness"]}}</tool>\n<answer>System completeness check performed successfully.</answer>',
    '<think>Let me run a comprehensive electrical check.</think>\n<tool>{"name": "electrical_check", "args": {"system_name": "system_dict"}}</tool>\n<answer>Comprehensive electrical validation completed.</answer>',
    
    # Completions with proper format
    '<think>Based on the system requirements, I need to perform electrical validation.</think>\n<tool>{"name": "electrical_check", "args": {"system_name": "system_dict", "checks": ["voltage", "connectivity"]}}</tool>\n<answer>Multi-parameter electrical check completed successfully.</answer>',
    '<think>The system analysis requires checking for potential issues.</think>\n<tool>{"name": "electrical_check", "args": {"system_name": "system_dict", "checks": ["completeness"]}}</tool>\n<answer>System integrity validation completed.</answer>',
    
    # Final answers with complete format
    '<think>Based on all the electrical checks performed, I can now provide a comprehensive analysis.</think>\n<answer>The power system analysis is complete. Based on the electrical checks, the system appears to be functioning properly with adequate voltage levels and connectivity.</answer>',
    '<think>After reviewing the system specifications and constraints.</think>\n<answer>After performing multiple electrical validation checks, the system shows proper connectivity between generators and loads with appropriate voltage levels.</answer>',
]


def main(test_completions: List[str] = None):
    """Test training using fake LLM with configurable completions.
    
    Args:
        test_completions: List of completion strings for the fake LLM to cycle through.
                         If None, uses DEFAULT_TEST_COMPLETIONS.
    """
    if test_completions is None:
        test_completions = DEFAULT_TEST_COMPLETIONS
        
    print("="*70)
    print("TESTING TRAINING WITH FAKE LLM")
    print("="*70)
    
    # Tool prompt for reasoning
    SYSTEM_PROMPT = """
Think step-by-step inside <think>...</think> tags. Provide your final answer inside <answer>...</answer> tags.

You have access to tools to help solve problems:
{tool_descriptions}

Call tools using a JSON command within <tool> tags, including:

"name": tool name
"args": tool arguments
Tool output will appear in <result> tags. Multiple tool calls are allowed if needed.
<answer>...</answer> tags must contain only the final answer.</answer>
"""

    # Setup environment
    dataset = preprocess_dataset(config['project']['dataset_name'], "train")
    
    # Create validation configuration
    validation_config = ValidationConfig()
    
    # Create the ToolEnvironment with electrical_check tool
    # Override max_steps for testing to allow multi-turn conversations
    env = ToolEnvironment(
        dataset=dataset,
        system_prompt=SYSTEM_PROMPT,
        tools=[electrical_check],
        max_steps=3,  # Override config to allow multiple conversation steps for testing
        reward=PowerSystemReward(
            tools=[electrical_check],
            power_system_weights=config['reward_weights']
        ),
        validation_config=validation_config,
    )
    
    print("Environment system prompt preview:")
    print(env.system_prompt[:200] + "...")
    
    # Create minimal mock model and tokenizer (skip expensive real model loading)
    print("Creating minimal mock model and tokenizer...")
    model = MockModel()
    tokenizer = MockTokenizer()
    
    # Move model to GPU if available
    if torch.cuda.is_available():
        model = model.cuda()
    
    # Training configuration using config values
    training_args = GRPOConfig(
        seed=SEED,
        output_dir=f"outputs/fake_test",
        run_name="fake_test",
        learning_rate=float(config['training']['learning_rate']),
        lr_scheduler_type=config['training']['lr_scheduler_type'],
        warmup_steps=int(config['training']['warmup_steps']),
        num_train_epochs=1,  # Short test
        temperature=float(config['training']['temperature']),
        max_steps=5,  # Override for testing (short run)
        bf16=bool(config['training']['bf16']),
        max_grad_norm=float(config['training']['max_grad_norm']),
        num_iterations=int(config['training']['num_iterations']),
        beta=float(config['training']['beta']),
        max_prompt_length=int(config['model']['max_prompt_length']),
        max_completion_length=int(config['model']['max_seq_length']) - int(config['model']['max_prompt_length']),
        per_device_train_batch_size=2,
        num_generations=2,
        gradient_accumulation_steps=int(config['training']['gradient_accumulation_steps']),
        gradient_checkpointing=False,  # Disable for fake model testing
        save_strategy="no",  # Don't save during testing
        save_steps=1000,
        save_only_model=bool(config['saving']['save_only_model']),
        use_vllm=bool(config['vllm']['use_vllm']),
        vllm_gpu_memory_utilization=0.1,  # Minimal since we won't actually use it
        logging_steps=int(config['saving']['logging_steps']),
        log_on_each_node=bool(config['saving']['log_on_each_node']),
        log_completions=bool(config['saving']['log_completions']),
        report_to=None,  # Don't report to wandb during testing
        reward_weights=env.get_reward_weights(),
    )
    
    # Initialize real trainer
    print("Initializing real trainer...")
    trainer = UnslothGRPOEnvTrainer(
        model=model,
        processing_class=tokenizer,
        reward_funcs=env.get_reward_funcs(),
        env=env,
        args=training_args,
        train_dataset=env.get_dataset(),
        eval_dataset=env.get_eval_dataset(),
        my_eval_steps=None,  # No eval during testing
    )
    
    # Inject mock LLM into real trainer
    print("Injecting mock LLM...")
    mock_llm = MockLLM(test_completions)
    inject_mock_llm_into_trainer(trainer, mock_llm)
    
    print("🚀 Starting test with real trainer and mock LLM...")
    print(f"Training configuration:")
    print(f"  - Training max_steps: {training_args.max_steps}")
    print(f"  - Environment max_steps: {env.state_manager.max_steps}")
    print(f"  - per_device_train_batch_size: {training_args.per_device_train_batch_size}")
    print(f"  - num_generations: {training_args.num_generations}")
    print(f"  - Dataset size: {len(env.get_dataset())}")
    print(f"  - Test completions: {len(test_completions)} completions")
    
    # Run training for a few steps
    try:
        print("📚 Starting trainer.train()...")
        trainer.train()
        print("✅ Training completed successfully!")
    except Exception as e:
        print(f"❌ Training failed (expected with mock LLM): {str(e)[:100]}...")


if __name__ == "__main__":
    # You can customize completions here for testing different scenarios
    # Fixed completions for max_steps=1 with proper XML format
    custom_completions = [
        '<think>Testing custom completion 1.</think>\n<answer>Based on the electrical check, the system analysis is complete.</answer>',
        '<think>Testing custom completion 2.</think>\n<answer>Connectivity analysis completed.</answer>',
        '<think>Testing comprehensive approach.</think>\n<answer>Custom test completed successfully.</answer>',
    ]
    
    # Run with default completions
    # main()
    
    # Uncomment to run with custom completions:
    main(test_completions=custom_completions)
    