import streamlit as st
import json
import os
import glob
from typing import Dict, List, Any, Optional
from datetime import datetime
import pandas as pd
import threading
import time
from filelock import FileLock
import base64
from PIL import Image
import io
import random

# Page configuration
st.set_page_config(
    page_title="GSM8K-V Annotation Tool",
    page_icon="📝",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Custom CSS
st.markdown("""
<style>
    /* 减少整体页面边距 */
    .main .block-container {
        padding-top: 1rem !important;
        padding-bottom: 1rem !important;
        padding-left: 1rem !important;
        padding-right: 1rem !important;
        max-width: 100% !important;
    }
    
    /* 减少组件之间的间距 */
    .element-container {
        margin-bottom: 0.5rem !important;
    }
    
    /* 减少metric组件的间距 */
    div[data-testid="metric-container"] {
        background-color: white;
        border: 1px solid #dee2e6;
        padding: 0.5rem;
        border-radius: 0.25rem;
        margin-bottom: 0.25rem !important;
    }
    
    /* 减少列之间的间距 */
    .stColumns > div {
        padding-left: 0.25rem !important;
        padding-right: 0.25rem !important;
    }
    
    /* 减少文本区域的间距 */
    .stTextArea > div > div {
        margin-bottom: 0.5rem !important;
    }
    
    /* 减少文本输入的间距 */
    .stTextInput > div > div {
        margin-bottom: 0.5rem !important;
    }
    
    /* 减少按钮的间距 */
    .stButton > button {
        margin-bottom: 0.25rem !important;
    }
    
    /* 减少进度条的间距 */
    .stProgress {
        margin-bottom: 0.5rem !important;
    }
    
   
    
    .original-question {
        font-size: 1.5rem;
        font-weight: 500;
        line-height: 1.6;
        color: #333;
        margin: 0.5rem 0 !important;
        padding: 0.75rem;
        background-color: white;
        border-radius: 0.5rem;
        border-left: 4px solid #1f77b4;
    }
    
    
    .ground-truth {
        font-size: 1.5rem;
        font-weight: bold;
        color: #1f77b4;
    }
    
    /* 图片容器优化 */
    .image-container {
        background-color: #f8f9fa;
        padding: 0.75rem;
        border-radius: 0.5rem;
        border: 1px solid #dee2e6;
        margin-bottom: 0.5rem !important;
    }
    
  
    
    /* 历史容器优化 */
    .history-container {
        background-color: #f8f9fa;
        padding: 0.75rem;
        border-radius: 0.5rem;
        border: 1px solid #dee2e6;
        margin-bottom: 0.5rem !important;
    }
    
    /* 自定义按钮样式 */
    div[data-testid="column"]:first-child .stButton > button {
        background-color: #28a745 !important;
        border-color: #28a745 !important;
        color: white !important;
        font-weight: bold !important;
    }
    
    div[data-testid="column"]:first-child .stButton > button:hover {
        background-color: #218838 !important;
        border-color: #1e7e34 !important;
        transform: translateY(-2px) !important;
        box-shadow: 0 4px 8px rgba(40, 167, 69, 0.3) !important;
    }
    
    div[data-testid="column"]:nth-child(2) .stButton > button {
        background-color: #dc3545 !important;
        border-color: #dc3545 !important;
        color: white !important;
        font-weight: bold !important;
    }
    
    div[data-testid="column"]:nth-child(2) .stButton > button:hover {
        background-color: #c82333 !important;
        border-color: #bd2130 !important;
        transform: translateY(-2px) !important;
        box-shadow: 0 4px 8px rgba(220, 53, 69, 0.3) !important;
    }
    
    div[data-testid="column"]:nth-child(3) .stButton > button {
        background-color: #6c757d !important;
        border-color: #6c757d !important;
        color: white !important;
        font-weight: bold !important;
    }
    
    div[data-testid="column"]:nth-child(3) .stButton > button:hover {
        background-color: #5a6268 !important;
        border-color: #545b62 !important;
        transform: translateY(-2px) !important;
        box-shadow: 0 4px 8px rgba(108, 117, 125, 0.3) !important;
    }
    
    /* 按钮过渡效果 */
    .stButton > button {
        transition: all 0.3s ease !important;
        border-radius: 0.5rem !important;
        padding: 0.5rem 1rem !important;
        font-size: 1rem !important;
        margin-bottom: 0.25rem !important;
    }
    
    /* 减少侧边栏的间距 */
    .css-1d391kg {
        padding-top: 1rem !important;
    }
    
    /* 减少标题的间距 */
    h1, h2, h3 {
        margin-top: 0.5rem !important;
        margin-bottom: 0.5rem !important;
    }
    
    /* 减少段落间距 */
    p {
        margin-bottom: 0.5rem !important;
    }
    
    /* 图片间距优化 */
    .stImage {
        margin-bottom: 0.25rem !important;
    }
    
    /* 警告和信息框间距优化 */
    .stAlert {
        margin-bottom: 0.5rem !important;
        padding: 0.5rem !important;
    }
    
    /* 成功消息间距优化 */
    .stSuccess {
        margin-bottom: 0.5rem !important;
        padding: 0.5rem !important;
    }
    
    /* 错误消息间距优化 */
    .stError {
        margin-bottom: 0.5rem !important;
        padding: 0.5rem !important;
    }
    
    /* 移除不必要的边距 */
    .block-container > div {
        gap: 0.5rem !important;
    }
</style>
""", unsafe_allow_html=True)

class AnnotationManager:
    def __init__(self, annotations_file: str = "annotations.json"):
        """Initialize annotation manager with filelock support"""
        self.annotations_file = annotations_file
        self.lock_file = f"{annotations_file}.lock"
        self.file_lock = FileLock(self.lock_file)
        self.annotations = self.load_annotations()
    
    def load_annotations(self) -> Dict[str, Any]:
        """Load existing annotations from file with file locking"""
        max_retries = 3
        
        for attempt in range(max_retries):
            try:
                with self.file_lock.acquire(timeout=10):  # 10 second timeout
                    if os.path.exists(self.annotations_file):
                        with open(self.annotations_file, 'r', encoding='utf-8') as f:
                            return json.load(f)
                    return {}
            except Exception as e:
                if attempt < max_retries - 1:
                    time.sleep(0.1 * (attempt + 1))  # Exponential backoff
                    continue
                st.error(f"Error loading annotations after {max_retries} attempts: {e}")
                return {}
        
        return {}
    
    def save_annotations(self) -> bool:
        """Save annotations to file with atomic write and file locking"""
        max_retries = 3
        temp_file = f"{self.annotations_file}.tmp"
        
        for attempt in range(max_retries):
            try:
                with self.file_lock.acquire(timeout=10):  # 10 second timeout
                    # Write to temporary file first (atomic operation)
                    with open(temp_file, 'w', encoding='utf-8') as f:
                        json.dump(self.annotations, f, indent=2, ensure_ascii=False)
                        f.flush()
                        os.fsync(f.fileno())  # Force write to disk
                    
                    # Atomic move (replace original file)
                    if os.name == 'nt':  # Windows
                        if os.path.exists(self.annotations_file):
                            os.remove(self.annotations_file)
                        os.rename(temp_file, self.annotations_file)
                    else:  # Unix/Linux
                        os.rename(temp_file, self.annotations_file)
                    
                    return True
                    
            except Exception as e:
                if attempt < max_retries - 1:
                    time.sleep(0.1 * (attempt + 1))  # Exponential backoff
                    continue
                st.error(f"Error saving annotations after {max_retries} attempts: {e}")
                return False
            finally:
                # Clean up temp file if it exists
                if os.path.exists(temp_file):
                    try:
                        os.remove(temp_file)
                    except OSError:
                        pass
        
        return False
    
    def add_annotation(self, question_id: str, annotator: str, result: str, comment: str = "", human_answer: str = "") -> bool:
        """Add an annotation with file locking and conflict detection"""
        try:
            with self.file_lock.acquire(timeout=10):
                # Reload annotations to get latest state
                latest_annotations = {}
                if os.path.exists(self.annotations_file):
                    with open(self.annotations_file, 'r', encoding='utf-8') as f:
                        latest_annotations = json.load(f)
                
                # Check if someone else has already added this annotation
                if question_id in latest_annotations and annotator in latest_annotations[question_id]:
                    st.warning(f"⚠️ You have already annotated this question!")
                    return False
                
                # Check if question is now fully annotated
                if question_id in latest_annotations and len(latest_annotations[question_id]) >= 2:
                    st.warning(f"⚠️ This question has been fully annotated by others!")
                    return False
                
                # Add new annotation
                if question_id not in latest_annotations:
                    latest_annotations[question_id] = {}
                
                latest_annotations[question_id][annotator] = {
                    "result": result,
                    "comment": comment,
                    "human_answer": human_answer,
                    "timestamp": datetime.now().isoformat()
                }
                
                # Update instance variable
                self.annotations = latest_annotations
                
                # Save with atomic write
                temp_file = f"{self.annotations_file}.tmp"
                with open(temp_file, 'w', encoding='utf-8') as f:
                    json.dump(self.annotations, f, indent=2, ensure_ascii=False)
                    f.flush()
                    os.fsync(f.fileno())
                
                # Atomic move
                if os.name == 'nt':  # Windows
                    if os.path.exists(self.annotations_file):
                        os.remove(self.annotations_file)
                    os.rename(temp_file, self.annotations_file)
                else:  # Unix/Linux
                    os.rename(temp_file, self.annotations_file)
                
                return True
                
        except Exception as e:
            st.error(f"Error adding annotation: {e}")
            return False
    
    def get_annotation_count(self, question_id: str) -> int:
        """Get the number of annotations for a question"""
        if question_id in self.annotations:
            return len(self.annotations[question_id])
        return 0
    
    def has_annotator_annotated(self, question_id: str, annotator: str) -> bool:
        """Check if an annotator has already annotated this question"""
        if question_id in self.annotations:
            return annotator in self.annotations[question_id]
        return False
    
    def is_fully_annotated(self, question_id: str) -> bool:
        """Check if a question has been annotated by 2 people"""
        return self.get_annotation_count(question_id) >= 2
    
    def get_annotation_results(self, question_id: str) -> Dict[str, Dict[str, str]]:
        """Get annotation results for a question"""
        if question_id in self.annotations:
            return self.annotations[question_id]
        return {}
    
    def refresh_annotations(self):
        """Refresh annotations from file (useful for checking updates)"""
        self.annotations = self.load_annotations()

class ImagePreloader:
    def __init__(self, data_loader: 'DataLoader', cache_size: int = 5):
        """Initialize image preloader with LRU cache"""
        self.data_loader = data_loader
        self.cache_size = cache_size
        self.image_cache = {}
        self.cache_order = []
        self.preload_lock = threading.Lock()
    
    def _encode_image_to_base64(self, image_path: str) -> Optional[str]:
        """Encode image to base64 for caching"""
        try:
            with Image.open(image_path) as img:
                # Resize large images to reduce memory usage
                if img.width > 1000 or img.height > 1000:
                    img.thumbnail((1000, 1000), Image.Resampling.LANCZOS)
                
                buffer = io.BytesIO()
                img.save(buffer, format='PNG')
                encoded = base64.b64encode(buffer.getvalue()).decode()
                return encoded
        except Exception as e:
            print(f"Error encoding image {image_path}: {e}")
            return None
    
    def _decode_base64_to_image(self, encoded_data: str):
        """Decode base64 back to image"""
        try:
            image_data = base64.b64decode(encoded_data)
            return Image.open(io.BytesIO(image_data))
        except Exception as e:
            print(f"Error decoding image: {e}")
            return None
    
    def preload_images(self, question_ids: List[str]):
        """Preload images for given question IDs"""
        with self.preload_lock:
            for qid in question_ids:
                if qid not in self.image_cache:
                    images = self.data_loader.find_images_for_qid(qid)
                    encoded_images = []
                    
                    for img_path in images:
                        if os.path.exists(img_path):
                            encoded = self._encode_image_to_base64(img_path)
                            if encoded:
                                encoded_images.append({
                                    'path': img_path,
                                    'data': encoded,
                                    'name': os.path.basename(img_path)
                                })
                    
                    # Add to cache
                    self.image_cache[qid] = encoded_images
                    self.cache_order.append(qid)
                    
                    # Maintain cache size (LRU eviction)
                    while len(self.cache_order) > self.cache_size:
                        oldest_qid = self.cache_order.pop(0)
                        if oldest_qid in self.image_cache:
                            del self.image_cache[oldest_qid]
    
    def get_cached_images(self, question_id: str) -> List[Dict[str, Any]]:
        """Get cached images for a question"""
        with self.preload_lock:
            if question_id in self.image_cache:
                # Move to end (most recently used)
                if question_id in self.cache_order:
                    self.cache_order.remove(question_id)
                    self.cache_order.append(question_id)
                return self.image_cache[question_id]
            return []

class DataLoader:
    def __init__(self, json_file_path: str, image_base_path: str):
        """Initialize data loader with image preloading"""
        self.json_file_path = json_file_path
        self.image_base_path = image_base_path
        self.data = self.load_json_data()
        self.image_preloader = ImagePreloader(self, cache_size=10)
    
    def load_json_data(self) -> List[Dict[str, Any]]:
        """Load JSON data from file"""
        try:
            with open(self.json_file_path, 'r', encoding='utf-8') as f:
                return json.load(f)
        except Exception as e:
            st.error(f"Error loading JSON data: {e}")
            return []
    
    def find_images_for_qid(self, qid: str) -> List[str]:
        """Find all images for a given question ID"""
        images = []
        search_patterns = [
            f"**/{qid}_*.png",
            f"**/{qid}_*.jpg", 
            f"**/{qid}_*.jpeg",
            f"{qid}_*.png",
            f"{qid}_*.jpg",
            f"{qid}_*.jpeg"
        ]
        
        for pattern in search_patterns:
            try:
                found_files = glob.glob(os.path.join(self.image_base_path, pattern), recursive=True)
                images.extend(found_files)
            except Exception as e:
                continue
        
        # Remove duplicates and filter existing files
        unique_images = []
        seen = set()
        
        for img_path in images:
            normalized_path = os.path.normpath(img_path)
            if normalized_path not in seen and os.path.exists(normalized_path):
                unique_images.append(normalized_path)
                seen.add(normalized_path)
        
        # Sort images by scene number
        try:
            unique_images.sort(key=lambda x: self._extract_scene_number(x))
        except Exception:
            unique_images.sort()
        
        return unique_images
    
    def _extract_scene_number(self, img_path: str) -> int:
        """Extract scene number from image filename for sorting"""
        try:
            img_name = os.path.basename(img_path)
            if '_' in img_name:
                scene_part = img_name.split('_')[1]
                scene_num = ''.join(filter(str.isdigit, scene_part))
                return int(scene_num) if scene_num else 0
        except Exception:
            pass
        return 0
    
    def get_available_questions(self, annotator: str, annotation_manager: AnnotationManager) -> List[Dict[str, Any]]:
        """Get questions that are available for annotation in random order with consistent assignment per user"""
        # Refresh annotations to get latest state
        annotation_manager.refresh_annotations()
        
        available = []
        
        for item in self.data:
            qid = item.get('question_id', '')
            if not qid:
                continue
            
            # Skip if already fully annotated (2 people)
            if annotation_manager.is_fully_annotated(qid):
                continue
            
            # Skip if current annotator has already annotated this
            if annotation_manager.has_annotator_annotated(qid, annotator):
                continue
            
            available.append(item)
        
        # Create consistent random order for each user, ensuring the same user sees the same order each time
        if available:
            # Use username hash as seed to ensure each user has a unique and consistent random order
            user_seed = hash(annotator) % (2**32)
            random.seed(user_seed)
            random.shuffle(available)
            
            # Reset random seed to avoid affecting other random operations
            random.seed()
        
        return available
    
    def preload_next_images(self, current_index: int, available_questions: List[Dict[str, Any]], preload_count: int = 3):
        """Preload images for next few questions"""
        next_qids = []
        
        # Preload the next few questions (next in random order)
        for i in range(current_index + 1, min(current_index + preload_count + 1, len(available_questions))):
            qid = available_questions[i].get('question_id', '')
            if qid:
                next_qids.append(qid)
        
        # Also preload a few questions before the current one (in case user goes back)
        for i in range(max(0, current_index - preload_count), current_index):
            qid = available_questions[i].get('question_id', '')
            if qid:
                next_qids.append(qid)
        
        if next_qids:
            # Preload in background thread
            threading.Thread(
                target=self.image_preloader.preload_images,
                args=(next_qids,),
                daemon=True
            ).start()

def auto_jump_to_next_question():
    """Helper function to automatically jump to next available question"""
    if st.session_state.get('available_questions'):
        available_questions = st.session_state['available_questions']
        current_index = st.session_state.get('current_index', 0)
        
        # Find next available question
        next_index = current_index + 1
        if next_index < len(available_questions):
            st.session_state.current_index = next_index
        else:
            # If at end, wrap to beginning to find any newly available questions
            st.session_state.current_index = 0
            # Also trigger a refresh of available questions
            st.session_state.refresh_questions = True

def render_progress_info(annotation_manager: AnnotationManager, data_loader: DataLoader, annotator: str):
    """Render progress information"""
    st.markdown('<div class="progress-container">', unsafe_allow_html=True)
    
    # Refresh annotations for accurate progress
    annotation_manager.refresh_annotations()
    
    total_questions = len(data_loader.data)
    fully_annotated = sum(1 for item in data_loader.data 
                         if annotation_manager.is_fully_annotated(item.get('question_id', '')))
    available_for_annotator = len(data_loader.get_available_questions(annotator, annotation_manager))
    annotated_by_current = sum(1 for item in data_loader.data 
                              if annotation_manager.has_annotator_annotated(item.get('question_id', ''), annotator))
    
    col1, col2, col3, col4 = st.columns(4)
    
    with col1:
        st.metric("Total Questions", total_questions)
    
    with col2:
        st.metric("Fully Annotated", fully_annotated)
    
    with col3:
        st.metric("Available for You", available_for_annotator)
    
    with col4:
        st.metric("Your Annotations", annotated_by_current)
    
    # Progress bar
    if total_questions > 0:
        progress = fully_annotated / total_questions
        st.progress(progress, text=f"Overall Progress: {fully_annotated}/{total_questions} ({progress:.1%})")
    
    st.markdown('</div>', unsafe_allow_html=True)

def render_question_display(question_data: Dict[str, Any], user_answer: str = ""):
    """Render question information"""
    st.markdown('<div class="question-container">', unsafe_allow_html=True)
    
    st.markdown(f"### Question ID: {question_data.get('question_id', 'N/A')}")
    
    # Improved question label and text with better spacing
    st.markdown('<div class="question-label">Original Question:</div>', unsafe_allow_html=True)
    question_text = question_data.get('original_question', 'N/A')
    st.markdown(f'<div class="original-question">{question_text}</div>', unsafe_allow_html=True)
    
    st.markdown('</div>', unsafe_allow_html=True)
    
    # Answer container with conditional Ground Truth display
    st.markdown('<div class="answer-container">', unsafe_allow_html=True)
    if user_answer and user_answer.strip():
        # Show ground truth only after user has provided an answer
        ground_truth = question_data.get('math_ground_truth', 'N/A')
        st.markdown(f'<div class="ground-truth">Ground Truth: {ground_truth}</div>', unsafe_allow_html=True)
    else:
        # Show placeholder text when no user answer is provided
        st.markdown('<div class="ground-truth">Ground Truth: Please Answer First</div>', unsafe_allow_html=True)
    st.markdown('</div>', unsafe_allow_html=True)

def render_images(images: List[str], show_header: bool = True):
    """Render images in a grid layout"""
    if not images:
        st.warning("No images found for this question.")
        return
    
    if show_header:
        st.markdown('<div class="image-container">', unsafe_allow_html=True)
        st.markdown(f"### Images ({len(images)} scenes)")
    
    # Display images in columns (max 3 per row)
    cols_per_row = 3
    for i in range(0, len(images), cols_per_row):
        cols = st.columns(cols_per_row)
        for j, img_path in enumerate(images[i:i+cols_per_row]):
            with cols[j]:
                try:
                    img_name = os.path.basename(img_path)
                    scene_num = img_name.split('_')[1].split('.')[0] if '_' in img_name else 'Unknown'
                    st.image(img_path, caption=f"Scene {scene_num}")
                except Exception as e:
                    st.error(f"Error loading image {img_path}: {e}")

    if show_header:
        st.markdown('</div>', unsafe_allow_html=True)

def render_images_optimized(images: List[str], question_id: str, data_loader: DataLoader):
    """Render images with preloading optimization"""
    if not images:
        st.warning("No images found for this question.")
        return
    
    st.markdown('<div class="image-container">', unsafe_allow_html=True)
    st.markdown(f"### Images ({len(images)} scenes)")
    
    # Try to get cached images first
    cached_images = data_loader.image_preloader.get_cached_images(question_id)
    
    if cached_images:
        # Use cached images for faster display
        cols_per_row = 3
        for i in range(0, len(cached_images), cols_per_row):
            cols = st.columns(cols_per_row)
            for j, cached_img in enumerate(cached_images[i:i+cols_per_row]):
                with cols[j]:
                    try:
                        img_name = cached_img['name']
                        scene_num = img_name.split('_')[1].split('.')[0] if '_' in img_name else 'Unknown'
                        
                        # Display from base64 data
                        st.markdown(
                            f'<img src="data:image/png;base64,{cached_img["data"]}" '
                            f'style="width: 100%; border-radius: 0.5rem;" '
                            f'alt="Scene {scene_num}"/>',
                            unsafe_allow_html=True
                        )
                        st.caption(f"Scene {scene_num}")
                        
                    except Exception as e:
                        st.error(f"Error displaying cached image: {e}")
    else:
        # Fallback to regular image loading (without header to avoid duplication)
        render_images(images, show_header=False)
    
    st.markdown('</div>', unsafe_allow_html=True)

def render_human_answer_section(question_id: str, question_data: Dict[str, Any]):
    """Render human answer input section"""
    st.markdown('<div class="human-answer-container">', unsafe_allow_html=True)
    st.markdown("### 🧠 Human Solution")
    st.markdown("**Solve this problem yourself based on the images:**")
    
    # Human answer input - use on_change parameter to prevent auto-refresh
    human_answer = st.text_input(
        "Your calculated answer:",
        key=f"human_answer_{question_id}",
        placeholder="Enter your solution here...",
        on_change=None  # Explicitly set to None to prevent automatic page refresh
    )
    
    # Show comparison with ground truth if answer is provided
    if human_answer:
        ground_truth = question_data.get('math_ground_truth', 'N/A')
        if human_answer.strip() == str(ground_truth).strip():
            st.success(f"✅ Your answer matches the ground truth: {ground_truth}")
        else:
            st.warning(f"⚠️ Your answer: {human_answer} | Ground truth: {ground_truth}")
    
    st.markdown('</div>', unsafe_allow_html=True)
    
    return human_answer

def render_annotation_section(question_id: str, annotator: str, annotation_manager: AnnotationManager):
    """Render annotation section with buttons and comment"""
    st.markdown('<div class="annotation-section">', unsafe_allow_html=True)
    
    st.markdown("### 📝 Make your annotation:")
    
    # Refresh annotations to check latest state
    annotation_manager.refresh_annotations()
    
    # Check if already annotated or fully annotated
    if annotation_manager.has_annotator_annotated(question_id, annotator):
        st.warning("⚠️ You have already annotated this question!")
        st.markdown('</div>', unsafe_allow_html=True)
        return
    
    if annotation_manager.is_fully_annotated(question_id):
        st.info("ℹ️ This question has been fully annotated by others.")
        st.markdown('</div>', unsafe_allow_html=True)
        return
    
    # Get human answer from session state
    human_answer = st.session_state.get(f"human_answer_{question_id}", "")
    
    # Comment input - use on_change parameter to prevent auto-refresh
    comment = st.text_area(
        "Comment (optional):",
        key=f"comment_{question_id}",
        placeholder="Add your thoughts, reasoning, or observations about this question...",
        height=100,
        on_change=None  # Explicitly set to None to prevent automatic page refresh
    )
    
    # Annotation buttons
    col1, col2, col3 = st.columns([1, 1, 1])
    
    with col1:
        if st.button("✅ Correct", use_container_width=True, key=f"correct_btn_{question_id}"):
            if annotation_manager.add_annotation(question_id, annotator, "correct", comment, human_answer):
                st.success("✅ Marked as Correct!")
                # Auto-jump to next question
                auto_jump_to_next_question()
                st.rerun()
            else:
                st.error("❌ Failed to save annotation")
    
    with col2:
        if st.button("❌ Incorrect", use_container_width=True, key=f"incorrect_btn_{question_id}"):
            if annotation_manager.add_annotation(question_id, annotator, "incorrect", comment, human_answer):
                st.success("❌ Marked as Incorrect!")
                # Auto-jump to next question
                auto_jump_to_next_question()
                st.rerun()
            else:
                st.error("❌ Failed to save annotation")
    
    with col3:
        if st.button("⏭️ Skip", use_container_width=True, key=f"skip_btn_{question_id}"):
            if annotation_manager.add_annotation(question_id, annotator, "skip", comment, human_answer):
                st.success("⏭️ Skipped!")
                # Auto-jump to next question
                auto_jump_to_next_question()
                st.rerun()
            else:
                st.error("❌ Failed to save annotation")
    
    st.markdown('</div>', unsafe_allow_html=True)

def render_navigation(available_questions: List[Dict], current_index: int, annotator: str):
    """Render navigation controls with QID jump support"""
    st.markdown('<div class="navigation-container">', unsafe_allow_html=True)
    
    if not available_questions:
        st.info("🎉 No more questions available for annotation!")
        st.markdown('</div>', unsafe_allow_html=True)
        return current_index
    
    col1, col2, col3, col4, col5 = st.columns([1, 1, 2, 1, 1])
    
    with col1:
        if st.button("⬅️ Previous", disabled=(current_index <= 0), key="nav_prev"):
            current_index = max(0, current_index - 1)
    
    with col2:
        if st.button("Next ➡️", disabled=(current_index >= len(available_questions) - 1), key="nav_next"):
            current_index = min(len(available_questions) - 1, current_index + 1)
    
    with col3:
        st.markdown(f"**Question {current_index + 1} of {len(available_questions)} available**")
        # Display current question ID (optional)
        if available_questions:
            current_qid = available_questions[current_index].get('question_id', 'Unknown')
            st.markdown(f"*ID: {current_qid}*")
    
    with col4:
        target_qid = st.text_input(
            "Jump to QID:", 
            key="jump_qid", 
            label_visibility="collapsed",
            on_change=None  
        )
    
    with col5:
        if st.button("Go", key="nav_go"):
            found = False
            for idx, q in enumerate(available_questions):
                if q.get('question_id') == target_qid:
                    current_index = idx
                    found = True
                    break
            if not found and target_qid:  # Only show error when ID is entered but not found
                st.error(f"Question ID {target_qid} not found or not available.")
    
    st.markdown('</div>', unsafe_allow_html=True)
    return current_index

def render_annotation_history(question_id: str, annotation_manager: AnnotationManager):
    """Render annotation history for the current question"""
    annotations = annotation_manager.get_annotation_results(question_id)
    
    if annotations:
        st.markdown('<div class="history-container">', unsafe_allow_html=True)
        st.markdown("### 📋 Annotation History:")
        
        for annotator, data in annotations.items():
            result = data.get("result", "")
            comment = data.get("comment", "")
            human_answer = data.get("human_answer", "")
            timestamp = data.get("timestamp", "")
            
            # Format timestamp
            if timestamp:
                try:
                    dt = datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
                    formatted_time = dt.strftime("%Y-%m-%d %H:%M")
                except:
                    formatted_time = timestamp
            else:
                formatted_time = "Unknown"
            
            st.markdown(f"**{annotator}** ({formatted_time}):")
            st.markdown(f"- **Result:** {result}")
            
            if human_answer:
                st.markdown(f"- **Human Answer:** {human_answer}")
            
            if comment:
                st.markdown(f"- **Comment:** {comment}")
            
            st.markdown("---")
        
        st.markdown('</div>', unsafe_allow_html=True)

def main():
    """Main application function with controlled navigation"""
    st.markdown('<div class="main-header">📝 GSM8K-V Annotation Tool</div>', unsafe_allow_html=True)
    
    VALID_USERS = [
        "user1",
        "user2", 
    ]
    
    # Sidebar configuration
    with st.sidebar:
        st.header("⚙️ Configuration")
        
        # Annotator selection with validation
        annotator = st.text_input(
            "Annotator Name:",
            value=st.session_state.get('annotator', ''),
            placeholder="Enter your name"
        )
        
        if annotator:
            # Validate if user is in the valid user list
            if annotator in VALID_USERS:
                st.session_state.annotator = annotator
                st.success(f"👤 Logged in as: **{annotator}**")
            else:
                st.error(f"❌ Access denied. '{annotator}' is not authorized.")
                st.warning("Please contact administrator for access.")
                st.info("💡 Make sure you are using the correct username provided by your administrator.")
                return
        else:
            st.warning("Please enter your annotator name to continue.")
            return
        
        # File configuration
        st.header("📁 File Configuration")
        
        json_file = st.text_input(
            "JSON File Path:",
            value="./meta_data.json"
        )
        
        image_path = st.text_input(
            "Image Base Path:",
            value="/home/yanyuchen/projects/gsm8k-v/data"
        )
        
        # Check file existence
        if not os.path.exists(json_file):
            st.error(f"JSON file not found: {json_file}")
            return
        
        if not os.path.exists(image_path):
            st.error(f"Image path not found: {image_path}")
            return
        
        st.success("✅ Files found")
        
        # Additional info
        st.header("ℹ️ Instructions")
        st.info("""
        **How to use:**
        1. Questions are automatically assigned to you in a consistent random order
        2. Review the question and images carefully
        3. Solve the problem yourself (Human Solution)
        4. Add your comment(reasons) if you choose Incorrect or Skip
        5. Choose your annotation:
           - ✅ **Correct**: images match question and ground truth
           - ❌ **Incorrect**: images don't match question or ground truth
           - ⏭️ **Skip**: Uncertain or problematic
        """)
    
    # Initialize managers
    annotation_manager = AnnotationManager()
    data_loader = DataLoader(json_file, image_path)
    
    if not data_loader.data:
        st.error("No data loaded from JSON file.")
        return
    
    # Check if we need to refresh questions
    if st.session_state.get('refresh_questions', False):
        st.session_state.current_index = 0
        st.session_state.refresh_questions = False
        # Clear any cached question list
        cache_key = f"cached_questions_{annotator}"
        if cache_key in st.session_state:
            del st.session_state[cache_key]
    
    # Get available questions for current annotator
    # Cache the question list to maintain consistency during navigation
    cache_key = f"cached_questions_{annotator}"
    if cache_key not in st.session_state or st.session_state.get('refresh_questions', False):
        available_questions = data_loader.get_available_questions(annotator, annotation_manager)
        st.session_state[cache_key] = available_questions
        st.session_state['available_questions'] = available_questions  # Also store without user prefix for auto-jump
    else:
        available_questions = st.session_state[cache_key]
        # Still need to check if any questions became unavailable
        current_available = data_loader.get_available_questions(annotator, annotation_manager)
        if len(current_available) != len(available_questions):
            # Update the cache if the number of available questions changed
            available_questions = current_available
            st.session_state[cache_key] = available_questions
            st.session_state['available_questions'] = available_questions
    
    # Initialize session state
    if 'current_index' not in st.session_state:
        st.session_state.current_index = 0
    
    # Ensure index is within bounds
    if st.session_state.current_index >= len(available_questions):
        st.session_state.current_index = max(0, len(available_questions) - 1)
    
    # Preload images for next questions
    if available_questions:
        data_loader.preload_next_images(st.session_state.current_index, available_questions)
    
    # Render progress information
    render_progress_info(annotation_manager, data_loader, annotator)
    
    # Navigation (only manual navigation allowed here)
    old_index = st.session_state.current_index
    st.session_state.current_index = render_navigation(available_questions, st.session_state.current_index, annotator)
    
    # If index changed manually, trigger preloading
    if old_index != st.session_state.current_index and available_questions:
        data_loader.preload_next_images(st.session_state.current_index, available_questions)
    
    # Main content
    if available_questions:
        current_question = available_questions[st.session_state.current_index]
        question_id = current_question.get('question_id', '')
        
        # Get user answer from session state first
        user_answer = st.session_state.get(f"human_answer_{question_id}", "")
        
        # Display question with conditional ground truth
        render_question_display(current_question, user_answer)
        
        # Find and display images (optimized)
        images = data_loader.find_images_for_qid(question_id)
        render_images_optimized(images, question_id, data_loader)
        
        # Human answer section
        render_human_answer_section(question_id, current_question)
        
        # Show annotation history
        render_annotation_history(question_id, annotation_manager)
        
        # Annotation section (auto-jump happens here)
        render_annotation_section(question_id, annotator, annotation_manager)
        
    else:
        st.success("🎉 Congratulations! You have completed all available annotations!")
        st.balloons()
        
        # Show summary
        st.markdown("### Your Annotation Summary:")
        your_annotations = sum(1 for item in data_loader.data 
                              if annotation_manager.has_annotator_annotated(item.get('question_id', ''), annotator))
        st.metric("Total Annotations Made", your_annotations)

if __name__ == "__main__":
    main()
