# This class is useful for Empirical Experiments. It loads the aggregate data and allows it to be transformed into a binary format aligned with the question formats in a BinaryExtendedSurvey object.

import json
from modules.survey_converter import BinaryExtendedSurvey

class AggregateResponses:
    def __init__(self, aggregate_dict: dict = None, survey=None, json_path: str = None):
        """
        Args:
            aggregate_dict (dict): Optional. {qid: {label: proportion}}
            survey (BinaryExtendedSurvey): Required. Survey with original and binary questions
            json_path (str): Optional. Path to JSON file storing aggregate_dict
        """
        if survey is None:
            raise ValueError("`survey` must be provided")

        if aggregate_dict is None:
            if json_path:
                with open(json_path, "r") as f:
                    aggregate_dict = json.load(f)
                print(f"[INFO] Loaded aggregate data from {json_path}")
            else:
                raise ValueError("Either `aggregate_dict` or `json_path` must be provided")

        self.raw = aggregate_dict
        self.survey = survey
        self.original_question_map = {q["id"]: q for q in survey.original_questions}
        self.question_map = {q["id"]: q for q in survey.questions}
        self.binary = self._convert_to_binary()

    def _convert_to_binary(self):
        """
        Convert categorical distributions to expected binary values.
        Returns:
            dict: {qid: binary value}
        """

        binary_result = {}

        for qid, dist in self.raw.items():
            if qid not in self.original_question_map:
                print(f"[SKIP] Unknown question ID: {qid}")
                continue

            self._process_item(qid, dist, binary_result)

        return binary_result
    
    def _process_item(self, qid, dist, binary_result):
        original_question = self.original_question_map.get(qid)
        
        code_to_answer = original_question.get("code_to_answer", {})
        n_options = len(code_to_answer)

        if n_options == 2:
            binary_q = self.question_map.get(qid)
            if binary_q is None:
                print(f"[WARN] Missing binary question entry for {qid}")
                return
            bin_code_to_answer = binary_q.get("code_to_answer", {})
            label_for_1 = bin_code_to_answer.get(1)
            if label_for_1 is not None:
                binary_result[qid] = dist.get(label_for_1, 0.0)
            else:
                print(f"[WARN] Label for code 1 not found in {qid}")
            return


        for binary_q in self.survey.questions:
            if binary_q.get("base_id") != qid:
                continue

            binary_qid = binary_q["id"]
            binary_answer = binary_q["base_answer"]

            value = dist.get(binary_answer, 0.0)

            binary_result[binary_qid] = value
        
        return
    
    def get_binary_value(self, question_id: str):
        """Get the expected binary value for a question."""
        return self.binary.get(question_id)

    def get_all_binary(self):
        """Return the full binary dict."""
        return self.binary