from collections import defaultdict, Counter
import random
import warnings
import yaml
import time
import logging
from tqdm import tqdm
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed
from itertools import islice
import os

logger = logging.getLogger(__name__)

# Resolve path to project root and config file
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, ".."))
CONFIG_PATH = os.path.join(PROJECT_ROOT, "config/attribute_banks", "attribute_bank.yaml")

from modules.response_converter import Responses
from modules.survey_conductor import SurveyConductor
from modules.endowment_manager import ActiveEndowments
from generators.theme_variability_tracker import ThemeVariabilityTracker
from modules.dataclasses import AgentConfig

# Load the attribute bank
with open(CONFIG_PATH, "r") as f:
    ATTRIBUTE_BANK = yaml.safe_load(f)

class ActiveEndowmentGenerator:
    """
    Orchestrates the active generation of agent endowments to maximize response variability 
    in simulated survey experiments. This includes initial sampling, adaptive expansion 
    based on mode/question entropy, response simulation, and logging.

    Attributes:
        survey: A Survey object containing question definitions and metadata.
        endowment_model: The model used to generate new endowments (LLM wrapper).
        attribute_learner: Object for extracting attributes from survey/questions.
        agent_config: Configuration for survey respondent agent instantiation and LLM model parameters.
        tracker: ThemeVariabilityTracker instance for monitoring response entropy.
        target_n: Total number of endowments to generate.
        initial_n: Number of endowments to sample per mode in the initial round.
        num_update_steps: Number of expansion steps allowed.
        parallel: Whether to run survey simulations in parallel.
        max_workers: Number of threads to use for parallel simulation.
        question_patch: Dictionary specifying question-specific sampling ratio and top-k.
        seed: Random seed for reproducibility.
        response_log: Stores all agent-level response records.
        mode_history: Stores mode and attribute metadata for each sampling stage.
    """    
    def __init__(self, survey, endowment_model,attribute_learner, agent_config: AgentConfig,  attribute_bank = ATTRIBUTE_BANK, variability_tracker=None, target_n=100, initial_n = 10, num_update_steps = 4, seed=None, parallel=False, max_workers = 10, verbose=True, question_patch: dict = None, logger=None):
        """
        Initialize the generator with all components needed for attribute-based 
        endowment generation and adaptive sampling.

        Args:
            survey: The survey object containing questions.
            endowment_model: A generator that produces endowments based on attribute lists.
            attribute_learner: An object that can generate attributes from questions or the survey.
            agent_config: AgentConfig dataclass defining agent parameters.
            attribute_bank: Dictionary storing attributes by mode or question.
            variability_tracker: Optional tracker to compute entropy-based weights.
            target_n: Total number of endowments to produce.
            initial_n: Initial number of endowments per mode.
            num_update_steps: Number of active update steps.
            seed: Seed for reproducible sampling.
            parallel: Whether to simulate agent responses in parallel.
            max_workers: Max threads if using parallel simulation.
            verbose: Whether to print log messages during operation.
            question_patch: Dict controlling question-based sampling in each update.
            logger: Optional logger instance for consistent logging across modules. Defaults to module-level logger if not provided.
        """
        self.survey = survey
        self.attribute_learner = attribute_learner
        self.endowment_model = endowment_model
        self.agent_config = agent_config
        self.tracker = variability_tracker or ThemeVariabilityTracker(
            questions=self.survey.get_questions_by_split("train"),
            smoothing= 0
        )

        self.target_n = target_n
        self.initial_n = initial_n
        self.num_update_steps = num_update_steps
        self.seed = seed
        self.parallel = parallel
        self.verbose = verbose
        self.max_workers = max_workers

        self.endowment_manager = ActiveEndowments.from_endowment_list([])

        self.question_patch = question_patch or {
            "fraction": 0.3,
            "top_k": 3,
            "min_repeats": 2  # threshold to trigger reallocation of ineffective question modes
        }

        self.mode_history = []
        self.response_log = []
        self.attribute_bank = attribute_bank
        self.random_state = random.Random(seed)
        self.logger = logger or logging.getLogger(__name__)

        self.qid_low_entropy_counts = Counter()

        # Learn the attribute and add it to the ATTRIBUTE_BANK
        self.add_survey_attributes()

    def add_survey_attributes(self):
        """
        Automatically extracts and registers attributes related to the overall survey.
        Updates the `survey` template and attribute lists in the attribute bank.
        """
        attributes = self.attribute_learner.generate_survey_attributes()
        
        # Ensure 'survey' is added only once to the template list
        if "survey" not in self.attribute_bank["templates"]:
            self.attribute_bank["templates"]["survey"] = ["survey"]

        self.attribute_bank["attributes"]["survey"] = attributes

        if self.verbose:
            print(f"[Attribute Bank Expansion] Added {len(attributes)} survey specific attributes: \n {attributes}")
        return
    
    def add_question_attributes(self,qid, max_attributes = None):
        """
        Generates and registers attributes for a specific question (qid), if not already present.

        Args:
            qid: Question ID.
            max_attributes: Optional maximum number of attributes to generate.
        """
        attributes = self.attribute_learner.generate_question_attributes(qid, max_attributes)
        # Ensure "question" template label is added only once
        if "question" not in self.attribute_bank["templates"]:
            self.attribute_bank["templates"]["question"] = []
        if qid not in self.attribute_bank["templates"]["question"]:
            self.attribute_bank["templates"]["question"].append(qid)

        # Avoid overwriting if already present
        if qid not in self.attribute_bank["attributes"]:
            self.attribute_bank["attributes"][qid] = attributes
            if self.verbose:
                print(f"[Attribute Bank Expansion] Added {len(attributes)} question specific attributes: \n {attributes}")
        else:
            if self.verbose:
                print(f"[Attribute Bank] Attributes for QID '{qid}' already exist, skipping")
        return

    def _log(self, msg, level: str = "info"):
        """
        Utility method for conditional logging with level control.

        Args:
            msg (str): The message to log.
            level (str): Logging level as string ("debug", "info", "warning", "error", "critical").
        """
        if getattr(self, 'verbose', False) and self.logger:
            log_fn = getattr(self.logger, level.lower(), self.logger.info)
            log_fn(msg)

    def generate(self):
        """
        Orchestrates the full active generation loop. Starts with initial sampling, then 
        repeatedly expands the pool using entropy-weighted sampling until `target_n` 
        endowments are generated.

        Returns:
            An ActiveEndowments manager with all generated endowments.
        """
        self._log("Starting initial sampling...")
        self._initial_sampling()

        updates_done = 0

        while len(self.endowment_manager.endowments) < self.target_n:
            remaining = self.target_n - len(self.endowment_manager)
            steps_left = max(1, self.num_update_steps - updates_done)
            expansion_batch_size = min(remaining, int(np.ceil(remaining / steps_left)))

            self._log(f"Expanding pool: current size = {len(self.endowment_manager)}, target = {self.target_n}")
            self._expand_pool(expansion_batch_size)
            updates_done += 1

        self._log(f"Generation complete. Final pool size: {len(self.endowment_manager)}")

        mode_counts = Counter(tuple(e['mode']) for e in self.endowment_manager.get_endowments())
        self._log(f"Unique modes sampled: {len(mode_counts)}")
        for mode, count in sorted(mode_counts.items(), key=lambda x: -x[1]):
            self._log(f"  Mode {mode}: {count} endowments")

        return self.endowment_manager

    def _initial_sampling(self):
        """
        Performs the initial round of sampling across predefined mode categories: 
        core, thematic, theoretical, and survey-specific. If no structured modes 
        are found, falls back to survey-specific mode only. Also runs the first batch 
        of agent simulations to seed entropy estimation.
        """
        n = self.initial_n
        templates = self.attribute_bank["templates"]
        mode_n_list = []

        found_presets = False

        # Core modes
        if templates.get("core"):
            self._log("Queuing core modes:")
            for core_mode in templates["core"]:
                self._log(f"  Queued: core mode `{core_mode}` with n={n}")
                mode_n_list.append(((core_mode,), n))
            found_presets = True

        # Thematic modes
        if templates.get("thematic"):
            self._log("Queuing thematic modes:")
            for theme in templates["thematic"]:
                mode = ("core", theme)
                self._log(f"  Queued: thematic mode `{mode}` with n={n}")
                mode_n_list.append((mode, n))
            found_presets = True

        # Theoretical modes
        if templates.get("theoretical"):
            self._log("Queuing theoretical modes:")
            for theory in templates["theoretical"]:
                mode = ("core", theory)
                self._log(f"  Queued: theoretical mode `{mode}` with n={n}")
                mode_n_list.append((mode, n))
            found_presets = True

        if not found_presets:
            self._log("No preset core/thematic/theoretical modes found. Skipping structured sampling.")

        # Survey-specific mode is always included
        self._log(f"Queuing survey-specific mode:  n={n}")
        mode_n_list.append((("survey",), n))

        self._log(f"Sampling {len(mode_n_list)} modes in parallel...")
        new_endowments = []
        new_endowments.extend(
            batch for batch in self._parallel_sample_modes(mode_n_list) if batch
        )

        self._log("Simulating survey responses...")
        agent_records = self._simulate_survey_batch(new_endowments)
        filtered_records = self.tracker.filter_records(agent_records, survey=self.survey)
        self.tracker.update_from_records(filtered_records)

    def _expand_pool(self, batch_total=10, question_patch = None):
        """
        Expands the endowment pool by sampling new agent modes based on entropy-aware 
        softmax weights (from tracker) and question-specific entropy targeting.

        Args:
            batch_total: Total number of new agents to add.
            question_patch: Overrides default patch config for question sampling.
        """
        if not question_patch:
            question_patch = self.question_patch
        question_quota = int(batch_total * question_patch["fraction"])
        base_quota = batch_total - question_quota

        mode_n_list = []

        # === BASE MODES: thematic/core/theoretical ===
        weights = self.tracker.get_softmax_weights(temperature=0.3)
        modes = list(weights.keys())
        probs = np.array([weights[m] for m in modes])
        draws = np.random.multinomial(base_quota, probs)
        allocation = Counter({modes[i]: draws[i] for i in range(len(modes)) if draws[i] > 0})

        for mode, n in allocation.items():
            self._log(f"Queuing base mode: {mode} with n={n}")
            mode_n_list.append((mode, n))

        
        # === QUESTION-SPECIFIC MODES: even allocation across top-k low-entropy questions ===
        if question_quota > 0:
            top_k = question_patch["top_k"]
            min_repeats = question_patch["min_repeats"]
            low_entropy_qids = self.tracker.get_low_entropy_questions(top_k=top_k)

            for qid in low_entropy_qids:
                self.qid_low_entropy_counts[qid] += 1

            per_q = question_quota // top_k
            leftovers = question_quota % top_k

            top_mode = max(weights.items(), key=lambda x: x[1])[0]

            for i, qid in enumerate(low_entropy_qids):
                n = per_q + (1 if i < leftovers else 0)
                if n <= 0:
                    continue

                if qid not in self.attribute_bank["attributes"]:
                    self.add_question_attributes(qid)
                else:
                    self._log(f"[Attribute Bank] Attributes for QID '{qid}' already exist, skipping")

                # Reallocate if ineffective
                if self.qid_low_entropy_counts[qid] >= min_repeats:
                    mixed_mode = tuple(set(top_mode + (qid,)))
                    self._log(f"Reallocating question mode '{qid}' to mixed mode: {mixed_mode} with n={n}")
                    mode_n_list.append((mixed_mode, n))
                else:
                    q_mode = (qid,)
                    self._log(f"Queuing question-specific mode: {q_mode} with n={n}")
                    mode_n_list.append((q_mode, n))

        self._log(f"Sampling {len(mode_n_list)} modes in parallel...")
        new_endowments = [
            batch for batch in self._parallel_sample_modes(mode_n_list)
            if batch
        ]

        self._log(f"Simulating survey for {len(new_endowments)} sampled mode(s)...")
        agent_records = self._simulate_survey_batch(new_endowments)
        filtered_records = self.tracker.filter_records(agent_records, survey = self.survey)
        self.tracker.update_from_records(filtered_records)

    def _sample_mode(self, mode_tuple, n):
        """
        Generates `n` endowments for a specific mode (e.g., core, theme, or question).

        Args:
            mode_tuple: A tuple of attribute bank keys (e.g., ("core", "economics")).
            n: Number of endowments to generate.

        Returns:
            List of generated endowments for this mode.
        """
        if isinstance(mode_tuple, str):
            mode_tuple = (mode_tuple,)

        attributes = sorted(set(attr for mode in mode_tuple for attr in self.attribute_bank["attributes"].get(mode, [])))

        if not attributes:
            warnings.warn(f"No attributes found for mode {mode_tuple}")
            return

        # self._log(f"  Generating {n} endowments for mode {mode_tuple} with {len(attributes)} attributes")
        new_endowments = self.endowment_model.generate_endowments(attributes, n=n, mode=mode_tuple)
        self.endowment_manager.add_batch(new_endowments)
        self.mode_history.append({
            "mode": mode_tuple,
            "attributes": attributes,
            "n": n
        })
        return new_endowments
    
    def _parallel_sample_modes(self, mode_n_list):
        """
        Parallel wrapper around _sample_mode for a list of (mode_tuple, n) pairs.
        Returns a list of endowment batches, one per mode.
        """
        if not self.parallel or self.max_workers <= 1:
            return [self._sample_mode(mode, n) for mode, n in mode_n_list]

        results = [None] * len(mode_n_list)
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            future_to_idx = {
                executor.submit(self._sample_mode, mode, n): i
                for i, (mode, n) in enumerate(mode_n_list)
            }

            iterator = as_completed(future_to_idx)
            if self.verbose:
                iterator = tqdm(iterator, total=len(future_to_idx), desc="Sampling modes")

            for future in iterator:
                i = future_to_idx[future]
                results[i] = future.result()

        return results

    def _simulate_survey_batch(self, endowment_batches) -> list[dict]:
        """
        Simulates survey responses for a batch of endowments using SurveyConductor.

        Args:
            endowment_batches: A list of endowment lists grouped by mode.

        Returns:
            A flat list of agent-level response records (dict format).
        """
        self._log(f"Simulating survey responses with dynamic chunking (parallel={self.parallel})")

        # Flatten all endowments into a single list
        flat_endowments = [e for batch in endowment_batches for e in batch]
        total = len(flat_endowments)

        # Decide chunk size based on max_workers
        if self.parallel:
            factor = 2  # Number of chunks per worker
            chunk_size = max(1, total // (self.max_workers * factor))
        else:
            chunk_size = total  # process all in one chunk

        self._log(f"Chunking {total} endowments into chunks of size {chunk_size}")

        # Helper to chunk the list
        def chunked_iter(iterable, n):
            it = iter(iterable)
            while True:
                chunk = list(islice(it, n))
                if not chunk:
                    break
                yield chunk

        chunks = list(chunked_iter(flat_endowments, chunk_size))

        def run_conductor(endowments):
            if hasattr(self.agent_config, "delay") and self.agent_config.delay:
                time.sleep(self.agent_config.delay)

            manager = ActiveEndowments.from_endowment_list(endowments)
            conductor = SurveyConductor(
                survey=self.survey,
                endowments=manager,
                agent_type=self.agent_config.agent_type,
                model_name=self.agent_config.model_name,
                formality=self.agent_config.formality,
                verbose=self.verbose
            )
            conductor.run(save=False, parallel_worker=self.parallel, **self.agent_config.agent_kwargs)
            return conductor.to_agent_records()

        results = []
        if self.parallel:
            with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
                futures = [executor.submit(run_conductor, chunk) for chunk in chunks]
                for f in tqdm(as_completed(futures), total=len(futures), desc="Simulating agents", disable= not self.verbose):
                    records = f.result()
                    self.response_log.extend(records)
                    results.extend(records)
        else:
            for chunk in chunks:
                records = run_conductor(chunk)
                self.response_log.extend(records)
                results.extend(records)

        self._log(f"Collected {len(results)} agent response records from {len(chunks)} chunks")
        return results

    def export_mode_summary(self, path: str):
        """
        Saves a summary of endowments generated per mode to CSV.

        Args:
            path (str): Path to save the summary file.
        """
        from collections import defaultdict
        import pandas as pd

        # Collect rows with mode, n, and attributes
        rows = []
        for entry in self.mode_history:
            mode = entry["mode"]
            n = entry["n"]
            attributes = entry["attributes"]
            rows.append({
                "mode": "+".join(mode),
                "n": n,
                "attributes": ", ".join(attributes)
            })

        df = pd.DataFrame(rows)
        df = df.groupby("mode").agg({
            "n": "sum",
            "attributes": lambda x: "; ".join(set(x))  # de-duplicate
        }).reset_index()
        df = df.sort_values("n", ascending=False)

        os.makedirs(os.path.dirname(path), exist_ok=True)
        df.to_csv(path, index=False)
        if self.verbose:
            self._log(f"[Mode Summary Export] Saved to: {path}")

    def export_responses_to_csv(self, path: str, format="code"):
        """
        Converts and saves all collected agent responses to a CSV file.

        Args:
            path: Path to save the file.
            format: Output format to use in Responses class (e.g., "code" or "text").
        """
        qids = [q["id"] for q in self.survey.questions]

        flat_records = []
        for agent_record in self.response_log:
            eid = agent_record["agent_id"]
            for qid, answer in zip(qids, agent_record["responses"]):
                flat_records.append({
                    "eid": eid,
                    "qid": qid,
                    "answer": answer if answer is not None else ""
                })

        responses = Responses(flat_records, self.survey, output_format=format)
        responses.save(path)
    
    def export_attribute_bank(self, filepath: str = "attribute_bank.yaml"):
        """
        Saves the current attribute bank to a YAML file.

        Args:
            filepath: Target path for saving the attribute bank.
        """
        # Ensure directory exists
        os.makedirs(os.path.dirname(filepath), exist_ok=True) if os.path.dirname(filepath) else None

        with open(filepath, "w") as f:
            yaml.dump(self.attribute_bank, f, sort_keys=False, allow_unicode=True)

        if self.verbose:
            self._log(f"[Mode Summary Export] Saved to: {filepath}")

    def get_responses(self, output_format="answer"):
        """
        Converts the response log to a `Responses` object (for analysis or saving).

        Args:
            output_format (str): Either "code" or "text" for the response content.

        Returns:
            Responses: Structured Responses object.
        """
        qids = [q["id"] for q in self.survey.questions]

        flat_records = []
        for agent_record in self.response_log:
            eid = agent_record["agent_id"]
            for qid, answer in zip(qids, agent_record["responses"]):
                flat_records.append({
                    "eid": eid,
                    "qid": qid,
                    "answer": answer if answer is not None else ""
                })

        return Responses(flat_records, self.survey, output_format=output_format, clean=True)

    @property
    def responses(self):
        return self.get_responses()
    
    @property
    def responses_code(self):
        return self.get_responses(output_format="code")

    def __len__(self):
        """
        Returns:
            The number of generated endowments.
        """
        return len(self.endowment_manager.endowments)   