from abc import ABC, abstractmethod
from typing import Dict
import pandas as pd
from typing import Union, List
import random
from copy import deepcopy

class BaseExperiment(ABC):
    """
    Abstract base class for experiments that compare agent responses to aggregate benchmarks.

    Attributes:
        responses (Responses): Response matrix.
        survey (Survey): Survey structure.
        endowments (EndowmentManager): Agent metadata and roles.
        aggregate_stats (dict): {qid: value} empirical or simulated ground-truth values.
        filter_binary (bool): Whether to restrict analysis to binary-coded questions.
        drop_na (bool): Whether to drop rows with missing responses.
    """    
    def __init__(
        self,
        responses,
        survey,
        endowments,
        aggregate_stats: Dict[str, float] = None,
        filter_binary: bool = False,
        drop_na: bool = False,
    ):
        """
        Initializes the experiment with response data, survey metadata, and agent endowments.

        Args:
            responses (Responses): Encoded response matrix.
            survey (Survey): Survey instance.
            endowments (EndowmentManager): Holds agent roles and attributes.
            aggregate_stats (dict, optional): Empirical values for comparison (qid → value).
            filter_binary (bool): If True, retain only questions with ≤ 3 codes.
            drop_na (bool): If True, remove questions with any missing agent responses.
        """
        self.responses = responses
        self.survey = survey
        self.endowments = endowments
        self.aggregate_stats = aggregate_stats or {}
        self.filter_binary = filter_binary
        self.drop_na = drop_na

        if self.filter_binary:
            self._filter_binary_questions()

    def _filter_binary_questions(self):
        """
        Filters the survey and response matrix to include only binary (or nearly binary) questions.

        Retains questions with 3 or fewer distinct coded responses.
        """
        binary_qids = [
            q["id"]
            for q in self.survey.questions
            if len(q.get("answer_to_code", {})) <= 3 # Suitable for questions with an additional neutral answer (e.g., 'undecided', coded as 0.5).
        ]
        self.responses.responses = {
            qid: ans_dict for qid, ans_dict in self.responses.responses.items()
            if qid in binary_qids
        }
        self.survey.questions = [
            q for q in self.survey.questions if q["id"] in binary_qids
        ]

    def get_aggregate_stats(self):
        """
        Returns the aggregate statistics used as benchmark.

        Returns:
            dict: Mapping from question ID to aggregate value.
        """
        return self.aggregate_stats

    def get_augmented_matrix(self, proxy_only: bool = False) -> pd.DataFrame:
        """
        Returns a [qid × eid] DataFrame of agent responses, with an 'aggregate' column for benchmarks.

        Args:
            proxy_only (bool): If True, include only proxy agent columns and the aggregate.

        Returns:
            pd.DataFrame: Matrix of responses with optional filtering.
        """
        matrix = self.responses.responses
        all_qids = sorted(matrix.keys())
        all_eids = sorted({eid for q in matrix.values() for eid in q})

        data = []
        for qid in all_qids:
            row = {eid: matrix[qid].get(eid, None) for eid in all_eids}
            row["qid"] = qid
            row["aggregate"] = self.aggregate_stats.get(qid, None)
            data.append(row)

        df = pd.DataFrame(data).set_index("qid")

        if self.drop_na:
            # Drop agents (columns) with any missing responses
            agent_cols = [col for col in df.columns if col != "aggregate"]
            cols_with_na = df[agent_cols].columns[df[agent_cols].isnull().any()]
            df = df.drop(columns=cols_with_na)

        if proxy_only:
            proxy_eids = set(self.endowments.get_eids_by_role("proxy"))
            keep_cols = proxy_eids.intersection(df.columns)
            df = df[list(keep_cols) + ["aggregate"]]

        return df
    
    def get_feature_names(self, proxy_only: bool = False) -> list[str]:
        """
        Returns the ordered list of agent feature names (column names) used in the regression.

        Args:
            proxy_only (bool): If True, only return features for proxy agents.

        Returns:
            List of agent IDs used as regression features.
        """
        df = self.get_augmented_matrix(proxy_only=proxy_only)
        return [col for col in df.columns if col != "aggregate"]


    def get_dataframe_by_split(self, split: Union[str, List[str]], proxy_only: bool = False) -> pd.DataFrame:
        """
        Returns the augmented response matrix for one or more question splits.

        Args:
            split (str or list[str]): Name(s) of question split(s) (e.g., 'train', 'val', 'test').
            proxy_only (bool): If True, include only proxy agents and aggregate.

        Returns:
            pd.DataFrame: Filtered matrix of responses and aggregates.
        """
        if isinstance(split, str):
            split = [split]

        qids = []
        for s in split:
            qids.extend([q["id"] for q in self.survey.get_questions_by_split(s)])

        full_df = self.get_augmented_matrix(proxy_only=proxy_only)
        return full_df.loc[full_df.index.intersection(qids)]
    
    def sample_fraction(self, fraction: float, seed: int = 101):
        """
        Returns a new Experiment with a fraction of the questions retained,
        sampled proportionally by split, stratified by original question groups.

        Args:
            fraction (float): Fraction of questions to retain (0 < fraction ≤ 1).
            seed (int): Random seed for reproducibility.

        Returns:
            BaseExperiment: A new experiment object with subsetted questions, responses, and stats.
        """
        if not (0 < fraction <= 1):
            raise ValueError("fraction must be between 0 and 1")

        import random
        from copy import deepcopy

        random.seed(seed)

        # Step 1: Get original QIDs by split (for BinaryExtendedSurvey)
        if not hasattr(self.survey, "original_to_binary_map"):
            raise ValueError("This method requires a BinaryExtendedSurvey with binary-original mappings.")

        split_names = ["train", "valid", "test"]
        original_qids_sampled = []

        # Use binary_to_original_map to identify which binary QIDs belong to which original question
        b2o = self.survey.binary_to_original_map
        q_by_split = {split: self.survey.get_questions_by_split(split) for split in split_names}
        
        for split in split_names:
            # Extract binary QIDs in this split
            binary_qids = [q["id"] for q in q_by_split[split] if q["id"] in b2o]
            # Recover original IDs
            original_qids = list({b2o[bqid]["original_id"] for bqid in binary_qids})
            k = max(1, int(len(original_qids) * fraction)) if original_qids else 0
            sampled = random.sample(original_qids, k)
            original_qids_sampled.extend(sampled)

        # Step 2: Expand original_qids_sampled to all corresponding binary QIDs
        binary_qids_sampled = []
        for oid in original_qids_sampled:
            binary_qids_sampled.extend(self.survey.original_to_binary_map.get(oid, []))

        # Step 3: Subset survey
        subset_questions = [q for q in self.survey.questions if q["id"] in binary_qids_sampled]
        new_survey = self.survey.__class__.clone_with_subset(self.survey, subset_questions)

        # Step 4: Subset responses
        if hasattr(self.responses, "clone_with_subset"):
            new_responses = self.responses.clone_with_subset(binary_qids_sampled, subset_survey=new_survey)
        else:
            new_responses = deepcopy(self.responses)
            new_responses.responses = {
                qid: ans for qid, ans in new_responses.responses.items()
                if qid in binary_qids_sampled
            }
            new_responses.questions = {qid: q for qid, q in new_responses.questions.items()
                                    if qid in binary_qids_sampled}
            new_responses.survey = new_survey
            new_responses.clone = True
            new_responses.source_path = self.responses.source_path

        # Step 5: Subset aggregate stats
        new_aggregate_stats = {
            qid: val for qid, val in self.aggregate_stats.items()
            if qid in binary_qids_sampled
        }

        # Step 6: Return cloned experiment of same type
        return self.__class__(
            responses=new_responses,
            survey=new_survey,
            endowments=self.endowments,
            aggregate_stats=new_aggregate_stats if isinstance(self, EmpiricalExperiment) else None,
            filter_binary=self.filter_binary,
            drop_na=self.drop_na,
        )
    
    def sample_trainvalid_fraction(self, fraction: float, seed: int = 101):
        """
        Returns a new Experiment with a fraction of train+valid questions retained,
        while keeping test questions intact.

        Args:
            fraction (float): Fraction of train+valid questions to retain (0 < fraction <= 1).
            seed (int): Random seed for reproducibility.

        Returns:
            BaseExperiment: A new experiment object with subsetted questions, responses, and stats.
        """
        if not (0 < fraction <= 1):
            raise ValueError("fraction must be between 0 and 1")

        random.seed(seed)

        if not hasattr(self.survey, "original_to_binary_map"):
            raise ValueError("This method requires a BinaryExtendedSurvey with binary-original mappings.")

        # Step 1: Sample original QIDs only from train/valid
        split_names = ["train", "valid"]
        original_qids_sampled = []

        b2o = self.survey.binary_to_original_map
        q_by_split = {split: self.survey.get_questions_by_split(split) for split in split_names}

        for split in split_names:
            binary_qids = [q["id"] for q in q_by_split[split] if q["id"] in b2o]
            original_qids = list({b2o[bqid]["original_id"] for bqid in binary_qids})
            k = max(1, int(len(original_qids) * fraction)) if original_qids else 0
            sampled = random.sample(original_qids, k)
            original_qids_sampled.extend(sampled)

        # Step 2: Expand to binary QIDs
        binary_qids_sampled = []
        for oid in original_qids_sampled:
            binary_qids_sampled.extend(self.survey.original_to_binary_map.get(oid, []))

        # Step 3: Always keep test questions
        test_qs = self.survey.get_questions_by_split("test")
        test_qids = [q["id"] for q in test_qs]
        binary_qids_sampled.extend(test_qids)

        # Step 4: Subset survey
        subset_questions = [q for q in self.survey.questions if q["id"] in binary_qids_sampled]
        new_survey = self.survey.__class__.clone_with_subset(self.survey, subset_questions)

        # Step 5: Subset responses
        if hasattr(self.responses, "clone_with_subset"):
            new_responses = self.responses.clone_with_subset(binary_qids_sampled, subset_survey=new_survey)
            new_responses.source_path = self.responses.source_path
        else:
            new_responses = deepcopy(self.responses)
            new_responses.responses = {
                qid: ans for qid, ans in new_responses.responses.items()
                if qid in binary_qids_sampled
            }
            new_responses.questions = {qid: q for qid, q in new_responses.questions.items()
                                    if qid in binary_qids_sampled}
            new_responses.survey = new_survey
            new_responses.clone = True

        # Step 6: Subset aggregate stats
        new_aggregate_stats = {
            qid: val for qid, val in self.aggregate_stats.items()
            if qid in binary_qids_sampled
        }

        # Step 7: Return cloned experiment of same type
        return self.__class__(
            responses=new_responses,
            survey=new_survey,
            endowments=self.endowments,
            aggregate_stats=new_aggregate_stats if isinstance(self, EmpiricalExperiment) else None,
            filter_binary=self.filter_binary,
            drop_na=self.drop_na,
        )
    

