from flask import Flask, render_template, request, jsonify, send_from_directory, send_file
from task_flows import WCSTFlow
import json
import os
import logging
import re
from collections import defaultdict

"""
ngrok http 5000
"""

app = Flask(__name__)
current_experiment = None


logging.basicConfig(level=logging.DEBUG)

app.static_folder = os.path.abspath('task_datasets')
app.logger.info(f"Static folder set to: {app.static_folder}")

def get_next_subject_and_session():
    experiment_logs_path = 'experiment_logs'
    pattern = r'WCST_human_subject(\d+)_STA_OI_session(\d+)_\d{8}_\d{6}'
    
    subjects_sessions = defaultdict(set)
    
    for folder in os.listdir(experiment_logs_path):
        match = re.match(pattern, folder)
        if match:
            subject = int(match.group(1))
            session = int(match.group(2))
            subjects_sessions[subject].add(session)

    # Find the next subject and session based on the rules
    next_subject = 1
    next_session = 1

    if subjects_sessions:
        max_subject = max(subjects_sessions.keys())
        if max_subject <= 20:
            # For subjects 1-20, session == subject
            if len(subjects_sessions[max_subject]) == 1:  # If the last subject has 1 session
                next_subject = max_subject + 1
                next_session = next_subject
            else:
                next_subject = max_subject
                next_session = max_subject
        else:
            # For subjects > 20, session from 1 to 10
            if len(subjects_sessions[max_subject]) < 10:
                next_subject = max_subject + 1
                next_session = max(subjects_sessions[max_subject]) + 1
            else:
                next_subject = max_subject + 1
                next_session = 1

    return next_subject, next_session

@app.route('/', methods=['GET', 'POST'])
def index():
    global current_experiment
    if request.method == 'POST':
        subject_number = request.form.get('subject_number')
        session_number = request.form.get('session_number')
        
        if not subject_number or not session_number:
            subject_number, session_number = get_next_subject_and_session()
        else:
            subject_number = int(subject_number)
            session_number = int(session_number)
        
        current_experiment = WCSTFlow(None, True, int(session_number), "English")  # "Chinese"
        current_experiment.data_logger.start_new_session(
            "WCST",
            None,
            "STA",
            "OI",
            True,
            session_number,
            subject_number,
            None
        )

        image_path, prompt = current_experiment.run_trial("OI", "STA")
        app.logger.info(f"Original image path: {image_path}")
        image_path = image_path.replace('./task_datasets/', '')
        full_image_path = os.path.join(app.static_folder, image_path)
        app.logger.info(f"Full image path: {full_image_path}")
        if not os.path.exists(full_image_path):
            app.logger.error(f"Image file not found: {full_image_path}")
        return jsonify({
            'image_path': '/image/' + image_path,
            'prompt': prompt,
            'trial_number': current_experiment.current_trial + 1,
            'total_trials': current_experiment.total_trials,
            'subject_number': subject_number,
            'session_number': session_number
        })
    next_subject, next_session = get_next_subject_and_session()
    return render_template('index.html', next_subject=next_subject, next_session=next_session)

@app.route('/submit_response', methods=['POST'])
def submit_response():
    global current_experiment
    app.logger.info(f"current_experiment: {current_experiment}")
    
    data = request.json
    response = data['response']

    feedback = current_experiment.evaluate_response(response, "OI", "STA", True)

    current_experiment.finalize_current_trial()
    if current_experiment.is_experiment_complete():
        current_experiment.data_logger.end_session()
        return jsonify({'complete': True})

    image_path, prompt = current_experiment.run_trial("OI", "STA")
    image_path = image_path.replace('./task_datasets/', '')
    app.logger.info(f"Next trial image path: {image_path}")
    return jsonify({
        'feedback': feedback,
        'image_path': '/image/' + image_path,
        'prompt': prompt,
        'trial_number': current_experiment.current_trial + 1,
        'total_trials': current_experiment.total_trials
    })

@app.route('/image/<path:filename>')
def serve_image(filename):
    app.logger.info(f"Requesting image: {filename}")
    try:
        return send_file(os.path.join(app.static_folder, filename), mimetype='image/png')
    except FileNotFoundError:
        app.logger.error(f"Image file not found: {filename}")
        return "Image not found", 404

if __name__ == '__main__':
    app.run(debug=True)