import streamlit as st
import json
import os
from PIL import Image
import pandas as pd
from pathlib import Path
import datetime

# Set page configuration
st.set_page_config(page_title="Image Recognition Quiz", layout="wide")

# Define paths for saving progress
RESULTS_DIR = "./results"
os.makedirs(RESULTS_DIR, exist_ok=True)

def get_progress_file_path(difficulty):
    """Get the path for the progress save file"""
    return f"{RESULTS_DIR}/{difficulty}_progress.json"

# Initialize session state and attempt to load previous progress
if 'initialized' not in st.session_state:
    st.session_state.current_idx = 0
    st.session_state.answers = {}
    st.session_state.test_completed = False
    st.session_state.questions = []
    st.session_state.difficulty = "easy"
    st.session_state.initialized = True
    
    # Try to load previous state from saved progress
    try:
        difficulty = st.session_state.difficulty
        progress_file = get_progress_file_path(difficulty)
        
        if os.path.exists(progress_file):
            with open(progress_file, 'r', encoding='utf-8') as f:
                progress_data = json.load(f)
                
            st.session_state.current_idx = progress_data.get('current_idx', 0)
            st.session_state.answers = progress_data.get('answers', {})
            st.session_state.difficulty = progress_data.get('difficulty', "easy")
            
            # We'll load questions based on this difficulty later
    except Exception as e:
        st.sidebar.warning(f"Failed to load previous progress: {str(e)}")

def save_progress():
    """Save current progress to file"""
    try:
        progress_data = {
            'current_idx': st.session_state.current_idx,
            'answers': st.session_state.answers,
            'difficulty': st.session_state.difficulty,
            'timestamp': datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        }
        
        progress_file = get_progress_file_path(st.session_state.difficulty)
        
        with open(progress_file, 'w', encoding='utf-8') as f:
            json.dump(progress_data, f, indent=4)
            
        return True
    except Exception as e:
        st.sidebar.error(f"Failed to save progress: {str(e)}")
        return False

# Define function to load questions
def load_questions(difficulty="easy"):
    """Load question set for the specified difficulty"""
    try:
        file_path = f"./document/spatial_understanding/questions/{difficulty}_questions_modified.json"
        with open(file_path, 'r', encoding='utf-8') as f:
            questions = json.load(f)
            
        # Filter out questions that require elements if needed
        filtered_questions = [q for q in questions if not q.get('need_elements', True)]
        
        # Ensure all questions have the necessary fields
        valid_questions = []
        for q in filtered_questions:
            if ('objective_question' in q and 'choice' in q and 'objective_reference_answer' in q
                and 'image_path' in q):
                valid_questions.append(q)
        
        return valid_questions
    except Exception as e:
        st.error(f"Failed to load questions: {str(e)}")
        return []

# Define function to record answers
def record_answer(answer):
    """Record user answer and move to the next question"""
    current_question = st.session_state.questions[st.session_state.current_idx]
    question_id = current_question.get('id', str(st.session_state.current_idx))
    
    st.session_state.answers[question_id] = {
        'user_answer': answer,
        'correct_answer': current_question['objective_reference_answer'],
        'question': current_question['objective_question'],
        'is_correct': answer == current_question['objective_reference_answer'],
        'timestamp': datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    }
    
    # Auto-save progress after each answer
    save_progress()
    
    # Mark quiz as completed if it's the last question
    if st.session_state.current_idx >= len(st.session_state.questions) - 1:
        st.session_state.test_completed = True
        # Show score page
        show_score()
    else:
        # Otherwise, move to the next question
        st.session_state.current_idx += 1
    
    st.rerun()

# Define function to show score
def show_score():
    """Display the score page"""
    total_questions = len(st.session_state.answers)
    correct_answers = sum(1 for a in st.session_state.answers.values() if a['is_correct'])
    score_percentage = (correct_answers / total_questions * 100) if total_questions > 0 else 0
    
    st.title("Quiz Results")
    st.header(f"Score: {correct_answers}/{total_questions} ({score_percentage:.1f}%)")
    
    # Create results dataframe
    results = []
    for question_id, answer_data in st.session_state.answers.items():
        results.append({
            'Question': answer_data['question'],
            'Your Answer': answer_data['user_answer'],
            'Correct Answer': answer_data['correct_answer'],
            'Result': '✓' if answer_data['is_correct'] else '✗',
            'Timestamp': answer_data.get('timestamp', '')
        })
    
    results_df = pd.DataFrame(results)
    st.dataframe(results_df, use_container_width=True)
    
    # Provide option to save final results
    if st.button("Save Results"):
        # Create results directory (if it doesn't exist)
        os.makedirs(RESULTS_DIR, exist_ok=True)
        
        # Save as CSV
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        results_df.to_csv(f"{RESULTS_DIR}/{st.session_state.difficulty}_results_{timestamp}.csv", index=False)
        
        # Also save as JSON with more detailed information
        with open(f"{RESULTS_DIR}/{st.session_state.difficulty}_results_{timestamp}.json", 'w', encoding='utf-8') as f:
            json.dump({
                'score': {
                    'correct': correct_answers,
                    'total': total_questions,
                    'percentage': score_percentage
                },
                'answers': st.session_state.answers,
                'difficulty': st.session_state.difficulty,
                'completion_time': datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            }, f, indent=4)
            
        st.success("Results saved successfully!")
    
    # Provide option to restart
    if st.button("Start New Quiz"):
        st.session_state.current_idx = 0
        st.session_state.answers = {}
        st.session_state.test_completed = False
        save_progress()  # Save the reset state
        st.rerun()

