from __future__ import annotations

import os
from typing import Any, List, Optional, Tuple

from openai import OpenAI, BadRequestError, OpenAIError

def _get_client(api_key: Optional[str] = None) -> OpenAI:
    return OpenAI(api_key="Input_Your_API_KEY")

def _build_primary_prompt(
    rag_demos: List[str],
    observation: str,
    skill_code: str,
    instruction: str,
    object_list: Any,
) -> str:
    parts = [f"Instruction:\n{instruction}\n"]
    for i, demo in enumerate(rag_demos, 1):
        parts.append(f"Demo {i}:\n{demo}\n")
    parts.append(f"Observation:\n{observation}\n")
    parts.append(f"Skill Code:\n{skill_code}\n")
    parts.append(f"Object_list:\n{object_list}\n")
    return "\n".join(parts)

def compute_llm_score_and_feedback(
    rag_demos: List[str],
    observation: str,
    skill_code: str,
    instruction: str,
    object_list: Any,
    *,
    threshold: float = -1.0,
    eval_model: str = "gpt-3.5-turbo-instruct",
    feedback_model: str = "gpt-4o-mini",
    api_key: Optional[str] = None,
) -> Tuple[Optional[float], Optional[str]]:
    client = _get_client(api_key)
    primary_prompt = _build_primary_prompt(
        rag_demos, observation, skill_code, instruction, object_list
    )

    try:
        resp = client.completions.create(
            model=eval_model,
            prompt=primary_prompt,
            max_tokens=0,
            temperature=0.0,
            logprobs=5,
            echo=True,
        )
    except BadRequestError as e:
        raise RuntimeError(
            f"[LLM Score] Completions API call failed: {e}"
        ) from e
    except OpenAIError as e:
        raise RuntimeError(f"[LLM Score] OpenAI API error: {e}") from e

    token_lps = resp.choices[0].logprobs.token_logprobs
    llm_score = sum(token_lps) / len(token_lps) if token_lps else None
    print(f"[LLM Score] {llm_score:.4f}" if llm_score is not None else "[LLM Score] N/A")

    feedback: Optional[str] = None
    if llm_score is not None and llm_score < threshold:
        fb_prompt_parts: List[str] = [
            "Below is the full context for the skill. Provide feedback following the requested format.\n",
            "=== Demos ===\n",
        ]
        for i, demo in enumerate(rag_demos, 1):
            fb_prompt_parts.append(f"Demo {i}:\n{demo}\n")
        fb_prompt_parts.extend(
            [
                "=== Observation ===\n",
                f"{observation}\n\n",
                "=== Skill Code ===\n",
                f"{skill_code}\n\n",
                "=== LLM Score ===\n",
                f"The current LLM Score is {llm_score:.4f}, below the threshold ({threshold}).\n\n",
                "=== Instruction ===\n",
                f"{instruction}\n\n",
                "=== Object_list ===\n",
                f"{object_list}\n\n",
                "=== Task ===\n",
                "Please provide feedback on how to improve the skill code.\n\n",
                "=== Format ===\n"
                "1. **Problem Identification**\n"
                "2. **Justification**\n"
                "3. **Proposed Solutions**\n"
                "4. **Additional Notes**\n\n",
            ]
        )
        feedback_prompt = "".join(fb_prompt_parts)

        try:
            chat_resp = client.chat.completions.create(
                model=feedback_model,
                messages=[
                    {
                        "role": "system",
                        "content": (
                            "You are an expert robotics code reviewer. "
                            "Provide constructive, step-by-step feedback."
                        ),
                    },
                    {"role": "user", "content": feedback_prompt},
                ],
                max_tokens=700,
                temperature=0.7,
            )
            feedback = chat_resp.choices[0].message.content.strip()
        except OpenAIError as e:
            raise RuntimeError(f"[Feedback] OpenAI API error: {e}") from e

    return llm_score, feedback

if __name__ == "__main__":
    score, fb = compute_llm_score_and_feedback(
        rag_demos=["Pick up the red block and place it on the blue block."],
        observation="The red block is on the table.",
        skill_code="def skill(): pass",
        instruction="Stack the red block on the blue block.",
        object_list={"red_block": "(0.1, 0.2, 0.3)"},
        threshold=-3.0,
    )
    print("\n=== Result ===")
    print("LLM Score:", score)
    if fb:
        print("Feedback:\n", fb)