from openai import OpenAI
import re
import ast
import logging
from typing import List, Tuple, Optional
from modules.survey_converter import Survey, BinaryExtendedSurvey

logger = logging.getLogger(__name__)

SYSTEM_PROMPT = """
You are an intelligent research assistant trained to analyze survey questions and infer which human attributes might influence how individuals respond.

You are given a set of training-only survey questions. Your task is to propose a list of relevant human attributes—such as demographics, beliefs, values, personality traits, or ideological leanings—that are likely to affect responses to these questions.

Carefully analyze the content and framing of the questions. Identify underlying factors that might shape how different people respond. Focus on attributes that are salient, discriminative, and potentially variable across respondents.

Respond **only** with a Python-style list of double-quoted strings. Do not include any explanation, headers, or prose before or after the list.

**Example output format:**
["attribute1", "attribute2", "attribute3", "attribute4"]
"""

QUESTION_ATTRIBUTE_PROMPT = """
You are an intelligent research assistant trained to analyze individual survey questions and infer which human attributes might influence how different people respond.

You are given a single survey question. Your task is to propose a list of relevant human attributes—such as demographics, beliefs, values, personality traits, or ideological leanings—that are likely to shape responses to this question.

Focus on underlying factors that would cause meaningful variation in answers across different types of people. Avoid generic or overly broad attributes.

Respond **only** with a Python-style list of double-quoted strings. Do not include any explanation, headers, or prose before or after the list.

**Example output format:**
["religious affiliation", "political ideology", "trust in government"]
"""

class AttributeLearner:
    """
    Uses a language model to infer latent human attributes (e.g., demographics, beliefs,
    values) that are likely to shape how individuals respond to survey questions.

    Attributes are inferred either globally (across the entire training survey) or
    locally (specific to an individual question). This enables generation of 
    persona endowments for survey simulations.

    Attributes are extracted using LLM prompts that analyze question semantics
    and framing. Responses are expected to be a Python list of strings.
    """
    def __init__(self, 
                survey: Survey, 
                model: str = "gpt-4o",
                temperature: float = 0.3,
                max_tokens = 1024,
                system_prompt: str = SYSTEM_PROMPT,
                question_attribute_prompt: str = QUESTION_ATTRIBUTE_PROMPT,
                max_attributes = None,
                verbose: bool = True,
                logger: Optional[logging.Logger] = None,
                ):
        """
        Initializes the AttributeLearner.

        Args:
            survey: A Survey object containing questions with ID and text fields.
            model (str): OpenAI model to use (e.g., 'gpt-4o').
            temperature (float): Sampling temperature for the LLM.
            max_tokens (int): Max token length of LLM response.
            system_prompt (str): System prompt for full-survey attribute inference.
            question_attribute_prompt (str): System prompt for per-question attribute inference.
            max_attributes (int or None): If specified, caps the number of attributes returned.
            verbose (bool): Whether to enable logging output.
            logger (Logger or None): Custom logger. Defaults to module-level logger.
        """
        self.survey = survey
        self.model = model
        self.temperature = temperature
        self.system_prompt = system_prompt
        self.question_attribute_prompt = question_attribute_prompt
        self.max_tokens = max_tokens
        self.max_attributes = max_attributes
        self.verbose = verbose
        self.logger = logger or logging.getLogger(__name__)
        self.qid_to_text = {q["id"]: survey.get_prompt_text(q) for q in survey.get_questions_by_split("train")}
        self.client = OpenAI()
        self.attributes: List[str] = []
    
    def _log(self, msg: str, level: str = "info"):
        """
        Internal logging utility with optional verbosity control.

        Args:
            msg (str): The message to log.
            level (str): Logging level ("debug", "info", "warning", "error", "critical").
        """
        if self.verbose:
            log_fn = getattr(self.logger, level.lower(), self.logger.info)
            log_fn(msg)

    def _safe_parse_attribute_list(self, raw_text: str) -> List[str]:
        """
        Safely extracts a Python-style list of strings from raw LLM output.

        Args:
            raw_text (str): Raw string returned by the LLM.

        Returns:
            List[str]: Parsed list of attribute strings. Returns empty list on failure.
        """
        match = re.search(r"\[.*?\]", raw_text, re.DOTALL)
        if not match:
            self._log("No list found in model output.", level = "warning")
            return []

        try:
            parsed = ast.literal_eval(match.group(0))
            if isinstance(parsed, list) and all(isinstance(x, str) for x in parsed):
                return parsed
        except Exception as e:
            self._log(f"Parsing failed: {e}", level="warning")
        
        return []

    def generate_survey_attributes(self) -> List[str]:
        """
        Uses the LLM to analyze all training survey questions and extract a global
        list of relevant respondent attributes.

        Returns:
            List[str]: List of attribute strings (e.g., ["political ideology", "age group"]).
        """
        full_survey_text = "\n".join(self.qid_to_text.values())

        response = self.client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": self.system_prompt+ (f"\nDo not exceed {self.max_attributes}." if self.max_attributes else "")},
                {"role": "user", "content": full_survey_text}
            ],
            temperature=self.temperature,
            max_tokens=self.max_tokens,
        )

        raw_output = response.choices[0].message.content.strip()
        attributes = self._safe_parse_attribute_list(raw_output)

        self.attributes.extend(attributes)

        return attributes
    
    def generate_question_attributes(self, qid, max_attributes = None) -> List[str]:
        """
        Uses the LLM to analyze a single survey question and infer relevant human
        attributes that would affect how people respond.

        Args:
            qid (str): ID of the question to analyze.
            max_attributes (int or None): Optional override for attribute cap.

        Returns:
            List[str]: List of question-specific attribute strings.
        """
        if not max_attributes:
            max_attributes = self.max_attributes
        question_text = self.qid_to_text[qid]

        response = self.client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": self.question_attribute_prompt+ (f"\nDo not exceed {max_attributes}." if max_attributes else "")},
                {"role": "user", "content": question_text}
            ],
            temperature=self.temperature,
            max_tokens=self.max_tokens,
        )

        raw_output = response.choices[0].message.content.strip()
        attributes = self._safe_parse_attribute_list(raw_output)

        self.attributes.extend(attributes)

        return attributes