"""
Tests for HFAPIBackend — unit (no network) + integration (requires HF_TOKEN).

Unit tests mock InferenceClient and verify:
  - generate() wraps prompt in user message + optional system message
  - generate_chat() passes the full message list through to chat_completion()
  - system_prompt is prepended exactly once; not injected if already present
  - caller's message list is never mutated
  - empty/None content is handled gracefully

Integration tests (skipped without HF_TOKEN) hit the real HF Inference API
using meta-llama/Llama-3.2-1B-Instruct (fast/cheap) and verify that
generate_chat() maintains context across turns.

Run all:              pytest tests/test_hf_api_backend.py -v
Run unit only:        pytest tests/test_hf_api_backend.py -v -m "not integration"
Run integration only: pytest tests/test_hf_api_backend.py -v -m integration
"""
from __future__ import annotations

import os
import sys
import types
from unittest.mock import MagicMock, patch

import pytest

# ---------------------------------------------------------------------------
# HF_TOKEN guard for integration tests
# ---------------------------------------------------------------------------
HF_TOKEN = os.environ.get("HF_TOKEN")
requires_hf_token = pytest.mark.skipif(
    not HF_TOKEN,
    reason="HF_TOKEN not set — skipping live API tests",
)


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _make_response(text: str, prompt_tokens: int = 0,
                   completion_tokens: int = 0, cached_tokens: int = 0):
    """Build a fake InferenceClient chat_completion response."""
    details = types.SimpleNamespace(cached_tokens=cached_tokens) if cached_tokens else None
    usage = types.SimpleNamespace(
        prompt_tokens=prompt_tokens,
        completion_tokens=completion_tokens,
        prompt_tokens_details=details,
    )
    return types.SimpleNamespace(
        choices=[types.SimpleNamespace(message=types.SimpleNamespace(content=text))],
        usage=usage,
    )


def _make_backend(system_prompt: str | None = None, **kwargs):
    """
    Instantiate HFAPIBackend with a mocked InferenceClient so no HTTP call
    is made.  Returns (backend, mock_client).
    """
    # Stub huggingface_hub at the module level so the import inside __init__
    # succeeds without the real package (or with it — both paths work).
    fake_client = MagicMock()
    hf_hub = types.ModuleType("huggingface_hub")
    hf_hub.InferenceClient = MagicMock(return_value=fake_client)
    with patch.dict(sys.modules, {"huggingface_hub": hf_hub}):
        from meta_rg.backends.hf_api_backend import HFAPIBackend
        backend = HFAPIBackend(
            model_id="test-model",
            temperature=0.7,
            max_new_tokens=64,
            system_prompt=system_prompt,
            **kwargs,
        )
    # Swap in our pre-built mock so call_args is accessible
    backend.client = fake_client
    return backend, fake_client


# ---------------------------------------------------------------------------
# Unit tests — generate()
# ---------------------------------------------------------------------------

class TestGenerate:
    def test_wraps_prompt_in_user_message(self):
        backend, client = _make_backend()
        client.chat_completion.return_value = _make_response("hello")
        backend.generate("say hi")
        msgs = client.chat_completion.call_args.kwargs["messages"]
        assert msgs == [{"role": "user", "content": "say hi"}]

    def test_prepends_system_prompt(self):
        backend, client = _make_backend(system_prompt="Be concise.")
        client.chat_completion.return_value = _make_response("ok")
        backend.generate("question")
        msgs = client.chat_completion.call_args.kwargs["messages"]
        assert msgs[0] == {"role": "system", "content": "Be concise."}
        assert msgs[1] == {"role": "user", "content": "question"}

    def test_returns_content_string(self):
        backend, client = _make_backend()
        client.chat_completion.return_value = _make_response("42 1 2 3")
        result = backend.generate("prompt")
        assert result == "42 1 2 3"

    def test_empty_content_returns_empty_string(self):
        backend, client = _make_backend()
        client.chat_completion.return_value = _make_response(None)
        result = backend.generate("prompt")
        assert result == ""


# ---------------------------------------------------------------------------
# Unit tests — generate_chat()
# ---------------------------------------------------------------------------

