import json
import time
from huggingface_hub import InferenceClient

# 1) Configure your Hugging Face Inference endpoint:
client = InferenceClient()
    # base_url="YOUR_BASE_URL",  # e.g., "https://api-inference.huggingface.co"
    # api_key="YOUR_API_KEY"
# )

def load_assignment_prompts(json_file_path="assignment_prompts.json"):
    """
    Loads the assignment problems from the JSON file, which contains for each problem:
      - problem_id
      - message (list of {role: "system"/"user", content: "..."})
      - solution_list (ground truth matching)
      - solution_cost (ground truth cost)
    Returns a list of these problems.
    """
    with open(json_file_path, "r") as f:
        data = json.load(f)
    return data

def get_chat_completion(messages, model="meta-llama/Llama-3.1-8B-Instruct",
                        max_tokens=1024, temperature=0.0, stream=True):
    """
    Sends the conversation (messages) to the HF chat completion endpoint.
    Streams the response and concatenates it into a final text string.

    :param messages: A list of dicts like [{"role": "system", "content": ...}, ...]
    :param model: The model repo on HF
    :param max_tokens: Limit for the output
    :param temperature: Sampling temperature
    :param stream: Whether to stream tokens or not
    :return: The complete assistant message as a string.
    """
    response_chunks = client.chat.completions.create(
        model=model,
        messages=messages,
        stream=stream,
        max_tokens=max_tokens,
        temperature=temperature,
    )

    # If streaming, we collect the incremental content from each chunk
    final_response = ""
    try:
        for chunk in response_chunks:
            # Usually, chunk is something like:
            # {
            #   "choices": [
            #     {
            #       "delta": {"content": "..."},
            #       "index": 0,
            #       "finish_reason": None
            #     }
            #   ]
            # }
            choices = chunk.get("choices", [])
            if len(choices) > 0:
                delta = choices[0].get("delta", {})
                content = delta.get("content", "")
                final_response += content
    except TypeError:
        # If 'stream=False', response_chunks may be a single dict, or if there's an error
        # you might need to parse it differently. You can adapt the logic here.
        print("Error processing response:", chunk)
    
    return final_response

def main():
    # 2) Load the assignment prompts
    assignment_problems = load_assignment_prompts("assignment_prompts.json")

    results = []

    # 3) Iterate over each problem
    for idx, problem in enumerate(assignment_problems, start=1):
        problem_id = problem["problem_id"]
        messages = problem["message"]  # This is a list of { "role": ..., "content": ... }
        gt_solution_list = problem["solution_list"]
        gt_solution_cost = problem["solution_cost"]

        print(f"\n===== Processing Problem ID {problem_id} ({idx}/{len(assignment_problems)}) =====")

        # 4) Send the prompt to the HF model
        try:
            model_response = get_chat_completion(
                messages=messages,
                model="meta-llama/Llama-3.1-8B-Instruct",  # or the model you want to use
                max_tokens=1024,
                temperature=0.0,
                stream=True
            )
        except Exception as e:
            print(f"Error occurred while querying the model for problem {problem_id}: {e}")
            model_response = ""

        # Optional: Wait a bit if needed to avoid rate limits
        time.sleep(2)

        print(model_response)
        # 5) Store the result with ground truth
        result_entry = {
            "problem_id": problem_id,
            "model_response": model_response,
            "solution_list_ground_truth": gt_solution_list,
            "solution_cost_ground_truth": gt_solution_cost
        }
        results.append(result_entry)

    # 6) Save all results to a JSON file
    with open("result/assignment_model_responses.json", "w") as f:
        json.dump(results, f, indent=2)

    print(f"Done! Saved responses for {len(results)} problems to 'assignment_model_responses.json'.")

if __name__ == "__main__":
    main()
