from flask import Flask, request, render_template, redirect, url_for, session
import os
import logging
from scene_understanding_app import SceneUnderstandingBackend

def setup_logger(name, log_file: str, level=logging.DEBUG):
    """Creates and returns a logger."""
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

    # Ensure log directory exists
    os.makedirs(os.path.dirname(log_file), exist_ok=True)

    handler = logging.FileHandler(log_file)        
    handler.setFormatter(formatter)

    logger = logging.getLogger(name)
    logger.setLevel(level)

    # Avoid adding multiple handlers to logger
    if not logger.handlers:
        logger.addHandler(handler)

    return logger

scene_understanding_study_name = 'scene_understanding_study'
log_dir = 'logs'
if not os.path.exists(log_dir):
    os.makedirs(log_dir)


scene_understanding_logging_filepath = os.path.join(log_dir, f"{scene_understanding_study_name}.log")
scene_understanding_logger = setup_logger(scene_understanding_study_name, scene_understanding_logging_filepath)

app = Flask(__name__)
app.secret_key = 'SDCHV3932032JIOEDDLKNDNL'
main_dir = os.path.dirname(os.path.abspath(__file__))

# Define the path to the JSON file and lock file (for regular samples)
samples_json_path = os.path.join(main_dir, 'samples_to_complete.json')
samples_lock_path = samples_json_path + '.lock'

# Directories for composable samples and results
composable_samples_dir = 'static/composable_samples'
scene_understanding_samples_dir = 'static/images/train'

scene_understanding_results_dir = 'results'
scene_understanding_surveys_json_path = os.path.join(scene_understanding_results_dir, 'user_surveys.json')
scene_understanding_surveys_lock_path = scene_understanding_surveys_json_path + '.lock'


def get_scene_understanding_backend():
    return SceneUnderstandingBackend(study_name='scene_understanding_study',
                                        logger=scene_understanding_logger,
                                        results_dir=scene_understanding_results_dir,
                                        main_dir=main_dir,
                                        samples_dir=scene_understanding_samples_dir,
                                        fixed_samples_filepath=scene_understanding_surveys_json_path,
                                        n_samples_for_user=4,
                                        n_target_samples=5,
                                        n_page_per_sample=4,
                                        landing_callback_name='index',
                                        sample_callback_name='questions_callback',
                                        submit_callback_name='submit_answers',
                                        landing_page_filename='questions.html',
                                        sample_page_filename='questions.html',
                                        submit_page_filename='results.html',
                                        survey_idx=-1)
    
def process_landing(logger: logging.Logger, backend):
    logger.debug("Index route called. Redirecting to first question.")
    
    if request.method == 'POST':
        page_data = backend.load_sample_page(0)
    else:
        page_data = backend.load_landing_page()
    
    assert 'render' in page_data, "Render key not found in info dictionary."
    
    return render_template(page_data['render'], 
                           image_data=page_data['image_data'], 
                           previous_answers=page_data['previous_answers'], 
                           page_id=page_data['page_id'], 
                           pages_left=page_data['pages_left'])
    
def process_question(logger: logging.Logger, backend, page_id: int):
    logger.debug(f"Questions route called with page_id: {page_id} with method {request.method}.")
    
    if request.method == 'POST' and request.form.get('action') == 'next':
        page_data = backend.next_page(page_id)
    elif request.method == 'POST' and request.form.get('action') == 'back':
        page_data = backend.previous_page(page_id)
    else:
        # Otherwise just load the page
        page_data = backend.load_sample_page(page_id)
        
    if 'render' in page_data:
        return render_template(page_data['render'], 
                               image_data=page_data['image_data'], 
                               previous_answers=page_data['previous_answers'], 
                               page_id=page_data['page_id'], 
                               pages_left=page_data['pages_left'])
    elif 'redirect' in page_data:
        logger.debug(f"(In webapp) Redirecting to {page_data['redirect']} with page_id: {page_data['page_id']}")
        return redirect(url_for(page_data['redirect'], page_id=page_data['page_id']))
    else:
        logger.error("Invalid data returned from questions callback.")
        raise ValueError(f"Invalid data returned from questions callback in {page_data}.")

def process_submit(logger: logging.Logger, backend):
    logger.debug("Submit route called.")
    confirmation_code = backend.load_submit_page()
    return render_template('results.html', confirmation_code=confirmation_code)

# ------------------------------------------------------------------------------
# Scene Understanding Survey Callbacks
# ------------------------------------------------------------------------------
@app.route('/')
def index():
    backend = get_scene_understanding_backend()
    return process_landing(scene_understanding_logger, 
                           backend)

@app.route('/questions/<int:page_id>', methods=['GET', 'POST'])
def questions_callback(page_id):
    backend = get_scene_understanding_backend()
    return process_question(scene_understanding_logger,
                            backend=backend,
                            page_id=page_id)

@app.route('/submit', methods=['GET', 'POST'])
def submit_answers():
    backend = get_scene_understanding_backend()
    return process_submit(scene_understanding_logger,
                          backend=backend)


if __name__ == '__main__':
    app.run(port=8004, debug=True)