from typing import List, Tuple, Optional
import time
import json
import re
import random
import uuid
import logging
from backends.registry import BackendRegistry
from modules.token_tracker import MultiAxisTokenTracker

logger = logging.getLogger(__name__)

SYSTEM_PROMPT = """You are an expert assistant trained to generate realistic, diverse, and demographically plausible personas for social science surveys. 

Each persona should include:
- `eid`: a short, lowercase, variable-safe identifier that encodes key traits (e.g., urban_liberal_30s_female). No punctuation or spaces.
- `endow_text`: a natural language description of the persona (1-5 sentences), written as if describing a survey respondent.

Instructions:
- Represent a wide range of age, gender, race/ethnicity, education, region, and political ideology.
- Avoid repetition of phrasing or demographic combinations across personas.
- Do not include explanations or formatting outside of the JSON array.
"""

class EndowmentModel():
    """
    A configurable persona generator that uses a backend LLM to create realistic, 
    demographically diverse endowments for social science surveys.

    This class interfaces with a registry of LLM backends (e.g., OpenAI, Gemini, Dummy) 
    to generate persona dictionaries based on user-defined attribute sets. Each 
    generated persona includes an identifier (`eid`) and a natural-language description 
    (`endow_text`), along with metadata such as mode and attribute list.

    Attributes:
        backend_type (str): Backend provider name (e.g., "openai", "gemini", "dummy").
        model_name (str): LLM model name (backend-specific).
        temperature (float): Sampling temperature for backend generation.
        max_tokens (int): Max tokens allowed per API response.
        batch_size (int): Number of personas generated per request batch.
        delay (float): Sleep duration (in seconds) between API calls.
        system_prompt (str): Instructional message defining LLM behavior.
        retry_failed (bool): Whether to retry failed prompt batches.
        randomize_attributes (bool): Whether to randomly sample a subset of attributes if too many.
        max_attributes (int): Max number of attributes to include in prompt.
        clear_after_run (bool): Whether to clear internal buffers after generation.
        seed (Optional[int]): Random seed for reproducible attribute sampling.
        logger (logging.Logger): Optional logger (defaults to module-level logger).
        verbose (bool): Whether to print progress and logging messages during generation.
    """
    def __init__(
        self,
        backend_type: str = "openai",
        model_name: str = "gpt-4o-mini",
        temperature: float = 0.9,
        max_tokens: int = 2048,
        batch_size: int = 10,
        delay: float = 0.5,
        system_prompt: str = SYSTEM_PROMPT,
        retry_failed: bool = True,
        randomize_attributes: bool = True,
        max_attributes: int = 12,
        clear_after_run: bool = True,
        seed: Optional[int] = None,
        logger: Optional[logging.Logger] = None,
        verbose: bool = True,
        token_tracker: Optional[MultiAxisTokenTracker] = None,
        **backend_kwargs
    ):
        self.backend_type = backend_type
        self.model_name = model_name
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.batch_size = batch_size
        self.delay = delay
        self.system_prompt = system_prompt
        self.retry_failed = retry_failed
        self.randomize_attributes = randomize_attributes
        self.max_attributes = max_attributes
        self.clear_after_run = clear_after_run
        self.random_state = random.Random(seed)
        
        self.logger = logger or logging.getLogger(__name__)
        self.verbose = verbose
        self.token_tracker = token_tracker
        
        # Instantiate backend via registry
        backend_cls = BackendRegistry.get(self.backend_type)
        self.backend = backend_cls(
            model_name=self.model_name,
            temperature=self.temperature,
            max_tokens=self.max_tokens,
            **backend_kwargs
        )
        self.personas: List[dict] = []
        self.failed_prompts: List[str] = []

    def _log(self, msg: str, level: str = "info"):
        """
        Internal logging utility with optional verbosity control.

        Args:
            msg (str): The message to log.
            level (str): Logging level ("debug", "info", "warning", "error", "critical").
        """
        if self.verbose:
            log_fn = getattr(self.logger, level.lower(), self.logger.info)
            log_fn(msg)

    def _extract_json_array(self, text: str) -> List[dict]:
        """
        Extract a JSON array from a raw LLM response.

        Args:
            text (str): Model output text.

        Returns:
            List[dict]: Parsed list of persona dictionaries.

        Raises:
            ValueError: If no valid JSON array is found.
        """
        match = re.search(r'\[\s*{.*?}\s*\]', text, re.DOTALL)
        if not match:
            raise ValueError("No JSON array found in response.")
        return json.loads(match.group())

    def _build_prompt(self, attributes: List[str], n: int) -> str:
        """
        Construct a user prompt for persona generation.

        Args:
            attributes (List[str]): A list of attribute names.
            n (int): Number of personas to request.

        Returns:
            str: Fully formatted user prompt string.
        """
        attr_string = ", ".join(attributes)
        return (
            f"Generate {n} diverse persona(s) that vary meaningfully along the following attributes: {attr_string}.\n"
            "Each persona should reflect a distinct combination or value expression of these traits.\n"
            "Return a JSON array of dictionaries, each with:\n"
            "- `eid`: a short, lowercase, variable-safe identifier\n"
            "- `endow_text`: a brief natural-language description of the persona"
        )

    def generate_endowments(self, attributes: List[str], n: int, mode: tuple) -> List[dict]:
        """
        Generate personas via OpenAI API using the configured model and prompt logic.

        Args:
            attributes (List[str]): List of attribute names to condition on.
            n (int): Total number of personas to generate.
            mode (tuple): Mode identifier used to tag generated personas.

        Returns:
            List[dict]: List of generated persona dictionaries with `eid`, `endow_text`, `mode`, and `attributes`.

        Notes:
            - Automatically retries failed batches if `retry_failed` is True.
            - Uses `batch_size` to split into multiple API calls.
        """        
        selected_attrs = attributes
        if self.randomize_attributes and len(attributes) > self.max_attributes:
            selected_attrs = self.random_state.sample(attributes, k=self.max_attributes)

        full_batch = []
        for i in range(0, n, self.batch_size):
            curr_n = min(self.batch_size, n - i)
            prompt = self._build_prompt(selected_attrs, curr_n)

            try:
                response = self.backend.chat(
                    system=self.system_prompt,
                    user=prompt,
                    return_usage=True,
                )

                raw_output = response["answer_text"] if isinstance(response, dict) else response
                usage = response.get("usage") if isinstance(response, dict) else None
                if self.token_tracker and usage:
                    self.token_tracker.log(
                        agent_id=f"endowment_model_mode_{'+'.join(mode)}_batch{i // self.batch_size + 1}",
                        input_tokens=usage.get("input_tokens", 0),
                        output_tokens=usage.get("output_tokens", 0),
                        cached_input_tokens=usage.get("cached_input_tokens", 0),
                        model_name=self.model_name,
                        module_name="EndowmentModel"
                    )

                batch = self._extract_json_array(raw_output)
                suffix = uuid.uuid4().hex[:6]
                for p in batch:
                    if "eid" in p:
                        p["eid"] = f"{p['eid']}_{suffix}"
                    p["mode"] = mode
                    p["attributes"] = selected_attrs
                self.personas.extend(batch)
                full_batch.extend(batch)
                self._log(f"✓ Batch {i // self.batch_size + 1} [{'+'.join(mode)}]: {len(batch)} personas generated.")

            except Exception as e:
                self._log(f"✗ Batch {i // self.batch_size + 1} failed: {e}", level= "warning")
                self.failed_prompts.append((prompt, mode, selected_attrs))

            time.sleep(self.delay)

        if self.retry_failed and self.failed_prompts:
            recovered_batch = self.retry_failed_batches()
            full_batch.extend(recovered_batch)

        if self.clear_after_run:
            self.clear()
            
        return full_batch

    def retry_failed_batches(self):
        """
        Retry all failed prompt batches and attempt to regenerate personas.

        Returns:
            List[dict]: List of successfully recovered personas.
        """
        self._log("\nRetrying failed prompts...")
        remaining = self.failed_prompts[:]
        recovered_batch = []
        for i, (prompt, mode, selected_attrs) in enumerate(remaining):
            try:
                response = self.backend.chat(
                    system=self.system_prompt,
                    user=prompt,
                    return_usage=True,
                )

                raw_output = response["answer_text"] if isinstance(response, dict) else response
                usage = response.get("usage") if isinstance(response, dict) else None

                if self.token_tracker and usage:
                    self.token_tracker.log(
                        agent_id=f"endowment_model_mode_{'+'.join(mode)}_retry{i + 1}",
                        input_tokens=usage.get("input_tokens", 0),
                        output_tokens=usage.get("output_tokens", 0),
                        cached_input_tokens=usage.get("cached_input_tokens", 0),
                        model_name=self.model_name,
                        module_name="EndowmentModel"
                    )

                batch = self._extract_json_array(raw_output)
                suffix = uuid.uuid4().hex[:6]
                for p in batch:
                    if "eid" in p:
                        p["eid"] = f"{p['eid']}_{suffix}"
                    p["mode"] = mode  # reuse the last mode used for consistency
                    p["attributes"] = selected_attrs  # this info might be lost; optionally track per prompt
                self.personas.extend(batch)
                recovered_batch.extend(batch)
                self._log(f"✓ Retry {i + 1}[{'+'.join(mode)}]: {len(batch)} personas recovered.")
                self.failed_prompts.remove((prompt, mode, selected_attrs))

            except Exception as e:
                self._log(f"✗ Retry {i + 1} failed: {e}", level = "warning")

            time.sleep(self.delay)
        
        return recovered_batch

    def save(self, filepath: str):
        """
        Save all generated personas to a file, one JSON object per line.

        Args:
            filepath (str): File path for output.
        """
        with open(filepath, "w", encoding="utf-8") as f:
            for p in self.personas:
                f.write(json.dumps(p) + "\n")

    def clear(self):
        """
        Clears stored personas and failed prompts.
        Useful for resetting the model between runs.
        """
        self.personas.clear()
        self.failed_prompts.clear()