from flask import session
import json
import os
import logging
import random
from scene_understanding_loader import generate_landing_page, get_all_questions
from user_study_app import UserStudyApp


class SceneUnderstandingBackend(UserStudyApp):
    """
    Backend for the scene understanding survey.

    Each sample contains four sequential question pages.
    """
    DEFAULT_PAGES_PER_SAMPLE = 4

    def __init__(self,
                 study_name: str,
                 logger: logging.Logger,
                 results_dir: str,
                 main_dir: str,
                 samples_dir: str,
                 fixed_samples_filepath: str,
                 n_samples_for_user: int,
                 n_target_samples: int,
                 landing_callback_name: str,
                 sample_callback_name: str,
                 submit_callback_name: str,
                 landing_page_filename: str,
                 sample_page_filename: str,
                 submit_page_filename: str,
                 n_page_per_sample: int = DEFAULT_PAGES_PER_SAMPLE,
                 survey_idx: int = -1):
        """
        Parameters identical to the composable backend, except the default
        `n_page_per_sample` is 4 for scene understanding.
        """
        super().__init__(study_name=study_name,
                         logger=logger,
                         results_dir=results_dir,
                         main_dir=main_dir,
                         samples_dir=samples_dir,
                         fixed_samples_filepath=fixed_samples_filepath,
                         n_page_per_sample=n_page_per_sample,
                         n_samples_for_user=n_samples_for_user,
                         n_target_samples=n_target_samples,
                         landing_callback_name=landing_callback_name,
                         sample_callback_name=sample_callback_name,
                         submit_callback_name=submit_callback_name,
                         landing_page_filename=landing_page_filename,
                         sample_page_filename=sample_page_filename,
                         submit_page_filename=submit_page_filename,
                         survey_idx=survey_idx)

    # --------------------------------------------------------------------- #
    # SURVEY / SAMPLE SELECTION
    # --------------------------------------------------------------------- #
    def load_samples(self, survey_idx: int):
        """
        Build (or load) the full list of surveys and assign one to the user.
        Logic is unchanged from the composable backend.
        """
        surveys_file = self.fixed_samples_filepath
        if os.path.exists(surveys_file):
            with open(surveys_file, "r") as f:
                all_surveys = json.load(f)
        else:
            all_samples = []
            random.seed(42)

            # Walk through <main_dir>/<samples_dir> looking for folders that
            # contain at least one .png – treat each folder as a sample.
            for root, dirs, files in os.walk(
                self.main_dir.split("/")[-1] + os.sep + self.samples_dir
            ):
                if any(fname.endswith(".png") for fname in files):
                    sample_path = root.replace(self.main_dir.split("/")[-1] + "/", "")
                    # ac_idx retained for compatibility, although not used here
                    ac_idx = random.randint(0, 10)
                    all_samples.append({"path": sample_path, "ac_idx": ac_idx})

            # Duplicate each sample ‑‑> build survey pools
            sample_pool = [
                sample for sample in all_samples for _ in range(self.n_target_samples)
            ]
            random.shuffle(sample_pool)

            n_surveys = len(sample_pool) // self.n_samples_for_user
            all_surveys = [
                sample_pool[i * self.n_samples_for_user : (i + 1) * self.n_samples_for_user]
                for i in range(n_surveys)
            ]

            with open(surveys_file, "w") as f:
                json.dump(all_surveys, f, indent=4)

        # ---------- Assign a survey to this user ----------
        if survey_idx >= 0:
            user_survey = all_surveys[survey_idx]
            self.assigned_survey_idx = survey_idx
        elif f"{self.study_name}_assigned_survey_idx" in session:
            self.assigned_survey_idx = session[f"{self.study_name}_assigned_survey_idx"]
            user_survey = all_surveys[self.assigned_survey_idx]
        else:
            self.assign_free_sample()
            user_survey = all_surveys[self.assigned_survey_idx]

        # Persist choice in session
        session[f"{self.study_name}_user_survey"] = user_survey
        session.modified = True

    def load_landing_page(self):
        """First page the participant sees (page_id == 0)."""
        folderpath = (
            self.main_dir.split("/")[-1] + os.sep + self.user_survey[0]["path"]
        )
        page_data = generate_landing_page(folderpath, self.main_dir)

        return {
            "render": self.sample_page_filename,   # uses same template as sample pages
            "image_data": page_data,
            "previous_answers": {},
            "page_id": 0,
            "pages_left": self.n_pages,
        }

    def load_sample_page(self, page_id: int):
        """
        Compute which sample and which subpage the requested page_id
        corresponds to, then serve the relevant question set.
        """
        assert page_id > 0, "Sample pages start at page_id == 1"

        # Map page_id → (sample_idx, question_idx)
        sample_idx = (page_id - 1) // self.n_page_per_sample
        question_idx = (page_id - 1) % self.n_page_per_sample

        folderpath = (
            self.main_dir.split("/")[-1] + os.sep + self.user_survey[sample_idx]["path"]
        )

        # Fetch all four question sets for this sample and pick the one we need
        all_question_sets = get_all_questions(folderpath, self.main_dir)
        page_data = all_question_sets[question_idx]

        previous_answers = self.answers.get(str(page_id), {})

        return {
            "render": self.sample_page_filename,
            "image_data": page_data,
            "previous_answers": previous_answers,
            "page_id": page_id,
            "pages_left": self.n_pages - page_id,
        }
