import os.path as osp

try:
    from dotenv import load_dotenv
except Exception:  # pragma: no cover
    # Optional dependency: only needed for local CLI runs.
    load_dotenv = None

from core.arch.base import BaseComponent
from core.arch.system import CompoundAISystem
from core.utils.api import get_llm_output
from core.utils.textresnet import (
    BACKWARD_SYSTEM_PROMPT,
    DEFAULT_BACKWARD_OUTPUT_CONSTRAINTS,
    STOP_GRADIENT,
    BackwardResponse,
    GradientSignal,
    build_backward_context_prompt,
    parse_gradient_response,
    short_text,
)
from examples.metrics.f1_score import f1_score


class QuestionRewriter(BaseComponent):
    """Rephrase the question while preserving intent."""

    BACKWARD_OUTPUT_CONSTRAINTS = DEFAULT_BACKWARD_OUTPUT_CONSTRAINTS

    def __init__(self, model='openai/gpt-4o-mini', max_tokens=1024, temperature=0):
        super().__init__(
            description=(
                "HotpotQA step 1/5 (QuestionRewriter): rewrite the raw user question into a clear, "
                "retrieval-friendly query while preserving all key entities/constraints. "
                "Upstream: raw `question`. Downstream: `InfoExtractor` consumes `rewritten_query` to "
                "produce `search_keywords` for Wikipedia retrieval. This component is the first "
                "step in the system and is responsible for ensuring that the question is clear and "
                "specific enough for the downstream components to succeed. No components are before this one. "
                "All feedback should be stopped here."
            ),
            input_fields=["question"],
            output_fields=["rewritten_query", "local_fix", "upstream_grad"],
            variable="Rewrite the question to be clearer and specific. Output only the rewritten query.",
            config={"model": model, "max_tokens": max_tokens, "temperature": temperature},
        )

    def forward(self, question):
        prompt = f"{self.variable}\nOriginal: {question}\nRewritten:"
        rewritten = get_llm_output(
            message=prompt,
            model=self.config.model,
            max_new_tokens=self.config.max_tokens,
            temperature=self.config.temperature,
        )
        return {
            "rewritten_query": rewritten.strip(),
            "local_fix": "rewrite generated",
            "upstream_grad": STOP_GRADIENT,
        }

    def backward(self, signal: GradientSignal, full_traj, component_traj) -> BackwardResponse:
        question = component_traj.get("input", {}).get("question", "")
        rewritten = component_traj.get("output", {}).get("rewritten_query", "")
        lm_input = f"{self.variable}\nOriginal: {question}\nRewritten:"
        lm_output = rewritten
        objective = signal.feedback
        response_desc = "query used for retrieval"
        context_prompt = build_backward_context_prompt(
            variable_desc=self.description,
            variable_value=self.variable,
            lm_input=lm_input,
            lm_output=lm_output,
            objective_feedback=objective,
            response_desc=response_desc,
        )
        prompt = f"{context_prompt}\n\n{self.BACKWARD_OUTPUT_CONSTRAINTS}"
        raw = get_llm_output(
            message=prompt,
            model=self.config.model,
            temperature=0.4,
            system_prompt=BACKWARD_SYSTEM_PROMPT,
        )
        resp = parse_gradient_response(raw)
        resp.debug = {
            "variable_desc": self.description,
            "variable_value": self.variable,
            "variable_short_max": 400,
            "variable_short": short_text(self.variable, 400),
            "lm_system_prompt": "",
            "lm_input": lm_input,
            "lm_output": lm_output,
            "response_desc": response_desc,
            "objective_feedback": objective,
            "raw_backward_output": raw,
        }
        return resp


