# services/local_judge.py
"""
Local judge model service.
Uses a local Qwen model instead of a remote API call.
"""

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Dict, Any
import json
import re
from config.settings import QUERY_DECOMPOSITION_CONFIG


class LocalJudgeModel:
    """Local judge model."""
    
    def __init__(self, model_path: str = None, device: str = None):
        self.model_path = model_path or QUERY_DECOMPOSITION_CONFIG["division_model_path"]
        self.device = device or QUERY_DECOMPOSITION_CONFIG["division_device_map"]
        self.model = None
        self.tokenizer = None
        self._load_model()
    
    def _load_model(self):
        """Load the local model."""
        print(f"[LocalJudge] Loading Qwen3 model from {self.model_path}...")
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_path,
                torch_dtype="auto",
                device_map=self.device
            )
            print(f"[LocalJudge] Qwen3 model loaded successfully")
        except Exception as e:
            print(f"[LocalJudge] Error loading Qwen3 model: {e}")
            raise e
    
    def generate_response(self, prompt: str, max_new_tokens: int = 512) -> str:
        """Generate a response (following the Qwen3 chat template)."""
        try:
            # Build chat messages
            messages = [{"role": "user", "content": prompt}]
            
            # Apply chat template
            text = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
                enable_thinking=True,
            )
            
            # Model inputs
            model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
            
            # Generate
            with torch.no_grad():
                generated_ids = self.model.generate(
                    **model_inputs,
                    max_new_tokens=max_new_tokens,
                    temperature=0.0,
                    do_sample=False,
                    pad_token_id=self.tokenizer.eos_token_id
                )
            
            # Extract generated part
            output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
            
            # Split out thinking tokens if present
            try:
                # Find token id 151668 (</think>)
                index = len(output_ids) - output_ids[::-1].index(151668)
            except ValueError:
                index = 0
            
            thinking_content = self.tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
            content = self.tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
            
            # Return final content (drop thinking)
            return content.strip()
            
        except Exception as e:
            print(f"[LocalJudge] Error generating response: {e}")
            return ""


# Global model instance (lazy initialization)
_global_judge_model = None


def get_judge_model() -> LocalJudgeModel:
    """Get a global local-judge instance."""
    global _global_judge_model
    if _global_judge_model is None:
        _global_judge_model = LocalJudgeModel()
    return _global_judge_model
