from abc import ABC, abstractmethod
from typing import List, Tuple, Optional
import time
import json
import re
import random
import uuid
import logging
from openai import OpenAI

logger = logging.getLogger(__name__)

SYSTEM_PROMPT = """You are an expert assistant trained to generate realistic, diverse, and demographically plausible personas for social science surveys. 

Each persona should include:
- `eid`: a short, lowercase, variable-safe identifier that encodes key traits (e.g., urban_liberal_30s_female). No punctuation or spaces.
- `endow_text`: a natural language description of the persona (1-5 sentences), written as if describing a survey respondent.

Instructions:
- Represent a wide range of age, gender, race/ethnicity, education, region, and political ideology.
- Avoid repetition of phrasing or demographic combinations across personas.
- Do not include explanations or formatting outside of the JSON array.
"""

class EndowmentModel(ABC):
    """
    Abstract base class for endowment models that generate personas 
    based on input attributes and sampling modes.

    Subclasses must implement the `generate_endowments` method.
    """

    @abstractmethod
    def generate_endowments(self, attributes: List[str], n: int, mode: Tuple[str, ...]) -> List[dict]:
        """
        Generate a list of `n` personas based on provided attributes and mode.

        Args:
            attributes (List[str]): A list of human-related attributes 
                (e.g., age, ideology, personality traits).
            n (int): Number of personas to generate.
            mode (Tuple[str, ...]): Mode identifier used for tracking or conditioning.

        Returns:
            List[dict]: A list of persona dictionaries, each containing:
                - "eid": unique identifier
                - "endow_text": natural-language description
        """
        pass