class InfoExtractor(BaseComponent):
    """Extract retrieval keywords while keeping the rewritten query identity."""

    BACKWARD_OUTPUT_CONSTRAINTS = DEFAULT_BACKWARD_OUTPUT_CONSTRAINTS

    def __init__(self, model='openai/gpt-4o-mini', max_tokens=1024, temperature=0):
        super().__init__(
            description=(
                "HotpotQA step 2/5 (InfoExtractor): extract minimal, high-recall keywords/entities from "
                "`rewritten_query` for Wikipedia search, while keeping the rewritten query as identity "
                "(`kept_entities`). Upstream: `QuestionRewriter.rewritten_query`. Downstream: "
                "`WikipediaRetriever` uses `search_keywords` to fetch evidence that `HintGenerator` and "
                "`AnswerGenerator` rely on. In backward, translate downstream feedback about missing/wrong "
                "evidence into concrete keyword changes (add missing entity/alias/attribute/disambiguator; "
                "remove noisy terms), and avoid suggesting changes that are already present in the current keywords."
            ),
            input_fields=["rewritten_query"],
            output_fields=["search_keywords", "kept_entities", "local_fix", "upstream_grad"],
            variable="Extract minimal keywords/entities for retrieval. Output comma-separated terms.",
            config={"model": model, "max_tokens": max_tokens, "temperature": temperature},
        )

    def forward(self, rewritten_query):
        prompt = f"{self.variable}\nQuery: {rewritten_query}\nKeywords:"
        keywords = get_llm_output(
            message=prompt,
            model=self.config.model,
            max_new_tokens=self.config.max_tokens,
            temperature=self.config.temperature,
        )
        return {
            "search_keywords": keywords.strip(),
            "kept_entities": rewritten_query,
            "local_fix": "keywords extracted",
            "upstream_grad": STOP_GRADIENT,
        }

    def backward(self, signal: GradientSignal, full_traj, component_traj) -> BackwardResponse:
        rewritten_query = component_traj.get("output", {}).get("kept_entities", "")
        search_keywords = component_traj.get("output", {}).get("search_keywords", "")
        retrieve_content = (
            full_traj.get("wikipedia_retriever", {}).get("output", {}).get("retrieve_content", "")
        )
        lm_input = f"{self.variable}\nQuery: {rewritten_query}\nKeywords:"
        lm_output = search_keywords
        objective = f"{signal.feedback}\nRetrieved content:\n{retrieve_content}"
        response_desc = (
            "keywords used for retrieving core information in wikipedia for the rewritten query"
        )
        context_prompt = build_backward_context_prompt(
            variable_desc=self.description,
            variable_value=self.variable,
            lm_input=lm_input,
            lm_output=lm_output,
            objective_feedback=objective,
            response_desc=response_desc,
        )
        prompt = f"{context_prompt}\n\n{self.BACKWARD_OUTPUT_CONSTRAINTS}"
        raw = get_llm_output(
            message=prompt,
            model=self.config.model,
            temperature=0.4,
            system_prompt=BACKWARD_SYSTEM_PROMPT,
        )
        resp = parse_gradient_response(raw)
        resp.debug = {
            "variable_desc": self.description,
            "variable_value": self.variable,
            "variable_short_max": 400,
            "variable_short": short_text(self.variable, 400),
            "lm_system_prompt": "",
            "lm_input": lm_input,
            "lm_output": lm_output,
            "response_desc": response_desc,
            "objective_feedback": objective,
            "raw_backward_output": raw,
        }
        return resp


