import json
import time
from mistralai import Mistral
from dotenv import load_dotenv
import os
import copy

load_dotenv()
mistral_key = os.getenv('mistral_key')

client = Mistral(api_key=mistral_key)

def load_assignment_prompts(json_file_path="assignment_prompts.json"):
    """
    Loads the assignment problems from the JSON file, which contains for each problem:
      - problem_id
      - message (list of {role: "system"/"user", content: "..."})
      - solution_list (ground truth matching)
      - solution_cost (ground truth cost)
    Returns a list of these problems.
    """
    with open(json_file_path, "r") as f:
        data = json.load(f)
    return data

def get_chat_completion(messages, model="ministral-8b-latest",
                        max_tokens=1024, temperature=0.0):
    """
    Sends the conversation (messages) to the HF chat completion endpoint.
    Streams the response and concatenates it into a final text string.

    :param messages: A list of dicts like [{"role": "system", "content": ...}, ...]
    :param model: The model repo on HF
    :param max_tokens: Limit for the output
    :param temperature: Sampling temperature
    :param stream: Whether to stream tokens or not
    :return: The complete assistant message as a string.
    """
    chat_response = client.chat.complete(
        model= model,
        messages = messages,
        max_tokens=max_tokens,
        temperature=temperature,   
    )
    final_response = chat_response.choices[0].message.content
    
    return final_response

def main():
    # 2) Load the assignment prompts
    assignment_problems = load_assignment_prompts("assignment_prompts.json")

    results = []

    # 3) Iterate over each problem
    for idx, problem in enumerate(assignment_problems, start=1):
        problem_id = problem["problem_id"]
        messages = problem["message"]  # This is a list of { "role": ..., "content": ... }
        gt_solution_list = problem["solution_list"]
        gt_solution_cost = problem["solution_cost"]

        print(f"\n===== Processing Problem ID {problem_id} ({idx}/{len(assignment_problems)}) =====")

        # 4) Send the prompt to the HF model
        try:
            model_response = get_chat_completion(
                messages=messages,
                model="ministral-8b-latest",  # or the model you want to use
                max_tokens=2048,
                temperature=0.0
            )
        except Exception as e:
            print(f"Error occurred while querying the model for problem {problem_id}: {e}")
            model_response = ""
        messages.append({"role": "assistant", "content": model_response})
        # Optional: Wait a bit if needed to avoid rate limits
        time.sleep(1)
        print("Initial response:", model_response) 

        reflect_prompt = """
        Your previous solution may not be correct. Your task is to reflect on the  problem, your solution, and the correct answer.
        You will then use this information help you answer the same question in the future.
        If you think your previous answer was incorrect, please follow these steps:
        First, explain why you answered the question incorrectly.
        Second, list the keywords that describe the type of your errors from most general to most specific.
        Third, solve the problem again, step-by-step, based on your knowledge of the correct answer.
        """
        messages.append({"role": "user", "content": reflect_prompt})

        try:
            model_response = get_chat_completion(
                messages=messages,
                model="ministral-8b-latest",  # or the model you want to use
                max_tokens=2048,
                temperature=0.0
            )
        except Exception as e:
            print(f"Error occurred while querying the model for problem {problem_id}: {e}")
            model_response = ""

        messages.append({"role": "assistant", "content": model_response})
        time.sleep(1)
        print("Reflection:", model_response) 

        messages.append({"role":"user", "content":"Based on your reflection, try to solve the problem again."})

        try:
            model_response = get_chat_completion(
                messages=messages,
                model="ministral-8b-latest",  # or the model you want to use
                max_tokens=2048,
                temperature=0.0
            )
        except Exception as e:
            print(f"Error occurred while querying the model for problem {problem_id}: {e}")
            model_response = ""

        print("Final response:", model_response)
        # 5) Store the result with ground truth
        result_entry = {
            "problem_id": problem_id,
            "model_response": model_response,
            "solution_list_ground_truth": gt_solution_list,
            "solution_cost_ground_truth": gt_solution_cost
        }
        results.append(result_entry)

    # 6) Save all results to a JSON file
    with open("result/assignment_model_responses.json", "w") as f:
        json.dump(results, f, indent=2)

    print(f"Done! Saved responses for {len(results)} problems to 'assignment_model_responses.json'.")

if __name__ == "__main__":
    main()