class SimulationExperiment(BaseExperiment):
    """
    Experiment class for simulated settings where ground truth is computed from agent responses.

    This subclass infers aggregate statistics from weighted responses of 'ground_truth' agents.
    """
    def __init__(self, responses, survey, endowments, aggregate_stats = None, filter_binary = False, drop_na = False):
        """
        Initializes a simulation experiment and computes aggregate statistics from ground-truth agents.

        Args:
            responses (Responses): Encoded agent responses.
            survey (Survey): Survey metadata.
            endowments (EndowmentManager): Agent metadata and roles.
            aggregate_stats (dict, optional): Unused (computed internally).
            filter_binary (bool): Restrict to binary-coded questions.
            drop_na (bool): Drop rows with any missing values.
        """
        super().__init__(responses, survey, endowments, aggregate_stats, filter_binary, drop_na)
        self.aggregate_stats = self._compute_aggregate_stats()

    def _compute_aggregate_stats(self):
        """
        Computes weighted averages of numeric response codes across ground-truth agents.

        Returns:
            dict: Mapping from question ID to weighted average code.
        """
        gt_eids = self.endowments.get_eids_by_role('ground_truth')
        weights = {
            e["eid"]: e.get("weight", 1.0) for e in self.endowments.get_endowments_by_role('ground_truth')
        }

        stats = {}
        for qid, agent_dict in self.responses.responses.items():
            total_weight = 0.0
            weighted_sum = 0.0

            for eid, code in agent_dict.items():
                if eid not in weights:
                    continue
                try:
                    code_val = float(code)
                except (TypeError, ValueError):
                    continue  # skip non-numeric codes

                w = weights[eid]
                weighted_sum += code_val * w
                total_weight += w

            stats[qid] = weighted_sum / total_weight if total_weight > 0 else None

        return stats

class EmpiricalExperiment(BaseExperiment):
    """
    Experiment class for real-world settings where benchmark statistics are externally provided.

    Requires `aggregate_stats` to be passed on initialization.
    """
    def __init__(self, responses, survey, endowments, aggregate_stats, filter_binary=False, drop_na = False):
        """
        Initializes an empirical experiment with pre-supplied aggregate statistics.

        Args:
            responses (Responses): Agent response matrix.
            survey (Survey): Survey metadata.
            endowments (EndowmentManager): Agent role and identity mapping.
            aggregate_stats (dict): Required empirical benchmark values.
            filter_binary (bool): If True, restrict to binary-coded questions.
            drop_na (bool): If True, drop questions with any missing values.

        Raises:
            ValueError: If `aggregate_stats` is not provided.
        """
        if aggregate_stats is None:
            raise ValueError("EmpiricalExperiment requires aggregate_stats to be provided.")
        super().__init__(responses, survey, endowments, aggregate_stats, filter_binary, drop_na)