import json
import time
from openai import OpenAI
from dotenv import load_dotenv
import os

load_dotenv()
mistral_key = os.getenv('mistral_key')

api_key = os.getenv("AVIOR_API_KEY") # Not support yet
api_base = "http://avior.mlfoundry.com/live-inference/v1"
client = OpenAI(
    api_key=api_key,
    base_url=api_base,
)

def load_assignment_prompts(json_file_path="assignment_prompts.json"):
    """
    Loads the assignment problems from the JSON file, which contains for each problem:
      - problem_id
      - message (list of {role: "system"/"user", content: "..."})
      - solution_list (ground truth matching)
      - solution_cost (ground truth cost)
    Returns a list of these problems.
    """
    with open(json_file_path, "r") as f:
        data = json.load(f)
    return data

def get_chat_completion(messages, model="meta-llama/Meta-Llama-3.1-405B-Instruct-FP8",
                        max_tokens=1024, temperature=0.0):
    """
    Sends the conversation (messages) to the HF chat completion endpoint.
    Streams the response and concatenates it into a final text string.

    :param messages: A list of dicts like [{"role": "system", "content": ...}, ...]
    :param model: The model repo on HF
    :param max_tokens: Limit for the output
    :param temperature: Sampling temperature
    :param stream: Whether to stream tokens or not
    :return: The complete assistant message as a string.
    """
    chat_response = client.chat.completions.create(
        model= model,
        messages = messages,
        max_tokens=max_tokens,
        temperature=temperature,   
    )
    final_response = chat_response.choices[0].message.content
    
    return final_response

def main():
    # 2) Load the assignment prompts
    assignment_problems = load_assignment_prompts("assignment_prompts.json")

    results = []

    # 3) Iterate over each problem
    for idx, problem in enumerate(assignment_problems, start=1):
        problem_id = problem["problem_id"]
        messages = problem["message"]  # This is a list of { "role": ..., "content": ... }
        gt_solution_list = problem["solution_list"]
        gt_solution_cost = problem["solution_cost"]

        print(f"\n===== Processing Problem ID {problem_id} ({idx}/{len(assignment_problems)}) =====")

        # 4) Send the prompt to the HF model
        try:
            model_response = get_chat_completion(
                messages=messages,
                model="meta-llama/Meta-Llama-3.1-405B-Instruct-FP8",  # or the model you want to use
                max_tokens=2048,
                temperature=0.0
            )
        except Exception as e:
            print(f"Error occurred while querying the model for problem {problem_id}: {e}")
            model_response = ""

        # Optional: Wait a bit if needed to avoid rate limits
        time.sleep(1)

        print(model_response)
        # 5) Store the result with ground truth
        result_entry = {
            "problem_id": problem_id,
            "model_response": model_response,
            "solution_list_ground_truth": gt_solution_list,
            "solution_cost_ground_truth": gt_solution_cost
        }
        results.append(result_entry)

    # 6) Save all results to a JSON file
    with open("result/llama_8B.json", "w") as f:
        json.dump(results, f, indent=2)

    print(f"Done! Saved responses for {len(results)} problems.")

if __name__ == "__main__":
    main()
