# llm_judge.py
import concurrent.futures
import hashlib
import json
import os
import threading
import time
from functools import lru_cache
from typing import Dict, List, Optional, Set

import torch
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field

from .universal_data_loader import UniversalBanditDataLoader

BASE_URL = "https://api.openai.com/v1"


def format_query(original_question: str, sub_question: str) -> str:
    """
    Formats the original and sub-question into a compact, effective query string.
    """
    return (
        f"Context: {original_question.strip()} | Focus: {sub_question.strip()}"
    )


class JudgeDecision(BaseModel):
    """The decision made by the Judge LLM."""

    best_document_index: int = Field(
        description="The 0-based index of the most relevant document."
    )


class JudgeLLM:
    """
    A wrapper for a Language Model that acts as a judge to provide a reward signal for retrieval.
    """

    def __init__(
        self,
        data_loader: UniversalBanditDataLoader,
        cache_path: Optional[str] = None,
        persist_cache: bool = True,
    ):
        """
        Args:
            data_loader: Data loader used to map indices to texts
            cache_path: Optional path to a JSONL cache file for persisting judge decisions
            persist_cache: Whether to persist/load on-disk cache
        """
        self.model = ChatOpenAI(
            model="gpt-4o-mini", temperature=0, base_url=BASE_URL
        ).with_structured_output(JudgeDecision)
        self.system_prompt = self._build_system_prompt()
        self.data_loader = data_loader

        # Disk cache setup
        self.persist_cache = bool(persist_cache)
        self.cache_path = cache_path or os.path.join(
            "runs", "judge_cache.jsonl"
        )
        self._cache_lock = threading.Lock()
        self._disk_cache: Dict[str, int] = {}
        if self.persist_cache:
            os.makedirs(os.path.dirname(self.cache_path) or ".", exist_ok=True)
            try:
                with open(self.cache_path, "r", encoding="utf-8") as f:
                    for line in f:
                        line = line.strip()
                        if not line:
                            continue
                        try:
                            obj = json.loads(line)
                            key = obj.get("key")
                            val = obj.get("best_document_index")
                            if isinstance(key, str) and isinstance(val, int):
                                self._disk_cache[key] = val
                        except Exception:
                            continue
            except FileNotFoundError:
                pass

    def _build_system_prompt(self) -> str:
        return """You are an impartial and meticulous AI judge. Your task is to determine which of the provided documents contains useful information to answer the given question, especially the `Focus` one.

Carefully review each document and respond with a JSON object containing the 0-based index of the relevant document. Smaller index is more relevant.
"""

    def _make_cache_key(self, question: str, formatted_docs: str) -> str:
        data = (question + "\n-----\n" + formatted_docs).encode("utf-8")
        return hashlib.sha256(data).hexdigest()

    @lru_cache(maxsize=None)
    def judge_documents(self, question: str, formatted_docs: str) -> int:
        """
        Asks the LLM to judge which document is most relevant.

        Args:
            original_query: The main, complex question.
            sub_question: The decomposed sub-question currently being addressed.
            documents: A list of retrieved document contents.

        Returns:
            A JudgeDecision object containing the reasoning and the chosen index.
        """

        # Check on-disk cache first
        key = self._make_cache_key(question, formatted_docs)
        if self.persist_cache and key in self._disk_cache:
            return self._disk_cache[key]

        human_prompt = f"""
Question:
{question}

Retrieved Documents:
{formatted_docs}

Based on the question, which document is the most relevant?
"""
        messages = [
            SystemMessage(content=self.system_prompt),
            HumanMessage(content=human_prompt),
        ]

        for _ in range(5):
            try:
                decision = self.model.invoke(messages)
                idx = int(decision.best_document_index)  # type: ignore
                if self.persist_cache and idx >= 0:
                    rec = {"key": key, "best_document_index": idx}
                    with self._cache_lock:
                        # Update in-memory map first
                        self._disk_cache[key] = idx
                        # Append to file
                        with open(self.cache_path, "a", encoding="utf-8") as f:
                            f.write(json.dumps(rec) + "\n")
                return idx
            except Exception as e:
                print(f"Error invoking Judge LLM: {e}")
                time.sleep(1)
        return -1

    def batched_judge(
        self,
        batch_query_indices: torch.Tensor,
        batch_candidate_arm_indices: torch.Tensor,
    ) -> torch.Tensor:
        query_texts = self.data_loader.get_texts_by_indices(
            query_indices=batch_query_indices.tolist()
        )["queries"]

        batch_candidate_arms_text = []
        for i in range(batch_candidate_arm_indices.size(0)):
            arm_indices = batch_candidate_arm_indices[i].tolist()
            texts = self.data_loader.get_texts_by_indices(
                arm_indices=arm_indices
            )["arms"]
            formatted_doc = "\n\n".join(
                [f"--- Document {i} ---\n{doc}" for i, doc in enumerate(texts)]
            )
            batch_candidate_arms_text.append(formatted_doc)

        # for each query and k condidate arms (corpus)
        # run the judge_documents function
        with concurrent.futures.ThreadPoolExecutor() as executor:
            results = list(
                executor.map(
                    self.judge_documents,
                    query_texts,
                    batch_candidate_arms_text,
                )
            )

        rslt_tensor = torch.tensor(
            results, device=batch_candidate_arm_indices.device, dtype=torch.long
        )

        # If any position is -1 (judge failed), replace with a random index in [0, k-1]
        k = batch_candidate_arm_indices.size(1)
        mask = rslt_tensor == -1
        if mask.any():
            rslt_tensor[mask] = torch.randint(
                0,
                k,
                size=(int(mask.sum().item()),),
                device=batch_candidate_arm_indices.device,
                dtype=torch.long,
            )

        # Safety: clamp any out-of-range indices into [0, k-1]
        rslt_tensor = torch.clamp(rslt_tensor, 0, k - 1)

        return rslt_tensor
