import re
import ast
import logging
from typing import List, Tuple, Optional
from backends.registry import BackendRegistry
from modules.survey_converter import Survey
from modules.token_tracker import MultiAxisTokenTracker

logger = logging.getLogger(__name__)

SYSTEM_PROMPT = """
You are an intelligent research assistant trained to analyze survey questions and infer which human attributes might influence how individuals respond.

You are given a set of training-only survey questions. Your task is to propose a list of relevant human attributes—such as demographics, beliefs, values, personality traits, or ideological leanings—that are likely to affect responses to these questions.

Carefully analyze the content and framing of the questions. Identify underlying factors that might shape how different people respond. Focus on attributes that are salient, discriminative, and potentially variable across respondents.

Respond **only** with a Python-style list of double-quoted strings. Do not include any explanation, headers, or prose before or after the list.

**Example output format:**
["attribute1", "attribute2", "attribute3", "attribute4"]
"""

QUESTION_ATTRIBUTE_PROMPT = """
You are an intelligent research assistant trained to analyze individual survey questions and infer which human attributes might influence how different people respond.

You are given a single survey question. Your task is to propose a list of relevant human attributes—such as demographics, beliefs, values, personality traits, or ideological leanings—that are likely to shape responses to this question.

Focus on underlying factors that would cause meaningful variation in answers across different types of people. Avoid generic or overly broad attributes.

Respond **only** with a Python-style list of double-quoted strings. Do not include any explanation, headers, or prose before or after the list.

**Example output format:**
["religious affiliation", "political ideology", "trust in government"]
"""

