import logging
import os
from abc import ABC, abstractmethod
import json
from datetime import datetime
from typing import Dict, List
from flask import request, session
import uuid
from filelock import FileLock
from datetime import datetime, timedelta
from filelock import FileLock

class UserStudyApp(ABC):
    def __init__(self,
                 study_name: str,
                 logger: logging.Logger,
                 results_dir: str,
                 main_dir: str,
                 samples_dir: str,
                 fixed_samples_filepath: str,
                 n_page_per_sample: int,
                 n_samples_for_user: int,
                 n_target_samples: int,
                 landing_callback_name: str,
                 sample_callback_name: str,
                 submit_callback_name: str,
                 landing_page_filename: str,
                 sample_page_filename: str,
                 submit_page_filename: str,
                 survey_idx: int = -1):
        
        self.logger = logger
        self.survey_idx = survey_idx
        self.study_name = study_name
        self.n_samples_for_user = n_samples_for_user
        self.n_page_per_sample = n_page_per_sample
        self.n_target_samples = n_target_samples
        self.base_results_dir = results_dir
        self.main_dir = main_dir
        self.samples_dir = samples_dir
        self.fixed_samples_filepath = fixed_samples_filepath
        self.landing_callback_name = landing_callback_name
        self.sample_callback_name = sample_callback_name
        self.submit_callback_name = submit_callback_name
        self.landing_page_filename = landing_page_filename
        self.sample_page_filename = sample_page_filename
        self.submit_page_filename = submit_page_filename
        self.status_file = os.path.join(self.base_results_dir, "survey_statuses.json")
        self.lock_file = self.status_file + ".lock"
    
        # Define all pages for this user study
        self.load_page_mapping = {
            "landing": self.load_landing_page,
            "survey": self.load_sample_page,
            "submit": self.load_submit_page
        }
        
        if f'{study_name}_user_survey' not in session:
            self.initialize_study()
        else:
            self.load_attributes()
        
    def initialize_study(self):
        
        if not os.path.exists(self.base_results_dir):
            os.makedirs(self.base_results_dir)
            
        study_name = self.study_name
            
        self.user_id = str(uuid.uuid4())
        session[f'{study_name}_user_id'] = self.user_id
        session[f'{study_name}_finished_pages'] = {}
        
        self.results_dir = os.path.join(self.base_results_dir, self.user_id)
            
        if not os.path.exists(self.results_dir):
            os.makedirs(self.results_dir)
        
        session[f'{study_name}_user_id'] = str(uuid.uuid4())
        session[f'{study_name}_finished_pages'] = {}
        session[f'{study_name}_results_dir'] = self.results_dir

        # Prolific specific args
        PROLIFIC_PID = request.args.get('PROLIFIC_PID')
        STUDY_ID = request.args.get('STUDY_ID')
        SESSION_ID = request.args.get('SESSION_ID')
        
        if not PROLIFIC_PID or not STUDY_ID or not SESSION_ID:
            PROLIFIC_PID = "None"
            STUDY_ID = "None"
            SESSION_ID = "None"
            self.logger.warning("PROLIFIC_PID, STUDY_ID, or SESSION_ID not found in request arguments.")
        
        logging.debug(f"Received PROLIFIC_PID: {PROLIFIC_PID}, STUDY_ID: {STUDY_ID}, SESSION_ID: {SESSION_ID}")

        session[f'{study_name}_PROLIFIC_PID'] = PROLIFIC_PID
        session[f'{study_name}_STUDY_ID'] = STUDY_ID
        session[f'{study_name}_SESSION_ID'] = SESSION_ID
        
        session.setdefault(f'{study_name}_answers', {})

        pid_file = os.path.join(self.results_dir, 'prolific_pid.json')
        with open(pid_file, 'w') as f:
            json.dump({
                'PROLIFIC_PID': PROLIFIC_PID,
                'STUDY_ID': STUDY_ID,
                'SESSION_ID': SESSION_ID
            }, f)

        self.load_samples(survey_idx=self.survey_idx)
        assert f'{study_name}_user_survey' in session, "User survey not found in session."
        self.user_survey = session[f'{study_name}_user_survey']
        self.answers = {}
        
        # total number of pages
        self.n_pages =  1 + len(session[f'{study_name}_user_survey']) * self.n_page_per_sample
        session[f'{study_name}_n_pages'] = self.n_pages
        session.modified = True
        
    def load_attributes(self):
        # Load attributes from session
        self.logger.debug("Loading attributes from session.")
        for key in session.keys():
            if key.startswith(self.study_name):
                value = session[key]
                attrib_name = key.replace(f'{self.study_name}_', '')
                self.logger.debug(f"Set attribute {attrib_name}: {value}")
                setattr(self, attrib_name, value)

    def _ensure_status_file(self):
        """Create the JSON file if missing, with all indices marked free."""
        if not os.path.exists(self.status_file):
            surveys_file = self.fixed_samples_filepath
            if os.path.exists(surveys_file):
                with open(surveys_file, 'r') as f:
                    all_surveys = json.load(f)
            n_surveys = len(all_surveys)
            print(f"Number of surveys: {n_surveys}")
            initial = {
                str(i): {"status": "free", "status_time": "", "user_id": ""}
                for i in range(n_surveys)
            }
            with open(self.status_file, "w") as f:
                json.dump(initial, f)

    def assign_free_sample(self):
        with FileLock(self.lock_file):
            # ensure file exists
            self._ensure_status_file()

            # load
            with open(self.status_file, "r") as f:
                statuses = json.load(f)

            now = datetime.now()
            # timeout check: any in‐progress >2h → free
            for idx_str, info in statuses.items():
                if info["status"] == "in progress" and info["status_time"]:
                    last = datetime.fromisoformat(info["status_time"])
                    if now - last > timedelta(hours=2):
                        statuses[idx_str] = {"status": "free", "status_time": "", "user_id": ""}

            # pick a free one
            free_idxs = sorted(int(i) for i, info in statuses.items() if info["status"] == "free")
            if free_idxs:
                survey_idx = free_idxs[0]
            else:
                # then select one in progress
                in_progress_idxs = sorted(int(i) for i, info in statuses.items() if info["status"] == "in progress")
                if in_progress_idxs:
                    survey_idx = in_progress_idxs[0]
                else:
                    survey_idx = 0

            # mark it in‐progress
            statuses[str(survey_idx)] = {
                "status": "in progress",
                "status_time": now.isoformat(),
                "user_id": self.user_id
            }

            # write back
            with open(self.status_file, "w") as f:
                json.dump(statuses, f, indent=2)

        # store and return
        self.assigned_survey_idx = survey_idx
        session[f'{self.study_name}_assigned_survey_idx'] = survey_idx
        session.modified = True
        
    def save_answers(self, 
                     file_basename: str,
                     answers: Dict):
        # Skip saving if answers is empty.
        if not answers:
            return
        
        if not os.path.exists(self.results_dir):
            os.makedirs(self.results_dir)
        
        assert os.path.isdir(self.results_dir), f"Results directory {self.results_dir} does not exist."
            
        timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        filename = os.path.join(self.results_dir, f'{file_basename}_{timestamp}.json')
        with open(filename, 'w') as f:
            json.dump(answers, f)
        self.logger.debug(f"Saved answers to {filename}.")
        
    def load_page(self, page_name: str, page_id: int):
        """
        Generates a specific page for the user study.

        :param page_name: Name of the page to generate.
        """
        
        assert page_name in self.load_page_mapping, f"Page {page_name} not found in survey mapping."

        self.logger.info(f"Loading page: {page_name} with ID: {page_id}")
        self.load_page_mapping[page_name]()
        self.logger.info(f"Page {page_name} loaded successfully with ID: {page_id}.")

    @abstractmethod
    def load_samples(self, survey_idx: int):
        """
        Loads the samples for the user study, for this particular survey.
        """
        raise NotImplementedError("Sample loading not implemented.")
        
    @abstractmethod
    def load_landing_page(self):
        """
        Loads the landing page for the user study.
        """
        raise NotImplementedError("Landing page generation not implemented.")
        
    @abstractmethod
    def load_sample_page(self, page_id: int):
        """
        Loads the sample page for the user study.
        """
        raise NotImplementedError("Survey page generation not implemented.")
        
    def load_submit_page(self) -> str:
        """
        Loads the submit page for the user study.
        """
        
        if self.assigned_survey_idx is None:
            raise RuntimeError("No survey index to submit; call load_samples() first.")

        self.logger.debug(f"Attempting to save status for survey index: {self.assigned_survey_idx}")
        with FileLock(self.lock_file):
            with open(self.status_file, "r") as f:
                statuses = json.load(f)

            now = datetime.now()
            statuses[str(self.assigned_survey_idx)] = {
                "status": "completed",
                "status_time": now.isoformat(),
                "user_id": self.user_id
            }

            with open(self.status_file, "w") as f:
                json.dump(statuses, f, indent=2)
            self.logger.debug(f"Status for survey index {self.assigned_survey_idx} updated to completed.")
        
        confirmation_code = self.user_id
        self.logger.debug(f"Sending render request to submit page with confirmation code: {confirmation_code}")
        return confirmation_code
    
    def get_current_sample(self, page_id: int) -> str:
        """
        Returns the current sample for the given page ID.
        """
        assert page_id >= 0, "Page ID cannot be negative."
        sample_index = (page_id - 1) // self.n_page_per_sample
        if sample_index < len(session[f'{self.study_name}_samples']):
            return session[f'{self.study_name}_samples'][sample_index]
        else:
            return None
    
    def next_page(self, current_page_id: int) -> Dict:
        
        if current_page_id >= 0:
            # Update session with current answers for the page
            session[f'{self.study_name}_finished_pages'][str(current_page_id)] = '1'
            user_answers = {}
            for key in request.form.keys():
                if key != 'action':
                    values = request.form.getlist(key)
                    clean_key = key.rstrip('[]')
                    user_answers[clean_key] = values if len(values) > 1 else values[0]
            session[f'{self.study_name}_answers'][str(current_page_id)] = user_answers
            self.logger.debug(f"Answers updated for page {current_page_id}: {session[f'{self.study_name}_answers'][str(current_page_id)]}")
            
            # Save current answers to file
            file_basename = f'answers_{current_page_id}'
            self.save_answers(file_basename, session[f'{self.study_name}_answers'][str(current_page_id)])
        
        session.modified = True
        
        if current_page_id < self.n_pages - 1:
            self.logger.debug(f"Sending redirect request {self.sample_callback_name} to next question with page_id: {current_page_id + 1}")
            return {'redirect': self.sample_callback_name, 'page_id': current_page_id + 1}
        else:
            self.logger.debug(f"Sending redirect request to submit page with page_id: {current_page_id + 1}")
            return {'redirect': self.submit_callback_name, 'page_id': current_page_id + 1}
    
    def previous_page(self, current_page_id: int) -> Dict:
        self.logger.debug(f"Going back to previous page with page_id: {current_page_id}")
        assert current_page_id >= 0, "Current page ID cannot be negative."
        if current_page_id > 1:
            self.logger.debug(f"Sending redirect request to previous question with page_id: {current_page_id - 1}")
            return {'redirect': self.sample_callback_name, 'page_id': current_page_id - 1}
        else:
            self.logger.debug(f"Sending redirect request to landing page with page_id: 0")
            return {'redirect': self.landing_callback_name, 'page_id': 0}