import streamlit as st
import json
import os
from PIL import Image
import datetime
from openai import OpenAI
import dotenv
dotenv.load_dotenv()

# Page configuration
st.set_page_config(page_title="Image Evaluation Tool", layout="wide")

# Initialize session state
if 'current_idx' not in st.session_state:
    st.session_state.current_idx = 0
if 'annotations' not in st.session_state:
    st.session_state.annotations = {}
if 'data_loaded' not in st.session_state:
    st.session_state.data_loaded = False
if 'data' not in st.session_state:
    st.session_state.data = []
if 'translated_prompts' not in st.session_state:
    st.session_state.translated_prompts = {}

def calculate_satisfaction_rate():
    """Calculate the percentage of images rated as satisfactory"""
    if not st.session_state.annotations:
        return 0, 0, 0  # No annotations yet
    
    total = len(st.session_state.annotations)
    yes_count = sum(1 for item in st.session_state.annotations.values() if item["evaluation"] == "Yes")
    satisfaction_rate = yes_count / total * 100 if total > 0 else 0
    
    return satisfaction_rate, yes_count, total

def translate_prompt(prompt):
    """Translate prompt using GPT-4o Mini"""
    try:
        api_key = os.environ.get('OPENAI_API_KEY')
        if not api_key:
            return "No OpenAI API key set in environment variables"

        client = OpenAI(api_key=api_key)
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": "You are a professional translator. Translate the English prompt to Chinese accurately and completely."},
                {"role": "user", "content": f"Please translate the following English prompt to Chinese:\n{prompt}"}
            ],
            temperature=0.3
        )
        return response.choices[0].message.content
    except Exception as e:
        return f"Translation failed: {str(e)}"

def load_previous_annotations():
    """Load previously saved annotations"""
    try:
        annotations_file = "./annotations/human_evaluation.json"
        if os.path.exists(annotations_file):
            with open(annotations_file, 'r') as f:
                annotations_list = json.load(f)
                
            annotations_dict = {item['id']: item for item in annotations_list}
            st.sidebar.success(f"Loaded {len(annotations_dict)} previous annotations")
            return annotations_dict
        return {}
    except Exception as e:
        st.sidebar.error(f"Failed to load previous annotations: {str(e)}")
        return {}

def find_first_unannotated(data):
    """Find the index of the first unannotated image"""
    for idx, item in enumerate(data):
        if item['id'] not in st.session_state.annotations:
            return idx
    return 0  # If all are annotated, start from the beginning

def load_data():
    """Load image dataset"""
    try:
        # 将选择器移到函数外部，放在用户点击按钮之前
        level = st.session_state.get('selected_level', "easy")
        
        # 构建文件路径并检查是否存在
        data_file = f"./document/basic_understanding/aligned_image_json_dall_e_3/{level}_aligned_images.json"
        
        if not os.path.exists(data_file):
            st.sidebar.error(f"File not found: {data_file}")
            st.sidebar.info("Please check if the path is correct")
            return [], ""
            
        with open(data_file, 'r') as f:
            data = json.load(f)
        
        st.sidebar.success(f"Loaded {len(data)} images from {level} dataset")
        return data, level
    except Exception as e:
        st.sidebar.error(f"Could not load data: {str(e)}")
        return [], ""

def navigate(direction):
    """Navigate to previous or next image"""
    if direction == "prev" and st.session_state.current_idx > 0:
        st.session_state.current_idx -= 1
    elif direction == "next" and st.session_state.current_idx < len(st.session_state.data) - 1:
        st.session_state.current_idx += 1
    elif direction == "next_unannotated":
        # Find next unannotated image
        current_idx = st.session_state.current_idx + 1
        while current_idx < len(st.session_state.data):
            if st.session_state.data[current_idx]['id'] not in st.session_state.annotations:
                st.session_state.current_idx = current_idx
                return
            current_idx += 1
        # If no unannotated images found after current position, start from beginning
        current_idx = 0
        while current_idx < st.session_state.current_idx:
            if st.session_state.data[current_idx]['id'] not in st.session_state.annotations:
                st.session_state.current_idx = current_idx
                return
            current_idx += 1
        # If all images are annotated, stay at the current position
    elif direction == "jump":
        new_idx = st.sidebar.number_input("Jump to image:", 
                                        min_value=1, 
                                        max_value=len(st.session_state.data), 
                                        value=st.session_state.current_idx + 1) - 1
        st.session_state.current_idx = new_idx

