import json
import re
import time
from typing import List, Optional
from openai import OpenAI, APIConnectionError

SYSTEM_PROMPT = """You are an expert assistant trained to generate realistic, diverse, and demographically plausible personas for social science surveys. 

Each persona should include:
- `eid`: a short, lowercase, variable-safe identifier that encodes key traits (e.g., urban_liberal_30s_female). No punctuation or spaces.
- `endow_text`: a natural language description of the persona (1-2 sentences), written as if describing a survey respondent.

Instructions:
- Represent a wide range of age, gender, race, education, region, and political ideology.
- Avoid repetition of phrasing or demographic combinations across personas.
- Do not include explanations or formatting outside of the JSON array.
"""

class EndowmentGenerator:
    """
    A basic persona generator for survey experiments using OpenAI's chat API.

    This class generates diverse respondent personas represented as dictionaries with:
    - `eid`: a variable-safe identifier encoding key traits (e.g., "urban_liberal_30s_female")
    - `endow_text`: a short natural-language description of the persona

    Designed for vanilla use in simulation or pilot experiments, it supports:
    - Optional topical conditioning
    - Batched generation with retry logic
    - Export of results to JSONL format

    Attributes:
        topic_description (str): Optional topic context for persona generation.
        model (str): OpenAI model name (default: "gpt-4").
        temperature (float): Sampling temperature.
        batch_size (int): Number of personas generated per prompt.
        delay (float): Delay between API calls to avoid rate limits.
        retry_failed (bool): Whether to retry failed generations.
    """    
    def __init__(
        self,
        topic_description: Optional[str] = None,
        model: str = "gpt-4",
        temperature: float = 0.9,
        batch_size: int = 10,
        delay: float = 0.5,
        system_prompt: str = SYSTEM_PROMPT,
        retry_failed: bool = True
    ):
        self.topic_description = topic_description
        self.model = model
        self.temperature = temperature
        self.batch_size = batch_size
        self.delay = delay
        self.system_prompt = system_prompt
        self.retry_failed = retry_failed
        self.client = OpenAI()
        self.personas: List[dict] = []
        self.failed_prompts: List[str] = []

    def _extract_json_array(self, text: str) -> List[dict]:
        """
        Extracts a JSON array from a string response.

        Args:
            text (str): The raw text content returned by the model.

        Returns:
            List[dict]: A list of persona dictionaries.

        Raises:
            ValueError: If a valid JSON array cannot be found in the response.
        """
        match = re.search(r'\[\s*{.*?}\s*\]', text, re.DOTALL)
        if not match:
            raise ValueError("No JSON array found in response.")
        return json.loads(match.group())

    def _build_prompt(self, count: int) -> str:
        """
        Constructs a user prompt for the OpenAI model based on the number of personas requested
        and the optional topic description.

        Args:
            count (int): Number of personas to generate in this prompt.

        Returns:
            str: The full prompt string to send to the model.
        """
        topic_line = f"The topic of the survey experiment is: {self.topic_description}.\n" if self.topic_description else ""
        return (
            f"Generate {count} diverse personas for a survey experiment. "
            "Each persona should be a dictionary with two fields:\n"
            "- eid: a short, lowercase, variable-safe identifier (e.g., 'urban_liberal_30s_female')\n"
            "- endow_text: a natural language description of the persona\n\n"
            f"{topic_line}Ensure diversity across gender, age, region, ideology, education, and race. "
            "Avoid duplication or overlapping categories. Return the result as a JSON array."
        )

    def generate(self, n: int):
        """
        Generates `n` personas in batches using the OpenAI API.

        For each batch, sends a prompt to the model, parses the response, and stores the result.
        Handles API failures with optional retry logic.

        Args:
            n (int): Total number of personas to generate.
        """
        for i in range(0, n, self.batch_size):
            remaining = min(self.batch_size, n - i)
            prompt = self._build_prompt(remaining)
            try:
                response = self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {"role": "system", "content": self.system_prompt},
                        {"role": "user", "content": prompt}
                    ],
                    temperature=self.temperature,
                    max_tokens=1024,
                )
                content = response.choices[0].message.content
                batch = self._extract_json_array(content)
                self.personas.extend(batch)
                print(f"✓ Batch {i // self.batch_size + 1}: {len(batch)} personas generated.")

            except Exception as e:
                print(f"✗ Batch {i // self.batch_size + 1} failed: {e}")
                self.failed_prompts.append(prompt)

            time.sleep(self.delay)

        if self.retry_failed:
            self.retry_failed_batches()

    def retry_failed_batches(self):
        """
        Re-attempts to generate personas for prompts that previously failed.

        Only retries if `self.retry_failed` is True. Updates `self.personas` with
        any successfully recovered entries and removes them from `self.failed_prompts`.
        """
        print("\n Retrying failed prompts...")
        remaining = self.failed_prompts[:]
        for i, prompt in enumerate(remaining):
            try:
                response = self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {"role": "system", "content": self.system_prompt},
                        {"role": "user", "content": prompt}
                    ],
                    temperature=self.temperature,
                    max_tokens=1024,
                )
                content = response.choices[0].message.content
                batch = self._extract_json_array(content)
                self.personas.extend(batch)
                print(f"✓ Retry {i + 1}: {len(batch)} personas recovered.")
                self.failed_prompts.remove(prompt)

            except Exception as e:
                print(f"✗ Retry {i + 1} failed: {e}")

            time.sleep(self.delay)

    def save(self, filepath: str):
        """
        Saves all generated personas to a JSON Lines (JSONL) file.

        Each persona is written as a separate line.

        Args:
            filepath (str): Destination path for the output file.
        """
        with open(filepath, "w", encoding="utf-8") as f:
            for p in self.personas:
                f.write(json.dumps(p) + "\n")