class TestGenerateChat:
    def test_passes_messages_to_chat_completion(self):
        backend, client = _make_backend()
        client.chat_completion.return_value = _make_response("answer")
        history = [
            {"role": "user",      "content": "turn 1"},
            {"role": "assistant", "content": "resp 1"},
            {"role": "user",      "content": "turn 2"},
        ]
        backend.generate_chat(history)
        sent = client.chat_completion.call_args.kwargs["messages"]
        assert sent == history

    def test_system_prompt_prepended_when_absent(self):
        backend, client = _make_backend(system_prompt="You are helpful.")
        client.chat_completion.return_value = _make_response("ok")
        history = [
            {"role": "user",      "content": "turn 1"},
            {"role": "assistant", "content": "resp 1"},
            {"role": "user",      "content": "turn 2"},
        ]
        backend.generate_chat(history)
        sent = client.chat_completion.call_args.kwargs["messages"]
        assert sent[0] == {"role": "system", "content": "You are helpful."}
        assert sent[1:] == history

    def test_caller_list_not_mutated(self):
        backend, client = _make_backend(system_prompt="Sys.")
        client.chat_completion.return_value = _make_response("ok")
        history = [{"role": "user", "content": "hello"}]
        original_len = len(history)
        backend.generate_chat(history)
        assert len(history) == original_len, "generate_chat() must not mutate the caller's list"

    def test_no_double_prepend_if_system_already_present(self):
        backend, client = _make_backend(system_prompt="Backend sys.")
        client.chat_completion.return_value = _make_response("ok")
        history = [
            {"role": "system",    "content": "Caller sys."},
            {"role": "user",      "content": "question"},
        ]
        backend.generate_chat(history)
        sent = client.chat_completion.call_args.kwargs["messages"]
        system_msgs = [m for m in sent if m["role"] == "system"]
        assert len(system_msgs) == 1, "Only one system message should be present"
        assert system_msgs[0]["content"] == "Caller sys.", "Caller's system message should be preserved"

    def test_no_system_prompt_passes_messages_unchanged(self):
        backend, client = _make_backend(system_prompt=None)
        client.chat_completion.return_value = _make_response("done")
        msgs = [{"role": "user", "content": "hi"}]
        backend.generate_chat(msgs)
        sent = client.chat_completion.call_args.kwargs["messages"]
        assert sent == msgs

    def test_returns_content_string(self):
        backend, client = _make_backend()
        client.chat_completion.return_value = _make_response("42 1 2 3")
        result = backend.generate_chat([{"role": "user", "content": "x"}])
        assert result == "42 1 2 3"

    def test_empty_content_returns_empty_string(self):
        backend, client = _make_backend()
        client.chat_completion.return_value = _make_response(None)
        result = backend.generate_chat([{"role": "user", "content": "x"}])
        assert result == ""

    def test_uses_correct_model_params(self):
        backend, client = _make_backend()
        client.chat_completion.return_value = _make_response("ok")
        backend.generate_chat([{"role": "user", "content": "hi"}])
        kw = client.chat_completion.call_args.kwargs
        assert kw["model"] == "test-model"
        assert kw["max_tokens"] == 64
        assert kw["temperature"] == 0.7


# ---------------------------------------------------------------------------
# Token tracking (usage field + prefix cache stats)
# ---------------------------------------------------------------------------

class TestTokenTracking:
    def test_prompt_and_completion_tokens_accumulated(self):
        backend, client = _make_backend()
        client.chat_completion.side_effect = [
            _make_response("a", prompt_tokens=100, completion_tokens=20),
            _make_response("b", prompt_tokens=200, completion_tokens=30),
        ]
        backend.generate("x")
        backend.generate("y")
        assert backend.total_prompt_tokens == 300
        assert backend.total_completion_tokens == 50

    def test_cached_tokens_accumulated(self):
        backend, client = _make_backend()
        client.chat_completion.side_effect = [
            _make_response("a", prompt_tokens=100, completion_tokens=10, cached_tokens=0),
            _make_response("b", prompt_tokens=200, completion_tokens=10, cached_tokens=150),
        ]
        backend.generate("x")
        backend.generate_chat([{"role": "user", "content": "y"}])
        assert backend.total_cached_tokens == 150
        assert abs(backend.cache_hit_rate - 150 / 300) < 1e-9

    def test_no_cached_tokens_when_details_absent(self):
        backend, client = _make_backend()
        # Response with usage but no prompt_tokens_details
        resp = types.SimpleNamespace(
            choices=[types.SimpleNamespace(message=types.SimpleNamespace(content="ok"))],
            usage=types.SimpleNamespace(prompt_tokens=50, completion_tokens=5,
                                        prompt_tokens_details=None),
        )
        client.chat_completion.return_value = resp
        backend.generate("x")
        assert backend.total_cached_tokens == 0
        assert backend.cache_hit_rate == 0.0

    def test_no_usage_field_does_not_crash(self):
        backend, client = _make_backend()
        resp = types.SimpleNamespace(
            choices=[types.SimpleNamespace(message=types.SimpleNamespace(content="ok"))],
        )
        client.chat_completion.return_value = resp
        backend.generate("x")
        assert backend.total_prompt_tokens == 0
        assert backend.total_cached_tokens == 0

    def test_cache_hit_rate_zero_before_any_calls(self):
        backend, _ = _make_backend()
        assert backend.cache_hit_rate == 0.0

    def test_last_call_tokens_reflect_most_recent_call(self):
        backend, client = _make_backend()
        client.chat_completion.side_effect = [
            _make_response("a", prompt_tokens=100, completion_tokens=10, cached_tokens=0),
            _make_response("b", prompt_tokens=200, completion_tokens=10, cached_tokens=150),
        ]
        backend.generate("x")
        assert backend.last_call_prompt_tokens == 100
        assert backend.last_call_cached_tokens == 0
        backend.generate("y")
        assert backend.last_call_prompt_tokens == 200
        assert backend.last_call_cached_tokens == 150

    def test_last_call_tokens_zero_when_no_usage(self):
        backend, client = _make_backend()
        # First call with usage, second without
        resp_no_usage = types.SimpleNamespace(
            choices=[types.SimpleNamespace(message=types.SimpleNamespace(content="ok"))],
        )
        client.chat_completion.side_effect = [
            _make_response("a", prompt_tokens=100, completion_tokens=10, cached_tokens=80),
            resp_no_usage,
        ]
        backend.generate("x")
        assert backend.last_call_prompt_tokens == 100
        backend.generate("y")
        assert backend.last_call_prompt_tokens == 0
        assert backend.last_call_cached_tokens == 0