class AttributeLearner:
    """
    Infers latent human attributes (e.g., demographics, beliefs, personality traits) 
    that likely influence how individuals respond to survey questions, using a 
    backend-agnostic language model interface.

    This class supports two modes of inference:
    - Global attribute inference across the full set of training survey questions
    - Local attribute inference for individual survey questions

    The inferred attributes serve as building blocks for generating respondent 
    endowments in simulation-based survey experiments.

    A pluggable backend architecture enables integration with multiple LLM providers 
    (e.g., OpenAI, Gemini), while enforcing a structured output format: each model 
    response is expected to contain a Python-style list of attribute strings.

    Attributes:
        survey (Survey): The survey object containing question data.
        backend (BaseLLMBackend): Instantiated backend used to generate responses.
        system_prompt (str): Prompt template for global attribute inference.
        question_attribute_prompt (str): Prompt template for local attribute inference.
        max_attributes (int or None): Optional cap on number of attributes returned.
        attributes (List[str]): Internal store of all extracted attributes.
    """
    def __init__(self, 
                survey: Survey, 
                backend_type: str = "openai",
                model_name: str = "gpt-4o-mini",
                temperature: float = 0.3,
                max_tokens = 1024,
                system_prompt: str = SYSTEM_PROMPT,
                question_attribute_prompt: str = QUESTION_ATTRIBUTE_PROMPT,
                max_attributes: Optional[int] = None,
                verbose: bool = True,
                logger: Optional[logging.Logger] = None,
                token_tracker: Optional[MultiAxisTokenTracker] = None,
                **backend_kwargs):
        """
        Initializes the AttributeLearner, which uses a language model backend to infer latent
        human attributes (e.g., beliefs, demographics, values) that likely shape survey responses.

        The learner supports both global attribute inference across the full training survey
        and local attribute inference for individual questions. It is backend-agnostic and can
        interface with OpenAI, Gemini, or custom LLM providers via `BackendRegistry`.

        Args:
            survey (Survey): A Survey object containing question IDs and texts.
            backend_type (str): Type of LLM backend to use (e.g., 'openai', 'gemini', 'dummy').
            model_name (str): Name of the model to use for the backend (e.g., 'gpt-4', 'gemini-1.5').
            temperature (float): Sampling temperature for LLM generation.
            max_tokens (int): Maximum number of tokens for the LLM response.
            system_prompt (str): Prompt used for full-survey (global) attribute inference.
            question_attribute_prompt (str): Prompt used for per-question (local) attribute inference.
            max_attributes (int, optional): Maximum number of attributes to return per call. Defaults to None (no cap).
            verbose (bool): Whether to print log messages during execution.
            logger (logging.Logger, optional): Optional custom logger. If None, a module-level logger is used.
            **backend_kwargs: Additional keyword arguments passed to the backend initialization (e.g., API keys, proxies).
        """
        self.survey = survey
        self.backend_type = backend_type
        self.model_name = model_name
        self.temperature = temperature
        self.system_prompt = system_prompt
        self.question_attribute_prompt = question_attribute_prompt
        self.max_tokens = max_tokens
        self.max_attributes = max_attributes
        self.verbose = verbose
        self.logger = logger or logging.getLogger(__name__)
        self.token_tracker = token_tracker
        self.qid_to_text = {q["id"]: survey.get_prompt_text(q) for q in survey.get_questions_by_split("train")}
        # Instantiate backend via registry
        backend_cls = BackendRegistry.get(self.backend_type)
        self.backend = backend_cls(
            model_name=self.model_name,
            temperature=self.temperature,
            max_tokens=self.max_tokens,
            **backend_kwargs
        )

        self.attributes: List[str] = []
    
    def _log(self, msg: str, level: str = "info"):
        """
        Internal logging utility with optional verbosity control.

        Args:
            msg (str): The message to log.
            level (str): Logging level ("debug", "info", "warning", "error", "critical").
        """
        if self.verbose:
            log_fn = getattr(self.logger, level.lower(), self.logger.info)
            log_fn(msg)

    def _safe_parse_attribute_list(self, raw_text: str) -> List[str]:
        """
        Safely extracts a Python-style list of strings from raw LLM output.

        Args:
            raw_text (str): Raw string returned by the LLM.

        Returns:
            List[str]: Parsed list of attribute strings. Returns empty list on failure.
        """
        match = re.search(r"\[.*?\]", raw_text, re.DOTALL)
        if not match:
            self._log("No list found in model output.", level = "warning")
            return []

        try:
            parsed = ast.literal_eval(match.group(0))
            if isinstance(parsed, list) and all(isinstance(x, str) for x in parsed):
                return parsed
        except Exception as e:
            self._log(f"Parsing failed: {e}", level="warning")
        
        return []

    def generate_survey_attributes(self) -> List[str]:
        """
        Uses the LLM to analyze all training survey questions and extract a global
        list of relevant respondent attributes.

        Returns:
            List[str]: List of attribute strings (e.g., ["political ideology", "age group"]).
        """
        full_survey_text = "\n".join(self.qid_to_text.values())

        system = self.system_prompt
        if self.max_attributes:
            system += f"\nDo not exceed {self.max_attributes}."

        result = self.backend.chat(system=system, user=full_survey_text, return_usage=True)

        raw_output = result["answer_text"] if isinstance(result, dict) else result
        usage = result.get("usage") if isinstance(result, dict) else None

        if self.token_tracker and usage:
            self.token_tracker.log(
                agent_id="attribute_learner_survey",
                input_tokens=usage.get("input_tokens", 0),
                output_tokens=usage.get("output_tokens", 0),
                cached_input_tokens=usage.get("cached_input_tokens", 0),
                model_name=self.model_name,
                module_name="AttributeLearner"
            )
        elif self.token_tracker:
            self._log("Token tracker attached but usage data not returned.", level="warning")

        attributes = self._safe_parse_attribute_list(raw_output)

        self.attributes.extend(attributes)

        return attributes
    
    def generate_question_attributes(self, qid, max_attributes = None) -> List[str]:
        """
        Uses the LLM to analyze a single survey question and infer relevant human
        attributes that would affect how people respond.

        Args:
            qid (str): ID of the question to analyze.
            max_attributes (int or None): Optional override for attribute cap.

        Returns:
            List[str]: List of question-specific attribute strings.
        """
        if not max_attributes:
            max_attributes = self.max_attributes
        question_text = self.qid_to_text[qid]
        
        system = self.question_attribute_prompt
        if max_attributes or self.max_attributes:
            cap = max_attributes or self.max_attributes
            system += f"\nDo not exceed {cap}."

        result = self.backend.chat(system=system, user=question_text, return_usage=True)

        raw_output = result["answer_text"] if isinstance(result, dict) else result
        usage = result.get("usage") if isinstance(result, dict) else None

        if self.token_tracker and usage:
            self.token_tracker.log(
                agent_id=f"attribute_learner_qid_{qid}",
                input_tokens=usage.get("input_tokens", 0),
                output_tokens=usage.get("output_tokens", 0),
                cached_input_tokens=usage.get("cached_input_tokens", 0),
                model_name=self.model_name,
                module_name="AttributeLearner"
            )
        elif self.token_tracker:
            self._log("Token tracker attached but usage data not returned.", level="warning")
            
        attributes = self._safe_parse_attribute_list(raw_output)

        self.attributes.extend(attributes)

        return attributes