from execution_module import generate_task_description, answer_with_principles, answer_question_directly
from memory_module import MemoryManager
from plan_module import generate_difference_list, generate_principles
import json
import os
import requests
from prompt_templates import BASE_MODEL_API_URL, BASE_MODEL_NAME

# 配置路径
INPUT_FILE = "inputs/your_input.json"
DPO_FILE = "outputs/dpo_data.json"
OUTPUT_FILE = "outputs/final_output.json"
MEMORY_FILE = "memory/memory.json"

memory = MemoryManager(MEMORY_FILE)
DPO_INSTRUCTION_PROMPT = "Input: Question: {question}\nMistakes: {diff}\nOutput: List of reusable principles"

def gpt_call(user: str, model: str = BASE_MODEL_NAME, api_key: str = None) -> str:
    headers = {"Content-Type": "application/json"}
    if api_key:
        headers["Authorization"] = f"Bearer {api_key}"  

    prompt = f"{user}"
    payload = {
        "model": model,
        "prompt": prompt,
        "temperature": 0.95,
        "max_tokens": 2048,
        "top_p": 0.7
    }
    try:
        response = requests.post(BASE_MODEL_API_URL, headers=headers, json=payload)
        response.raise_for_status()
        result = response.json()
        return result.get("choices", [{}])[0].get("text", "").strip()
    except Exception as e:
        print(f"[ERROR] LLM API call failed: {e}")
        return ""

# ===== Step 1: Generate DPO Data =====
def prepare_step1_generate_dpo(dataset):
    dpo_dataset = []

    for item in dataset:
        question = item["question"]
        label = item["label"]
        task_desc = generate_task_description(question)

        # Baseline answer
        pred = answer_question_directly(question)

        # Difference analysis
        diff = generate_difference_list(question, pred, label)

        # Generate rejected (weak model)
        rejected = generate_principles(question, diff, model="weak")

        # Generate chosen (strong model)
        chosen = generate_principles(question, diff, model="strong")


        dpo_dataset.append({
            "instruction": f"Input: Question: {question}\nError Points: {diff}\nOutput: Reusable principles",
            "chosen": chosen,
            "rejected": rejected
        })

    with open(DPO_FILE, 'w', encoding='utf-8') as f:
        json.dump(dpo_dataset, f, indent=2, ensure_ascii=False)

    print(f"[✔] DPO data generated using both models: {DPO_FILE}")


# ===== Step 2: Update Memory with Fine-Tuned Model Results =====
def prepare_step2_update_memory_from_dpo():
    if not os.path.exists(DPO_FILE):
        print("[✘] DPO file not found. Please run Step 1 first.")
        return

    with open(DPO_FILE, 'r', encoding='utf-8') as f:
        dpo_data = json.load(f)

    update_count = 0
    for item in dpo_data:
        if not item.get("chosen"):
            continue
        task_desc = item["instruction"].split("Mistakes")[0].replace("Input: Question: ", "").strip()
        principles = [line.strip("- ").strip() for line in item["chosen"].splitlines() if line.strip()]
        memory.merge_principles(task_desc, principles)
        update_count += 1

    memory.save()
    print(f"[✔] Memory updated with {update_count} tasks using enhanced principles.")

# ===== Step 3: Inference Using Existing Memory =====
def inference_with_memory(dataset):
    results = []
    for item in dataset:
        question = item["question"]
        task_desc = generate_task_description(question)
        task_key, principles = memory.retrieve(task_desc)

        if not principles:
            print(f"[!] No principles found for task: {task_desc}")
            continue

        answer = answer_with_principles(question, principles)
        results.append({
            "question": question,
            "task": task_desc,
            "answer": answer
        })

    with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    print(f"[✔] Inference complete. Results saved to: {OUTPUT_FILE}")

# ===== CLI Entry =====
def main():
    with open(INPUT_FILE, 'r', encoding='utf-8') as f:
        dataset = json.load(f)

    print("\nMetaEvo Pipeline")
    print("1. Preparation Step 1 - Generate meta-optimization Data")
    print("2. Preparation Step 2 - Update Memory with Reforcement Results")
    print("3. Inference Phase - Use Memory for Reasoning")
    choice = input("Select stage (Enter 1 / 2 / 3): ").strip()

    if choice == "1":
        prepare_step1_generate_dpo(dataset)
    elif choice == "2":
        prepare_step2_update_memory_from_dpo()
    elif choice == "3":
        inference_with_memory(dataset)
    else:
        print("[✘] Invalid selection. Please enter 1 / 2 / 3.")

if __name__ == "__main__":
    main()