import json
import re
import time
import uuid
import logging
from typing import List, Optional

from backends.registry import BackendRegistry
from modules.token_tracker import MultiAxisTokenTracker

logger = logging.getLogger(__name__)

SYSTEM_PROMPT = """You are an expert assistant trained to generate realistic, diverse, and demographically plausible personas for social science surveys. 

Each persona should include:
- `eid`: a short, lowercase, variable-safe identifier that encodes key traits (e.g., urban_liberal_30s_female). No punctuation or spaces.
- `endow_text`: a natural language description of the persona (1–5 sentences), written as if describing a survey respondent.

Instructions:
- Represent a wide range of age, gender, race/ethnicity, education, region, and political ideology.
- Avoid repetition of phrasing or demographic combinations across personas.
- Do not include explanations or formatting outside of the JSON array.
"""

class EndowmentGenerator:
    """
    Baseline persona generator (no entropy guidance, no attributes).
    Used for comparison against AEG.

    Attributes:
        backend_type (str): Backend provider ("openai", "gemini", "dummy").
        model_name (str): Model identifier.
        temperature (float): Sampling temperature.
        max_tokens (int): Max tokens per response.
        batch_size (int): Personas per batch.
        delay (float): Sleep time between calls.
        system_prompt (str): Instructional system message.
        retry_failed (bool): Retry failed generations.
        clear_after_run (bool): Clear buffer after each run.
        token_tracker (MultiAxisTokenTracker): Optional cost tracker.
    """
    def __init__(
        self,
        backend_type: str = "openai",
        model_name: str = "gpt-4o-mini",
        temperature: float = 0.9,
        max_tokens: int = 1024,
        batch_size: int = 10,
        delay: float = 0.5,
        system_prompt: str = SYSTEM_PROMPT,
        retry_failed: bool = True,
        clear_after_run: bool = False,
        token_tracker: Optional[MultiAxisTokenTracker] = None,
        verbose: bool = True,
        max_retries: int = 10,
        retry_base_delay: float = 10.0,
        **backend_kwargs
    ):
        self.backend_type = backend_type
        self.model_name = model_name
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.batch_size = batch_size
        self.delay = delay
        self.system_prompt = system_prompt
        self.retry_failed = retry_failed
        self.clear_after_run = clear_after_run
        self.token_tracker = token_tracker
        self.verbose = verbose
        self.max_retries = max_retries
        self.retry_base_delay = retry_base_delay

        backend_cls = BackendRegistry.get(self.backend_type)
        self.backend = backend_cls(
            model_name=self.model_name,
            temperature=self.temperature,
            max_tokens=self.max_tokens,
            **backend_kwargs
        )

        self.personas: List[dict] = []
        self.failed_prompts: List[str] = []

    def _log(self, msg: str, level: str = "info"):
        if self.verbose:
            log_fn = getattr(logger, level.lower(), logger.info)
            log_fn(msg)

    def _extract_json_array(self, text: str) -> List[dict]:
        match = re.search(r'\[\s*{.*?}\s*\]', text, re.DOTALL)
        if not match:
            raise ValueError("No JSON array found in response.")
        return json.loads(match.group())

    def _build_prompt(self, n: int, topic: Optional[str] = None) -> str:
        topic_line = f"The survey topic is: {topic}\n" if topic else ""
        return (
            f"Generate {n} diverse personas for a survey experiment. "
            "Each persona must include:\n"
            "- eid: short, lowercase identifier\n"
            "- endow_text: a short natural-language description\n\n"
            f"{topic_line}Ensure diversity across demographics. "
            "Return the result as a JSON array."
        )

    def generate(self, n: int, topic_description: Optional[str] = None) -> List[dict]:
        full_batch = []
        for i in range(0, n, self.batch_size):
            curr_n = min(self.batch_size, n - i)
            prompt = self._build_prompt(curr_n, topic_description)

            try:
                response = self.backend.chat(
                    system=self.system_prompt,
                    user=prompt,
                    return_usage=True,
                )
                raw_output = response["answer_text"] if isinstance(response, dict) else response
                usage = response.get("usage") if isinstance(response, dict) else None

                if self.token_tracker and usage:
                    self.token_tracker.log(
                        agent_id=f"endowment_generator_batch{i // self.batch_size + 1}",
                        input_tokens=usage.get("input_tokens", 0),
                        output_tokens=usage.get("output_tokens", 0),
                        cached_input_tokens=usage.get("cached_input_tokens", 0),
                        model_name=self.model_name,
                        module_name="EndowmentGenerator"
                    )

                batch = self._extract_json_array(raw_output)
                suffix = uuid.uuid4().hex[:6]
                for p in batch:
                    if "eid" in p:
                        p["eid"] = f"{p['eid']}_{suffix}"
                self.personas.extend(batch)
                full_batch.extend(batch)
                self._log(f"✓ Batch {i // self.batch_size + 1}: {len(batch)} personas generated.")

            except Exception as e:
                self._log(f"✗ Batch {i // self.batch_size + 1} failed: {e}", level="warning")
                self.failed_prompts.append(prompt)

            time.sleep(self.delay)

        if self.retry_failed and self.failed_prompts:
            recovered = self.retry_failed_batches()
            full_batch.extend(recovered)

        if self.clear_after_run:
            self.clear()

        return full_batch

    def retry_failed_batches(self) -> List[dict]:
        self._log("\nRetrying failed prompts...")
        recovered = []
        for i, prompt in enumerate(self.failed_prompts[:]):
            for attempt in range(self.max_retries):
                # Exponential backoff: 10s, 20s, 40s, 80s, 160s
                wait_time = self.retry_base_delay * (2 ** attempt)
                self._log(f"Retry {i + 1}, attempt {attempt + 1}/{self.max_retries}: waiting {wait_time:.0f}s...")
                time.sleep(wait_time)
                
                try:
                    response = self.backend.chat(
                        system=self.system_prompt,
                        user=prompt,
                        return_usage=True,
                    )
                    raw_output = response["answer_text"] if isinstance(response, dict) else response
                    usage = response.get("usage") if isinstance(response, dict) else None

                    if self.token_tracker and usage:
                        self.token_tracker.log(
                            agent_id=f"endowment_generator_retry{i + 1}",
                            input_tokens=usage.get("input_tokens", 0),
                            output_tokens=usage.get("output_tokens", 0),
                            cached_input_tokens=usage.get("cached_input_tokens", 0),
                            model_name=self.model_name,
                            module_name="EndowmentGenerator"
                        )

                    batch = self._extract_json_array(raw_output)
                    suffix = uuid.uuid4().hex[:6]
                    for p in batch:
                        if "eid" in p:
                            p["eid"] = f"{p['eid']}_{suffix}"
                    self.personas.extend(batch)
                    recovered.extend(batch)
                    self.failed_prompts.remove(prompt)
                    self._log(f"✓ Retry {i + 1}: {len(batch)} personas recovered.")
                    break  # Success - exit retry loop
                    
                except Exception as e:
                    self._log(f"✗ Retry {i + 1}, attempt {attempt + 1} failed: {e}", level="warning")
                    if attempt == self.max_retries - 1:
                        self._log(f"✗ Giving up on prompt {i + 1} after {self.max_retries} attempts", level="warning")

        return recovered

    def save(self, filepath: str):
        with open(filepath, "w", encoding="utf-8") as f:
            for p in self.personas:
                f.write(json.dumps(p) + "\n")

    def clear(self):
        self.personas.clear()
        self.failed_prompts.clear()