"""Mock LLM classes for testing training without expensive model inference."""

import itertools
from typing import List, Dict, Any
from dataclasses import dataclass


@dataclass
class MockRequestOutput:
    """Mock request output that mimics vllm.RequestOutput"""
    prompt: str
    outputs: List['MockCompletionOutput']
    prompt_token_ids: List[int] = None
    
    def __post_init__(self):
        if self.prompt_token_ids is None:
            # Create fake prompt token IDs
            self.prompt_token_ids = list(range(len(self.prompt.split())))


@dataclass 
class MockCompletionOutput:
    """Mock completion output that mimics vllm.CompletionOutput"""
    text: str
    token_ids: List[int] = None
    
    def __post_init__(self):
        if self.token_ids is None:
            # Create fake token IDs
            self.token_ids = list(range(len(self.text.split())))


class MockLLM:
    """Mock LLM that returns predefined completions in a cyclic manner."""
    
    def __init__(self, completions: List[str]):
        """Initialize with custom completions list."""
        self.completions = completions
        self.completion_cycle = itertools.cycle(self.completions)
        self.call_count = 0
        
    def generate(self, prompts: List[str], sampling_params: Any, lora_request: Any = None) -> List[MockRequestOutput]:
        """Generate fake completions for the given prompts."""
        results = []
        for i, prompt in enumerate(prompts):
            self.call_count += 1
            completion_text = next(self.completion_cycle)
            print(f"🤖 MockLLM Call #{self.call_count} - Prompt {i+1}: {completion_text[:50]}...")
            completion_output = MockCompletionOutput(text=completion_text)
            request_output = MockRequestOutput(
                prompt=prompt,
                outputs=[completion_output]
            )
            results.append(request_output)
        return results
    
    def chat(self, messages_list: List[List[Dict[str, str]]], sampling_params: Any, use_tqdm: bool = False, lora_request: Any = None) -> List[MockRequestOutput]:
        """Chat method that mimics vllm.LLM.chat interface."""
        results = []
        for i, messages in enumerate(messages_list):
            self.call_count += 1
            # Convert messages to a prompt string (simplified)
            prompt = ""
            for msg in messages:
                prompt += f"{msg['role']}: {msg['content']}\n"
            
            completion_text = next(self.completion_cycle)
            print(f"🤖 MockLLM Chat Call #{self.call_count} - Conversation {i+1}: {completion_text[:50]}...")
            completion_output = MockCompletionOutput(text=completion_text)
            request_output = MockRequestOutput(
                prompt=prompt,
                outputs=[completion_output]
            )
            results.append(request_output)
        return results


class MockSamplingParams:
    """Mock SamplingParams to avoid vLLM import."""
    
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)
            
    def clone(self):
        new_params = MockSamplingParams()
        for k, v in self.__dict__.items():
            setattr(new_params, k, v)
        return new_params


def inject_mock_llm_into_trainer(trainer, mock_llm):
    """Inject mock LLM into the real trainer and prevent expensive vLLM operations."""
    trainer.llm = mock_llm
    trainer._last_loaded_step = 0  # Prevent repeated loading attempts
    
    # Override the _move_model_to_vllm method to do nothing (skip expensive vLLM setup)
    def fake_move_model_to_vllm(self):
        print("⚡ Skipping expensive vLLM initialization (using mock LLM)")
        self.llm = mock_llm
        
    trainer._move_model_to_vllm = fake_move_model_to_vllm.__get__(trainer, type(trainer))
    
    # Replace sampling_params with fake version to avoid vllm dependency
    trainer.sampling_params = MockSamplingParams(
        max_tokens=trainer.max_completion_length,
        temperature=trainer.temperature,
        top_p=0.9,
        top_k=-1,
        min_p=0.0,
        repetition_penalty=1.0
    ) 