class OpenAIEndowmentModel(EndowmentModel):
    """
    OpenAI-backed implementation of the EndowmentModel that uses GPT-based 
    LLMs to generate realistic, demographically diverse personas.

    Attributes:
        model (str): OpenAI model name to use (e.g., "gpt-4").
        temperature (float): Sampling temperature.
        max_tokens (int): Maximum tokens allowed in each API completion.
        batch_size (int): Number of personas generated per prompt batch.
        delay (float): Seconds to sleep between API calls.
        system_prompt (str): Instructional system message for persona generation.
        retry_failed (bool): Whether to retry failed completions.
        randomize_attributes (bool): Whether to subsample from attributes if too many.
        max_attributes (int): Maximum number of attributes to include in prompt.
        clear_after_run (bool): If True, clears stored personas after generation.
        seed (Optional[int]): Random seed for reproducibility.
        logger (logging.Logger): Optional external logger to use. Defaults to module-level logger.
        verbose (bool): Whether to enable logging output.
    """
    def __init__(
        self,
        model: str = "gpt-4",
        temperature: float = 0.9,
        max_tokens: int = 2048,
        batch_size: int = 10,
        delay: float = 0.5,
        system_prompt: str = SYSTEM_PROMPT,
        retry_failed: bool = True,
        randomize_attributes: bool = True,
        max_attributes: int = 12,
        clear_after_run: bool = True,
        seed: Optional[int] = None,
        logger: Optional[logging.Logger] = None,
        verbose: bool = True
    ):
        self.model = model
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.batch_size = batch_size
        self.delay = delay
        self.system_prompt = system_prompt
        self.retry_failed = retry_failed
        self.randomize_attributes = randomize_attributes
        self.max_attributes = max_attributes
        self.clear_after_run = clear_after_run
        self.random_state = random.Random(seed)
        
        self.logger = logger or logging.getLogger(__name__)
        self.verbose = verbose

        self.client = OpenAI()
        self.personas: List[dict] = []
        self.failed_prompts: List[str] = []

    def _log(self, msg: str, level: str = "info"):
        """
        Internal logging utility with optional verbosity control.

        Args:
            msg (str): The message to log.
            level (str): Logging level ("debug", "info", "warning", "error", "critical").
        """
        if self.verbose:
            log_fn = getattr(self.logger, level.lower(), self.logger.info)
            log_fn(msg)

    def _extract_json_array(self, text: str) -> List[dict]:
        """
        Extract a JSON array from a raw LLM response.

        Args:
            text (str): Model output text.

        Returns:
            List[dict]: Parsed list of persona dictionaries.

        Raises:
            ValueError: If no valid JSON array is found.
        """
        match = re.search(r'\[\s*{.*?}\s*\]', text, re.DOTALL)
        if not match:
            raise ValueError("No JSON array found in response.")
        return json.loads(match.group())

    def _build_prompt(self, attributes: List[str], n: int) -> str:
        """
        Construct a user prompt for persona generation.

        Args:
            attributes (List[str]): A list of attribute names.
            n (int): Number of personas to request.

        Returns:
            str: Fully formatted user prompt string.
        """
        attr_string = ", ".join(attributes)
        return (
            f"Generate {n} diverse persona(s) that vary meaningfully along the following attributes: {attr_string}.\n"
            "Each persona should reflect a distinct combination or value expression of these traits.\n"
            "Return a JSON array of dictionaries, each with:\n"
            "- `eid`: a short, lowercase, variable-safe identifier\n"
            "- `endow_text`: a brief natural-language description of the persona"
        )

    def generate_endowments(self, attributes: List[str], n: int, mode: tuple) -> List[dict]:
        """
        Generate personas via OpenAI API using the configured model and prompt logic.

        Args:
            attributes (List[str]): List of attribute names to condition on.
            n (int): Total number of personas to generate.
            mode (tuple): Mode identifier used to tag generated personas.

        Returns:
            List[dict]: List of generated persona dictionaries with `eid`, `endow_text`, `mode`, and `attributes`.

        Notes:
            - Automatically retries failed batches if `retry_failed` is True.
            - Uses `batch_size` to split into multiple API calls.
        """        
        selected_attrs = attributes
        if self.randomize_attributes and len(attributes) > self.max_attributes:
            selected_attrs = self.random_state.sample(attributes, k=self.max_attributes)

        full_batch = []
        for i in range(0, n, self.batch_size):
            curr_n = min(self.batch_size, n - i)
            prompt = self._build_prompt(selected_attrs, curr_n)

            try:
                response = self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {"role": "system", "content": self.system_prompt},
                        {"role": "user", "content": prompt}
                    ],
                    temperature=self.temperature,
                    max_tokens=self.max_tokens,
                )
                content = response.choices[0].message.content
                batch = self._extract_json_array(content)
                suffix = uuid.uuid4().hex[:6]
                for p in batch:
                    if "eid" in p:
                        p["eid"] = f"{p['eid']}_{suffix}"
                    p["mode"] = mode
                    p["attributes"] = selected_attrs
                self.personas.extend(batch)
                full_batch.extend(batch)
                self._log(f"✓ Batch {i // self.batch_size + 1} [{'+'.join(mode)}]: {len(batch)} personas generated.")

            except Exception as e:
                self._log(f"✗ Batch {i // self.batch_size + 1} failed: {e}", level= "warning")
                self.failed_prompts.append((prompt, mode, selected_attrs))

            time.sleep(self.delay)

        if self.retry_failed and self.failed_prompts:
            recovered_batch = self.retry_failed_batches()
            full_batch.extend(recovered_batch)

        if self.clear_after_run:
            self.clear()
            
        return full_batch

    def retry_failed_batches(self):
        """
        Retry all failed prompt batches and attempt to regenerate personas.

        Returns:
            List[dict]: List of successfully recovered personas.
        """
        self._log("\nRetrying failed prompts...")
        remaining = self.failed_prompts[:]
        recovered_batch = []
        for i, (prompt, mode, selected_attrs) in enumerate(remaining):
            try:
                response = self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {"role": "system", "content": self.system_prompt},
                        {"role": "user", "content": prompt}
                    ],
                    temperature=self.temperature,
                    max_tokens=self.max_tokens,
                )
                content = response.choices[0].message.content
                batch = self._extract_json_array(content)
                suffix = uuid.uuid4().hex[:6]
                for p in batch:
                    if "eid" in p:
                        p["eid"] = f"{p['eid']}_{suffix}"
                    p["mode"] = mode  # reuse the last mode used for consistency
                    p["attributes"] = selected_attrs  # this info might be lost; optionally track per prompt
                self.personas.extend(batch)
                recovered_batch.extend(batch)
                self._log(f"✓ Retry {i + 1}[{'+'.join(mode)}]: {len(batch)} personas recovered.")
                self.failed_prompts.remove((prompt, mode, selected_attrs))

            except Exception as e:
                self._log(f"✗ Retry {i + 1} failed: {e}", level = "warning")

            time.sleep(self.delay)
        
        return recovered_batch

    def save(self, filepath: str):
        """
        Save all generated personas to a file, one JSON object per line.

        Args:
            filepath (str): File path for output.
        """
        with open(filepath, "w", encoding="utf-8") as f:
            for p in self.personas:
                f.write(json.dumps(p) + "\n")

    def clear(self):
        """
        Clears stored personas and failed prompts.
        Useful for resetting the model between runs.
        """
        self.personas.clear()
        self.failed_prompts.clear()