

import os
import json
import sys
from tqdm import tqdm
from openai import OpenAI

# Fix Unicode encoding issues on Windows
if sys.platform == 'win32':
    sys.stdout.reconfigure(encoding='utf-8')


# ------------------------------
# Config
# ------------------------------

# Dataset files to load (must exist before running this script)
DATASET_FILES = [
    "dataset_GSM8K.json",
    "dataset_ASDiv.json",
    "dataset_SVAMP.json"
]

# Models to use
MODELS = ["gpt-4o-mini", "gpt-3.5-turbo-1106"]

# Combinations to skip (already completed)
SKIP_COMBINATIONS = [
]

# ------------------------------
# Load API key & Initialize Client
# ------------------------------
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
    raise ValueError("OPENAI_API_KEY environment variable is not set")
client = OpenAI(api_key=api_key)

# ------------------------------
# Helper Functions
# ------------------------------
def load_dataset_from_file(file_path):
    """Load a dataset from a local JSON file."""
    if not os.path.exists(file_path):
        print(f"Warning: {file_path} not found. Skipping.")
        return []

    print(f"Loading {file_path}...")
    with open(file_path, "r", encoding='utf-8') as f:
        data = json.load(f)
    print(f"Loaded {len(data)} questions from {file_path}")
    return data

def generate_responses_for_dataset(dataset_items, model_name, client, output_file):
    """Generate reasoning responses for a dataset with resume capability."""

    # Load already processed question IDs to avoid re-processing
    processed_ids = set()
    if os.path.exists(output_file):
        print(f"\nFound existing responses file: {output_file}")
        print("Loading already processed questions...")
        with open(output_file, "r", encoding='utf-8') as f:
            for line in f:
                try:
                    entry = json.loads(line)
                    processed_ids.add(entry["id"])
                except json.JSONDecodeError:
                    continue
        print(f"Already processed: {len(processed_ids)} questions")

    system_prompt = """
You are a helpful mathematician. Solve the user's math problem.
Your final output must be a single JSON object. This object must have two keys:
1. "reasoning_trace": A JSON object where each key is a step (e.g., "Step 1", "Step 2") and the value is the explanation for that step.
2. "final_answer": A string or number representing the final answer.

Example response format:
{
  "reasoning_trace": {
    "Step 1": "First, I calculated the cost of the apples, which is 5 apples * $0.50/apple = $2.50.",
    "Step 2": "Next, I added the cost of the oranges, which is 3 oranges * $0.75/orange = $2.25.",
    "Step 3": "The total cost is $2.50 + $2.25 = $4.75."
  },
  "final_answer": 4.75
}
"""

    # Filter out already processed items
    items_to_process = [item for item in dataset_items if item["id"] not in processed_ids]

    if not items_to_process:
        print(f"\nAll questions already processed for {model_name}!")
        return

    print(f"\nProcessing {len(items_to_process)} remaining questions...")

    for item in tqdm(items_to_process, desc=f"Generating with {model_name}"):
        qid = item["id"]

        try:
            completion = client.chat.completions.create(
                model=model_name,
                response_format={"type": "json_object"},
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": item["question"]}
                ],
                temperature=0.0,
                max_tokens=2000
            )

            response_content = completion.choices[0].message.content

            try:
                response_data = json.loads(response_content)
                reasoning_dict = response_data.get("reasoning_trace", {})
                final_answer = response_data.get("final_answer", None)

                # Convert the dictionary of steps into a single multi-line string
                reasoning_text = "\n".join(f"{key}: {value}" for key, value in reasoning_dict.items())

            except json.JSONDecodeError:
                reasoning_text = "Error: Model did not return valid JSON."
                final_answer = None

        except Exception as e:
            print(f"\nError processing question {qid}: {e}")
            reasoning_text = f"Error: API call failed - {str(e)}"
            final_answer = None

        # Save in a structured format
        entry = {
            "id": qid,
            "dataset": item["dataset"],
            "question": item["question"],
            "ground_truth": item["answer"],
            "response": reasoning_text,
            "model_answer": final_answer,
            "model": model_name
        }

        # Append immediately to file (enables resume on crash)
        with open(output_file, "a", encoding='utf-8') as f:
            f.write(json.dumps(entry) + "\n")

# ------------------------------
# Main Execution
# ------------------------------
print(f"=" * 70)
print(f"Reasoning Response Generator")
print(f"Models: {', '.join(MODELS)}")
print(f"=" * 70)

# Step 1: Load datasets from JSON files
print("\n" + "=" * 70)
print("STEP 1: Loading Datasets from JSON files")
print("=" * 70)

all_datasets = {}

for dataset_file in DATASET_FILES:
    dataset_items = load_dataset_from_file(dataset_file)
    if dataset_items:
        # Try to get dataset name from the data, otherwise extract from filename
        if 'dataset' in dataset_items[0]:
            dataset_name = dataset_items[0]['dataset']
        else:
            # Extract from filename: dataset_GSM8K.json -> GSM8K
            dataset_name = dataset_file.replace('dataset_', '').replace('.json', '')

        # Ensure all items have required fields (for compatibility with old datasets)
        for i, item in enumerate(dataset_items):
            if 'dataset' not in item:
                item['dataset'] = dataset_name
            if 'id' not in item:
                item['id'] = f"{dataset_name}_{i}"

        all_datasets[dataset_name] = dataset_items

total_questions = sum(len(items) for items in all_datasets.values())
print(f"\nTotal questions loaded: {total_questions}")

if total_questions == 0:
    print("\nNo datasets found! Please run download_datasets.py first.")
    sys.exit(1)

# Step 2: Generate responses for all model-dataset combinations
print("\n" + "=" * 70)
print("STEP 2: Generating Reasoning Responses")
print("=" * 70)

# Calculate total combinations
total_combinations = len(MODELS) * len(all_datasets)
completed_combinations = 0
skipped_combinations = 0

for model in MODELS:
    for dataset_name, dataset_items in all_datasets.items():
        # Check if this combination should be skipped
        if (model, dataset_name) in SKIP_COMBINATIONS:
            print(f"\nSkipping {model} + {dataset_name} (already completed)")
            skipped_combinations += 1
            continue

        print(f"\n{'=' * 70}")
        print(f"Processing: {model} + {dataset_name}")
        print(f"Questions: {len(dataset_items)}")
        print(f"{'=' * 70}")

        # Separate file for each model-dataset combination
        responses_file = f"responses_{model.replace('-', '_')}_{dataset_name}.jsonl"
        print(f"Output file: {responses_file}")
        print("(If interrupted, re-run this script to resume from where it stopped)")

        generate_responses_for_dataset(dataset_items, model, client, responses_file)

        completed_combinations += 1
        print(f"\nCompleted {completed_combinations}/{total_combinations - len(SKIP_COMBINATIONS)} combinations")

print("\n" + "=" * 70)
print(f"ALL COMPLETE!")
print(f"Completed: {completed_combinations} combinations")
print(f"Skipped: {skipped_combinations} combinations")
print(f"Total questions processed: {total_questions}")
print("=" * 70)