class WikipediaRetriever(BaseComponent):
    """Retrieve passages from Wikipedia using API fallback.

    NOTE: This is treated as a tool call (non-optimizable).
    """

    def __init__(self, k: int = 10):
        super().__init__(
            description=(
                "HotpotQA step 3/5 (WikipediaRetriever, tool): query Wikipedia with `search_keywords` and "
                "return top-k snippets as `retrieve_content`. Upstream: `InfoExtractor.search_keywords`. "
                "Downstream: `HintGenerator` must ground its hints in `retrieve_content`; if retrieval is "
                "empty/irrelevant, downstream reasoning will likely fail."
            ),
            input_fields=["search_keywords"],
            output_fields=["retrieve_content", "k_used"],
            variable=None,
            config=None,
        )
        self.k = k

    def _wiki_search(self, query: str, k: int) -> list:
        """Fallback: use Wikipedia API for retrieval."""
        import json
        import re
        import urllib.parse
        import urllib.request

        # Normalize common "comma-separated keywords" into a space-separated query.
        q = (query or "").replace("\n", " ").replace(",", " ").strip()
        q = re.sub(r"\s+", " ", q)
        if not q:
            return []

        params = {
            "action": "query",
            "list": "search",
            "srsearch": q,
            "srlimit": int(k),
            "format": "json",
            "srprop": "snippet",
        }
        url = "https://en.wikipedia.org/w/api.php?" + urllib.parse.urlencode(params)
        req = urllib.request.Request(
            url,
            headers={
                # Wikipedia recommends identifying clients via UA.
                "User-Agent": "textresnet/0.1 (research; https://github.com/)"
            },
        )
        try:
            with urllib.request.urlopen(req, timeout=10) as resp:
                data = json.loads(resp.read().decode("utf-8"))
            results = data.get("query", {}).get("search", []) or []
            passages = [re.sub(r"<[^>]+>", "", (r.get("snippet", "") or "")) for r in results]
            passages = [p.strip() for p in passages if p and p.strip()]
            return passages
        except Exception:
            return []

    def forward(self, **inputs):
        search_keywords = inputs.get("search_keywords")
        if not search_keywords:
            raise ValueError("Missing required input: 'search_keywords'")

        topk_passages = self._wiki_search(search_keywords, self.k)
        retrieve_content = "\n".join(topk_passages) if topk_passages else "No results found."
        return {"retrieve_content": retrieve_content, "k_used": self.k}


class HintGenerator(BaseComponent):
    """Generate hints grounded in retrieved content."""

    BACKWARD_OUTPUT_CONSTRAINTS = DEFAULT_BACKWARD_OUTPUT_CONSTRAINTS

    def __init__(self, model='openai/gpt-4o-mini', max_tokens=512, temperature=0):
        super().__init__(
            description=(
                "HotpotQA step 4/5 (HintGenerator): generate 1–3 concise, evidence-grounded hints from "
                "`retrieve_content` to support answering `rewritten_query`. Upstream: `rewritten_query` "
                "and `retrieve_content` from Wikipedia. Downstream: `AnswerGenerator` uses these hints; "
                "hallucinated or non-evidence hints directly cause wrong answers."
            ),
            input_fields=["rewritten_query", "retrieve_content"],
            output_fields=["hints", "local_fix", "upstream_grad"],
            variable="Generate 1-3 concise hints grounded in the retrieved content. Cite phrases.",
            config={"model": model, "max_tokens": max_tokens, "temperature": temperature},
        )

    def forward(self, rewritten_query, retrieve_content):
        prompt = f"{self.variable}\nQuery: {rewritten_query}\nEvidence:\n{retrieve_content}\nHints:"
        hints = get_llm_output(
            message=prompt,
            model=self.config.model,
            max_new_tokens=self.config.max_tokens,
            temperature=self.config.temperature,
        )
        return {
            "hints": hints.strip(),
            "local_fix": "hints generated",
            "upstream_grad": STOP_GRADIENT,
        }

    def backward(self, signal: GradientSignal, full_traj, component_traj) -> BackwardResponse:
        rewritten_query = component_traj.get("input", {}).get("rewritten_query", "")
        hints = component_traj.get("output", {}).get("hints", "")
        retrieve_content = component_traj.get("input", {}).get("retrieve_content", "")
        lm_input = (
            f"{self.variable}\nQuery: {rewritten_query}\nEvidence:\n{retrieve_content}\nHints:"
        )
        lm_output = hints
        objective = signal.feedback
        response_desc = "hints for answering"
        context_prompt = build_backward_context_prompt(
            variable_desc=self.description,
            variable_value=self.variable,
            lm_input=lm_input,
            lm_output=lm_output,
            objective_feedback=objective,
            response_desc=response_desc,
        )
        prompt = f"{context_prompt}\n\n{self.BACKWARD_OUTPUT_CONSTRAINTS}"
        raw = get_llm_output(
            message=prompt,
            model=self.config.model,
            temperature=0.4,
            system_prompt=BACKWARD_SYSTEM_PROMPT,
        )
        resp = parse_gradient_response(raw)
        resp.debug = {
            "variable_desc": self.description,
            "variable_value": self.variable,
            "variable_short_max": 400,
            "variable_short": short_text(self.variable, 400),
            "lm_system_prompt": "",
            "lm_input": lm_input,
            "lm_output": lm_output,
            "response_desc": response_desc,
            "objective_feedback": objective,
            "raw_backward_output": raw,
        }
        return resp