def main():
    if not st.session_state.test_completed:
        st.title("Image Recognition Quiz")
        
        # Sidebar for difficulty selection and controls
        with st.sidebar:
            st.header("Settings")
            
            # Show progress restoration information if applicable
            if st.session_state.answers:
                completed = len(st.session_state.answers)
                st.success(f"Restored previous progress: {completed} questions answered")
            
            difficulty = st.selectbox(
                "Select Difficulty", 
                ["easy", "medium", "hard"],
                index=["easy", "medium", "hard"].index(st.session_state.difficulty)
            )
            
            # Reload questions if difficulty changes
            if difficulty != st.session_state.difficulty:
                if st.session_state.answers and not st.checkbox("Discard current progress?"):
                    st.warning("Please confirm discarding current progress to change difficulty")
                else:
                    st.session_state.difficulty = difficulty
                    st.session_state.questions = load_questions(difficulty)
                    st.session_state.current_idx = 0
                    st.session_state.answers = {}
                    save_progress()  # Save the new state
                    st.rerun()
            
            if st.button("Load Questions"):
                st.session_state.questions = load_questions(st.session_state.difficulty)
                
                # If we have progress but no questions loaded yet, load questions first
                if not st.session_state.questions and st.session_state.answers:
                    st.info("Loading questions to continue your progress...")
                
                # Only reset progress if explicitly chosen
                if not st.session_state.answers or st.checkbox("Start from beginning?"):
                    st.session_state.current_idx = 0
                    st.session_state.answers = {}
                
                save_progress()  # Save the new state
                st.rerun()
            
            # Show manual save button
            if st.session_state.questions and st.button("Save Progress Manually"):
                if save_progress():
                    st.success("Progress saved successfully!")
            
            # Add a section showing completion stats
            if st.session_state.answers and st.session_state.questions:
                st.divider()
                st.subheader("Progress")
                completed = len(st.session_state.answers)
                total = len(st.session_state.questions)
                st.progress(completed / total if total > 0 else 0)
                st.write(f"Completed: {completed}/{total}")
                
                if completed > 0:
                    correct = sum(1 for a in st.session_state.answers.values() if a['is_correct'])
                    accuracy = (correct / completed * 100) if completed > 0 else 0
                    st.write(f"Current accuracy: {accuracy:.1f}%")
        
        # Show prompt if no questions are loaded
        if not st.session_state.questions:
            st.info("Please load questions from the sidebar to start or continue the quiz")
            return
        
        # Display current progress
        total_questions = len(st.session_state.questions)
        st.progress((st.session_state.current_idx) / total_questions)
        st.write(f"Question {st.session_state.current_idx + 1} / {total_questions}")
        
        # Get current question
        current_question = st.session_state.questions[st.session_state.current_idx]
        
        # Two-column layout
        col1, col2 = st.columns([2, 1])
        
        with col1:
            # Display image
            st.subheader("Image")
            try:
                image_path = current_question['image_path']
                # Ensure path exists
                if Path(image_path).exists():
                    image = Image.open(image_path)
                    st.image(image, use_column_width=True)
                else:
                    st.error(f"Image file does not exist: {image_path}")
            except Exception as e:
                st.error(f"Unable to load image: {str(e)}")
                st.write(f"Image path: {current_question.get('image_path', 'Unknown')}")
        
        with col2:
            # Display question
            st.subheader("Question")
            st.write(current_question['objective_question'])
            
            # Display prompt
            st.subheader("Prompt")
            st.write(current_question.get('prompt', "No prompt available"))
            
            # Display choices
            st.subheader("Select an Answer")
            
            choice_map = current_question['choice']
            
            cols = st.columns(2)
            with cols[0]:
                if st.button(f"A: {choice_map.get('A', 'Option A')}", use_container_width=True):
                    record_answer("A")
                
                if st.button(f"C: {choice_map.get('C', 'Option C')}", use_container_width=True):
                    record_answer("C")
                    
            with cols[1]:
                if st.button(f"B: {choice_map.get('B', 'Option B')}", use_container_width=True):
                    record_answer("B")
                    
                if st.button(f"D: {choice_map.get('D', 'Option D')}", use_container_width=True):
                    record_answer("D")
            
            # Check if this question has been answered before (if revisiting)
            question_id = current_question.get('id', str(st.session_state.current_idx))
            if question_id in st.session_state.answers:
                previous_answer = st.session_state.answers[question_id]
                st.info(f"You previously answered: {previous_answer['user_answer']}")
    else:
        # If quiz is completed, show score
        show_score()

if __name__ == "__main__":
    main()