def save_annotation(evaluation):
    """Save evaluation for current image and auto-save"""
    current_item = st.session_state.data[st.session_state.current_idx]
    image_id = current_item['id']
    
    st.session_state.annotations[image_id] = {
        "id": image_id,
        "aspect": current_item['aspect'],
        "prompt": current_item['prompt'],
        "image_path": current_item['image_path'],
        "evaluation": evaluation,
        "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    }
    
    # Auto-save to JSON file
    save_annotations()
    
    # Move to next unannotated image
    navigate("next_unannotated")

def save_annotations():
    """Save all annotations to JSON file"""
    try:
        os.makedirs("./annotations", exist_ok=True)
        filename = f"./annotations/human_evaluation.json"
        
        annotations_list = list(st.session_state.annotations.values())
        
        with open(filename, 'w') as f:
            json.dump(annotations_list, f, indent=2)
        
        st.sidebar.success(f"Annotations saved to {filename}")
    except Exception as e:
        st.sidebar.error(f"Save failed: {str(e)}")

def main():
    st.title("Image Generation Quality Evaluation")
    
    # Load previous annotations when the app starts
    if 'annotations_loaded' not in st.session_state:
        st.session_state.annotations = load_previous_annotations()
        st.session_state.annotations_loaded = True
    
    # Sidebar controls
    with st.sidebar:
        st.header("Control Panel")
        
        # Add difficulty level selector before loading data
        selected_level = st.selectbox(
            "Select difficulty level", 
            ["easy", "medium", "hard"],
            index=0  # Default to "easy"
        )
        st.session_state.selected_level = selected_level
        
        # Load dataset button
        if st.button("Load Dataset"):
            data, level = load_data()
            if data:
                st.session_state.data = data
                st.session_state.data_loaded = True
                st.session_state.level = level
                
                # Find first unannotated image
                st.session_state.current_idx = find_first_unannotated(data)
                st.rerun()
        
        if st.session_state.data_loaded:
            # Display evaluation stats with satisfaction rate
            satisfaction_rate, yes_count, total = calculate_satisfaction_rate()
            st.write(f"Evaluated: {total}/{len(st.session_state.data)}")
            st.write(f"Satisfaction Rate: {satisfaction_rate:.1f}% ({yes_count}/{total})")
            
            # Create a progress bar for satisfaction rate
            if total > 0:
                st.progress(satisfaction_rate / 100)
            
            st.divider()
            st.subheader("Navigation")
            
            col1, col2 = st.columns(2)
            with col1:
                if st.button("Previous", use_container_width=True):
                    navigate("prev")
            with col2:
                if st.button("Next", use_container_width=True):
                    navigate("next")
            
            if st.button("Next Unannotated", use_container_width=True):
                navigate("next_unannotated")
            
            st.button("Jump to...", on_click=navigate, args=("jump",))
            
            st.divider()
            if st.button("Manual Save"):
                save_annotations()
    
    # Display current image if data is loaded
    if st.session_state.data_loaded and st.session_state.current_idx < len(st.session_state.data):
        current_item = st.session_state.data[st.session_state.current_idx]
        
        # Show progress
        st.progress((st.session_state.current_idx) / len(st.session_state.data))
        st.write(f"Image {st.session_state.current_idx + 1} / {len(st.session_state.data)}")
        
        # Two-column layout
        col1, col2 = st.columns([3, 2])
        
        with col1:
            # Display image
            st.subheader("Image")
            try:
                image = Image.open(current_item['image_path'])
                st.image(image, caption=f"ID: {current_item['id']}", use_column_width=True)
            except Exception as e:
                st.error(f"Could not load image: {str(e)}")
                st.write(f"Image path: {current_item['image_path']}")
        
        with col2:
            # Display prompt and info
            st.subheader("Prompt")
            st.write(current_item['prompt'])

            # Add translation button
            if st.button("Translate Prompt"):
                prompt_id = current_item['id']
                if prompt_id in st.session_state.translated_prompts:
                    translation = st.session_state.translated_prompts[prompt_id]
                else:
                    translation = translate_prompt(current_item['prompt'])
                    st.session_state.translated_prompts[prompt_id] = translation
                
                st.subheader("Chinese Translation")
                st.write(translation)
            
            st.subheader("Image Information")
            st.write(f"Category: {current_item['aspect']}")
            st.write(f"Difficulty: {current_item['level']}")
            st.write(f"Model: {current_item['model']}")
            
            # Evaluation interface
            st.divider()
            st.subheader("Evaluation")
            
            # Default quality level
            # if 'quality_level' not in st.session_state:
            #     st.session_state.quality_level = "Satisfactory"
            
            # Quality assessment selection
            # st.session_state.quality_level = st.radio(
            #     "Image quality assessment:",
            #     ["Satisfactory", "Unsatisfactory"]
            # )
            
            # Evaluation buttons
            col_yes, col_no = st.columns(2)
            with col_yes:
                if st.button("Yes - Matches prompt", use_container_width=True):
                    save_annotation("Yes")
            
            with col_no:
                if st.button("No - Doesn't match prompt", use_container_width=True):
                    save_annotation("No")
            
            # Show annotation status for current image
            if current_item['id'] in st.session_state.annotations:
                st.success("✓ This image has been annotated")
                st.write(f"Evaluation: {st.session_state.annotations[current_item['id']]['evaluation']}")
                # st.write(f"Quality level: {st.session_state.annotations[current_item['id']]['quality_level']}")
    
    elif not st.session_state.data_loaded:
        st.info("Please click 'Load Dataset' in the sidebar to start evaluation")
    else:
        st.success("All images have been evaluated!")

if __name__ == "__main__":
    main()