import json
import openai
from utils import generate_multiple_samples

def read_seed_txt(seed_txt):
    with open(seed_txt, "r") as f:
        return [line.strip() for line in f if line.strip()]

def build_prompt(seeds, n_samples, openai_key):
    joined = "\n".join(f"{i+1}. {s}" for i, s in enumerate(seeds))
    if openai_key:
        return (
            f"Generate {n_samples} new questions similar in style of the following 5 examples:\n\n"
            f"{joined}\n\n"
            f"\nNew Question:"
        )
    return (
            f"Generate 1 new question similar in style of the following 5 examples:\n\n"
            f"{joined}\n\n"
            f"\nNew Question:"
        )

def generate_synthetic_samples(seed_txt, output_path, n_samples, model, openai_key=None, hf_fallback=None, hf_fallback_tokenizer=None):
    seeds = read_seed_txt(seed_txt)
    prompt = build_prompt(seeds, n_samples, openai_key)

    try:
        if openai_key:
            openai.api_key = openai_key
            print("Using OpenAI model...")
            response = openai.ChatCompletion.create(
                model=model,
                messages=[
                    {"role": "system", "content": "You are a question generation assistant."},
                    {"role": "user", "content": prompt}
                ],
                temperature=0.7
            )
            raw = response["choices"][0]["message"]["content"]
            try:
                parsed = json.loads(raw)
            except json.JSONDecodeError:
                parsed = json.loads(raw.split("```")[0])
        else:
            print("Using Hugging Face fallback model...")

            raw = generate_multiple_samples(hf_fallback, hf_fallback_tokenizer, prompt, n=n_samples)
            parsed = raw

            # Print and return
            print("\n↪ Raw outputs:")
            for i, r in enumerate(raw, 1):
                print(f"{i}. {r}\n")

        if not isinstance(parsed, list):
            raise ValueError("Expected a list of strings")

        formatted = [{"generated_instruction": q.strip()} for q in parsed]
        with open(output_path, "w") as f:
            json.dump(formatted, f, indent=2)
        print(f"✓ Saved {len(formatted)} samples to {output_path}")
        return formatted

    except Exception as e:
        print(f"✗ Generation failed: {e}")
        return []