# ---------------------------------------------------------------------------
# Integration tests — real HF Inference API
# ---------------------------------------------------------------------------

@pytest.mark.integration
class TestHFAPIIntegration:
    """
    These tests hit the real HF Inference API.  They use Llama-3.2-1B-Instruct
    (fast and cheap) rather than the large prover models so the suite stays
    practical as a smoke test.

    Model used: meta-llama/Llama-3.2-1B-Instruct
    """

    @pytest.fixture(scope="class")
    def backend(self):
        requires_hf_token(lambda: None)()  # skip early if no token
        from meta_rg.backends.hf_api_backend import HFAPIBackend
        return HFAPIBackend(
            model_id="meta-llama/Llama-3.2-1B-Instruct",
            temperature=0.1,
            max_new_tokens=32,
        )

    @requires_hf_token
    def test_generate_returns_string(self, backend):
        result = backend.generate("Reply with the single word HELLO.")
        assert isinstance(result, str)
        assert len(result) > 0

    @requires_hf_token
    def test_generate_chat_single_turn(self, backend):
        result = backend.generate_chat([
            {"role": "user", "content": "Reply with the single word HELLO."},
        ])
        assert isinstance(result, str)
        assert len(result) > 0

    @requires_hf_token
    def test_generate_chat_preserves_context(self, backend):
        """
        The model should recall a value given in an earlier turn — this
        verifies that generate_chat() passes the full history to the API
        rather than only the last message.
        """
        history = [
            {"role": "user",      "content": "Remember the secret number 7."},
            {"role": "assistant", "content": "I will remember the secret number 7."},
            {"role": "user",      "content": "What is the secret number? Reply with the digit only."},
        ]
        result = backend.generate_chat(history)
        assert "7" in result, (
            f"Model did not recall the number from history. Response: {result!r}"
        )

    @requires_hf_token
    def test_generate_chat_does_not_mutate_history(self, backend):
        history = [{"role": "user", "content": "Hi."}]
        original_len = len(history)
        backend.generate_chat(history)
        assert len(history) == original_len


@pytest.mark.integration
class TestDeepSeekProverIntegration:
    """
    Smoke tests against deepseek-ai/DeepSeek-Prover-V2-671B via HF Inference API.
    Verifies that generate() and generate_chat() work with the actual target model.
    """

    @pytest.fixture(scope="class")
    def backend(self):
        from meta_rg.backends.hf_api_backend import HFAPIBackend
        return HFAPIBackend(
            model_id="deepseek-ai/DeepSeek-Prover-V2-671B",
            temperature=0.1,
            max_new_tokens=64,
        )

    @requires_hf_token
    def test_generate_returns_string(self, backend):
        result = backend.generate("Reply with the single word HELLO.")
        assert isinstance(result, str)
        assert len(result) > 0, "Expected non-empty response from DeepSeek-Prover-V2-671B"

    @requires_hf_token
    def test_generate_chat_single_turn(self, backend):
        result = backend.generate_chat([
            {"role": "user", "content": "Reply with the single word HELLO."},
        ])
        assert isinstance(result, str)
        assert len(result) > 0, "Expected non-empty response from generate_chat()"

    @requires_hf_token
    def test_generate_chat_preserves_context(self, backend):
        history = [
            {"role": "user",      "content": "Remember the secret number 7."},
            {"role": "assistant", "content": "I will remember the secret number 7."},
            {"role": "user",      "content": "What is the secret number? Reply with the digit only."},
        ]
        result = backend.generate_chat(history)
        assert "7" in result, (
            f"Model did not recall the number from history. Response: {result!r}"
        )