class AnswerGenerator(BaseComponent):
    """Produce final short answer from hints."""

    BACKWARD_OUTPUT_CONSTRAINTS = DEFAULT_BACKWARD_OUTPUT_CONSTRAINTS

    def __init__(self, model='openai/gpt-4o-mini', max_tokens=256, temperature=0):
        super().__init__(
            description=(
                "HotpotQA step 5/5 (AnswerGenerator): produce the final short `answer` to "
                "`rewritten_query` using ONLY the provided `hints`. Upstream: hints from `HintGenerator` "
                "(which depend on Wikipedia evidence). Downstream: evaluation compares `answer` with gold; "
                "formatting/conciseness errors are local, missing facts usually indicate upstream evidence/hints issues."
            ),
            input_fields=["rewritten_query", "hints"],
            output_fields=["answer", "local_fix", "upstream_grad"],
            variable="Answer concisely using provided hints. Output only the answer.",
            config={"model": model, "max_tokens": max_tokens, "temperature": temperature},
        )

    def forward(self, rewritten_query, hints):
        prompt = f"{self.variable}\nQuery: {rewritten_query}\nHints:\n{hints}\nAnswer:"
        ans = get_llm_output(
            message=prompt,
            model=self.config.model,
            max_new_tokens=self.config.max_tokens,
            temperature=self.config.temperature,
        )
        return {
            "answer": ans.strip(),
            "local_fix": "answer generated",
            "upstream_grad": STOP_GRADIENT,
        }

    def backward(self, signal: GradientSignal, full_traj, component_traj) -> BackwardResponse:
        rewritten_query = component_traj.get("input", {}).get("rewritten_query", "")
        hints = component_traj.get("input", {}).get("hints", "")
        answer = component_traj.get("output", {}).get("answer", "")
        lm_input = f"{self.variable}\nQuery: {rewritten_query}\nHints:\n{hints}\nAnswer:"
        lm_output = answer
        objective = signal.feedback
        response_desc = "final answer"
        context_prompt = build_backward_context_prompt(
            variable_desc=self.description,
            variable_value=self.variable,
            lm_input=lm_input,
            lm_output=lm_output,
            objective_feedback=objective,
            response_desc=response_desc,
        )
        prompt = f"{context_prompt}\n\n{self.BACKWARD_OUTPUT_CONSTRAINTS}"
        raw = get_llm_output(
            message=prompt,
            model=self.config.model,
            temperature=0.4,
            system_prompt=BACKWARD_SYSTEM_PROMPT,
        )
        resp = parse_gradient_response(raw)
        resp.debug = {
            "variable_desc": self.description,
            "variable_value": self.variable,
            "variable_short_max": 400,
            "variable_short": short_text(self.variable, 400),
            "lm_system_prompt": "",
            "lm_input": lm_input,
            "lm_output": lm_output,
            "response_desc": response_desc,
            "objective_feedback": objective,
            "raw_backward_output": raw,
        }
        return resp


def system_engine(*args, **kwargs):
    system = CompoundAISystem(
        components={
            "question_rewriter": QuestionRewriter(),
            "info_extractor": InfoExtractor(),
            "wikipedia_retriever": WikipediaRetriever(k=15),
            "hint_generator": HintGenerator(),
            "answer_generator": AnswerGenerator(),
        },
        final_output_fields=["answer"],
        ground_fields=["gd_answer"],
        eval_func=f1_score,
        *args,
        **kwargs,
    )

    return system
