import tkinter as tk
from tkinter import ttk, messagebox
from PIL import Image, ImageTk
import threading
from collections import deque
import queue
from constants import TASK_DESCRIPTIONS, PROMPT_TYPES, PRESENTATION_MODES, MODELS, ROLE_DESCRIPTIONS 
from api_functions import ChatSession, make_api_request
import time
from datetime import datetime
import os
import re
import json
from constants import TPM_LIMITS, RPM_LIMITS, TPD_LIMITS, RPD_LIMITS, REQUEST_LOG_FILE, NUM_SESSIONS
from task_flows import WCSTFlow

class ExperimentGUI:
    def __init__(self, master):
        self.master = master
        self.master.title("Executive Function Experiment")
        self.master.geometry("1200x900")
        
        # Use a more modern theme
        style = ttk.Style()
        style.theme_use('winnative')  # You can try other themes like 'vista', 'xpnative', etc.
        style.configure('TLabel', font=('Helvetica', 12))
        style.configure('TButton', font=('Helvetica', 10))
        style.configure('TCheckbutton', font=('Helvetica', 10))
        style.configure('TCombobox', font=('Helvetica', 12))
        style.configure('TFrame', background='#f0f0f0')
        style.configure('Large.TButton', font=('Helvetica', 12))
        
        self.setup_variables()
        self.create_widgets()
        
        self.image_present_time = None
        self.key_press_time = None
        # Monitoring the usage
        self.token_usage = {}
        self.request_count = {}
        # self.last_request_time = {}
        self.request_log = []
        self.load_request_log()
        
        self.chat_session = None
        self.experiment_flow = None
        self.current_response = ''
        # self.current_tokens = 0
        self.current_tokens = []
        self.response_queue = queue.Queue()
        self.color_mapping = {
            'R': 'red',
            'B': 'blue',
            'Y': 'yellow',
            'G': 'green'
        }
    
    def load_request_log(self):
        try:
            with open(REQUEST_LOG_FILE, 'r') as f:
                log_data = json.load(f)
                self.request_log = log_data['requests']
                self.token_usage = log_data['token_usage']
                self.request_count = log_data['request_count']
                # self.last_request_time = log_data['last_request_time']
        except FileNotFoundError:
            self.request_log = []
            self.token_usage = {model: {"minute": 0, "day": 0} for model in MODELS}
            self.request_count = {model: {"minute": 0, "day": 0} for model in MODELS}
            # self.last_request_time = {model: 0 for model in MODELS}

    def save_request_log(self):
        log_data = {
            'requests': self.request_log,
            'token_usage': self.token_usage,
            'request_count': self.request_count,
            # 'last_request_time': self.last_request_time
        }
        with open(REQUEST_LOG_FILE, 'w') as f:
            json.dump(log_data, f)
    
    def setup_variables(self):
        self.task_var = tk.StringVar()
        self.model_var = tk.StringVar()
        self.prompt_type_var = tk.StringVar()
        self.presentation_mode_var = tk.StringVar()
        
        self.simulate_impairment_var = tk.BooleanVar(value=False)
        self.impairment_type_var = tk.StringVar()
        
        self.is_human_var = tk.BooleanVar()
        self.subject_number = tk.StringVar()
        self.session_number = tk.StringVar()
        self.language_var = tk.StringVar(value="English")
        
        # Info variables
        self.current_trial_var = tk.StringVar()
        self.current_rule_var = tk.StringVar()
        self.correct_card_var = tk.StringVar()
        self.correct_in_row_var = tk.StringVar()
        self.correct_color_var = tk.StringVar()
        self.is_match_var = tk.StringVar()
        self.number_of_moves_var = tk.StringVar()
        self.solution_path_var = tk.StringVar()
        self.total_earnings_var = tk.StringVar()
        self.decks_value_var = tk.StringVar()

    def create_widgets(self):
        # Main frame
        main_frame = ttk.Frame(self.master, padding="10")
        main_frame.pack(fill=tk.BOTH, expand=True, padx=20, pady=20)
        
        # Left frame for controls
        left_frame = ttk.Frame(main_frame, width=300)
        left_frame.pack(side=tk.LEFT, fill=tk.Y, padx=(0, 20))
        
        # Right frame for image and prompt
        right_frame = ttk.Frame(main_frame)
        right_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)
        
        # Control Panel
        control_frame = ttk.LabelFrame(left_frame, text="Control Panel", padding="10")
        control_frame.pack(fill=tk.X, pady=(0, 10))
        ttk.Label(control_frame, text="Task:").pack(anchor='w', pady=(0, 5))
        self.task_combobox=ttk.Combobox(control_frame, textvariable=self.task_var, values=list(TASK_DESCRIPTIONS.keys()), width=25)
        self.task_combobox.pack(pady=(0, 10))
        self.task_var.set(list(TASK_DESCRIPTIONS.keys())[0])
        self.task_combobox.bind("<<ComboboxSelected>>", self.on_task_selected)
        
        ttk.Label(control_frame, text="Model:").pack(anchor='w', pady=(0, 5))
        self.model_combobox=ttk.Combobox(control_frame, textvariable=self.model_var, values=MODELS, width=25)
        self.model_combobox.pack(pady=(0, 10))
        self.model_var.set(MODELS[0])

        ttk.Label(control_frame, text="Prompt Type:").pack(anchor='w', pady=(0, 5))
        self.prompt_type_combobox=ttk.Combobox(control_frame, textvariable=self.prompt_type_var, values=list(PROMPT_TYPES.keys()), width=25)
        self.prompt_type_combobox.pack(pady=(0, 10))
        self.prompt_type_var.set(list(PROMPT_TYPES.keys())[0])
        self.prompt_type_combobox.bind("<<ComboboxSelected>>", self.on_prompt_selected)

        ttk.Label(control_frame, text="Presentation Mode:").pack(anchor='w', pady=(0, 5))
        self.presentation_mode_combobox=ttk.Combobox(control_frame, textvariable=self.presentation_mode_var, values=PRESENTATION_MODES, width=25)
        self.presentation_mode_combobox.pack(pady=(0, 10))
        self.presentation_mode_var.set(PRESENTATION_MODES[0])
        self.presentation_mode_combobox.bind("<<ComboboxSelected>>", self.on_presentation_selected)
        
        ttk.Checkbutton(control_frame, text="Simulate Impairment", variable=self.simulate_impairment_var, command=self.toggle_impairment_simulation).pack(anchor='w', pady=(0, 10))

        ttk.Label(control_frame, text="Impairment Type:").pack(anchor='w', pady=(0, 5))
        self.impairment_type_combobox = ttk.Combobox(control_frame, textvariable=self.impairment_type_var, values=list(ROLE_DESCRIPTIONS.keys()), width=25, state='disabled')
        self.impairment_type_combobox.pack(pady=(0, 10))
        self.impairment_type_combobox.bind("<<ComboboxSelected>>", self.on_impairment_type_selected)

        ttk.Checkbutton(control_frame, text="Human Participant", variable=self.is_human_var, command=self.toggle_human_participant).pack(anchor='w', pady=(0, 10))

        # Frame for Subject and Session Number inputs
        self.number_frame = ttk.Frame(control_frame)
        self.number_frame.pack(fill=tk.X, pady=(0, 10))
        # Subject Number input (initially hidden)
        self.subject_label = ttk.Label(self.number_frame, text="Subject Number:")
        self.subject_entry = ttk.Entry(self.number_frame, textvariable=self.subject_number, width=5)
        self.subject_ok = ttk.Button(self.number_frame, text="OK", command=lambda: self.validate_number(self.subject_number))

        # Session Number input
        ttk.Label(self.number_frame, text="Session Number:").grid(row=1, column=0, sticky='w', padx=(0, 5))
        ttk.Entry(self.number_frame, textvariable=self.session_number, width=5).grid(row=1, column=1, sticky='w')
        ttk.Button(self.number_frame, text="OK", command=lambda: self.validate_number(self.session_number)).grid(row=1, column=2, sticky='w', padx=(5, 0))
        
        ttk.Button(control_frame, text="Start Experiment", command=self.start_experiment).pack(fill=tk.X, pady=(0, 5))
        ttk.Button(control_frame, text="Reset Experiment", command=self.reset_experiment).pack(fill=tk.X, pady=(0, 5))
        # ttk.Button(left_frame, text="Preview", command=self.preview_experiment).pack(fill=tk.X)
        ttk.Label(control_frame, text="Language:").pack(anchor='w', pady=(10, 5))
        self.language_combobox = ttk.Combobox(control_frame, textvariable=self.language_var, values=["English", "中文"], width=25)
        self.language_combobox.pack(pady=(0, 10))
        self.language_combobox.bind("<<ComboboxSelected>>", self.on_language_changed)

        # Info labels at the top of right frame
        self.info_frame = ttk.LabelFrame(right_frame, text="Task Information", padding="10")
        self.info_frame.pack(fill=tk.X, pady=(0, 10))
        self.update_info_frame()
        
        # PanedWindow for image and prompt
        paned_window = ttk.PanedWindow(right_frame, orient=tk.HORIZONTAL)
        paned_window.pack(fill=tk.BOTH, expand=True, pady=(0, 10))

        # Image display (left side of paned window)
        image_frame = ttk.Frame(paned_window)
        paned_window.add(image_frame, weight=1)
        ttk.Label(image_frame, text="Image:").pack(anchor='w', pady=(0, 5))
        self.image_frame = ttk.Frame(image_frame, borderwidth=2, relief='groove')
        self.image_frame.pack(fill=tk.BOTH, expand=True)
        self.image_label = ttk.Label(self.image_frame)
        # self.image_label.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
        self.image_label.pack(fill=tk.BOTH, expand=True)

        # Prompt display (right side of paned window)
        prompt_frame = ttk.Frame(paned_window)
        paned_window.add(prompt_frame, weight=1)
        ttk.Label(prompt_frame, text="Prompt:").pack(anchor='w', pady=(0, 5))
        self.prompt_text = tk.Text(prompt_frame, wrap=tk.WORD, font=('Microsoft YaHei', 12)) # Helvetica
        self.prompt_text.pack(fill=tk.BOTH, expand=True)

        # Set initial position of sash (divider)
        right_frame.update()
        paned_window.sashpos(0, int(paned_window.winfo_width() * 0.4))

        # Response display
        self.response_frame = ttk.LabelFrame(right_frame, text="Response", padding="10")
        # self.response_frame = ttk.Frame(right_frame)
        self.response_frame.pack(fill=tk.BOTH, expand=True, pady=(10, 0))
        # ttk.Label(self.response_frame, text="Response:").pack(anchor='w')
        # AI response text (initially shown)
        self.response_text = tk.Text(self.response_frame, height=5, width=60, wrap=tk.WORD, font=("Helvetica", 12))
        self.response_text.pack(fill=tk.BOTH, expand=True)
        self.response_text.config(state=tk.DISABLED)
        self.retry_button = ttk.Button(self.response_frame, text="Retry", command=self.retry_api_request)
        self.retry_button.pack(side=tk.BOTTOM, pady=5)
        self.retry_button.pack_forget()  # 初始时隐藏按钮
        
        # Human response widgets (initially hidden)
        self.human_response_frame = ttk.Frame(self.response_frame)
        self.human_response_frame.pack(fill=tk.BOTH, expand=True)
        # WCST buttons
        self.wcst_buttons = [ttk.Button(self.human_response_frame, text=str(i), command=lambda x=i: self.human_response(x), style='Large.TButton') for i in range(1, 5)]
        # Stroop instructions
        self.stroop_label = ttk.Label(self.human_response_frame, text="Press 'R' for Red, 'B' for Blue, 'Y' for Yellow, 'G' for Green")
        # N-back buttons
        self.nback_match_btn = ttk.Button(self.human_response_frame, text="Match", command=lambda: self.human_response(True), style='Large.TButton')
        self.nback_nomatch_btn = ttk.Button(self.human_response_frame, text="No Match", command=lambda: self.human_response(False), style='Large.TButton')
        # Tower of London widgets
        self.tol_frame = ttk.Frame(self.human_response_frame)
        self.tol_moves = []
        # Iowa Gambling buttons
        self.igt_buttons = [ttk.Button(self.human_response_frame, text=chr(65+i), command=lambda x=chr(65+i): self.human_response(x), style='Large.TButton') for i in range(4)]
        
        # Feedback frame
        # self.feedback_frame = ttk.Frame(right_frame)
        self.feedback_frame = ttk.LabelFrame(right_frame, text="Feedback")
        self.feedback_frame.pack(fill=tk.BOTH, expand=True, pady=(10, 0))
        
        self.feedback_title_frame = ttk.Frame(self.feedback_frame)
        self.feedback_title_frame.pack(fill=tk.X,side=tk.TOP)
        # ttk.Label(self.feedback_title_frame, text="Feedback:").pack(side=tk.LEFT)
        
        # AI feedback buttons (initially shown)
        self.ai_feedback_frame = ttk.Frame(self.feedback_title_frame)
        self.ai_feedback_frame.pack(side=tk.RIGHT)
        
        ttk.Button(self.ai_feedback_frame, text="Continue", command=self.continue_experiment).pack(side=tk.RIGHT, padx=(0, 5))
        ttk.Button(self.ai_feedback_frame, text="D", command=lambda: self.igt_calculate('D')).pack(side=tk.RIGHT, padx=(0, 35))
        ttk.Button(self.ai_feedback_frame, text="C", command=lambda: self.igt_calculate('C')).pack(side=tk.RIGHT, padx=(0, 5))
        ttk.Button(self.ai_feedback_frame, text="B", command=lambda: self.igt_calculate('B')).pack(side=tk.RIGHT, padx=(0, 5))
        ttk.Button(self.ai_feedback_frame, text="A", command=lambda: self.igt_calculate('A')).pack(side=tk.RIGHT, padx=(0, 5))
        ttk.Button(self.ai_feedback_frame, text="Incorrect", command=lambda: self.evaluate_response(False)).pack(side=tk.RIGHT, padx=(0, 35))
        ttk.Button(self.ai_feedback_frame, text="Correct", command=lambda: self.evaluate_response(True)).pack(side=tk.RIGHT)
        
        # Feedback text
        self.feedback_text = tk.Text(self.feedback_frame, height=3, width=60, wrap=tk.WORD, font=('Microsoft YaHei', 16))
        self.feedback_text.pack(fill=tk.BOTH, expand=True)
        
        self.preview_experiment()
    
    def on_impairment_type_selected(self, event):
        self.preview_experiment()
    
    def toggle_impairment_simulation(self):
        if self.simulate_impairment_var.get():
            self.impairment_type_combobox.config(state='readonly')
        else:
            self.impairment_type_combobox.config(state='disabled')
            self.impairment_type_var.set('')  # Clear the selection when disabling
        self.preview_experiment()

    def on_language_changed(self, event):
        self.preview_experiment()

    def on_task_selected(self, event):
        self.update_info_frame()
        self.preview_experiment()
        if self.is_human_var.get():
            self.update_human_response_widgets()
    def on_prompt_selected(self, event):
        self.preview_experiment()
        if self.is_human_var.get():
            self.update_human_response_widgets()
    def on_presentation_selected(self, event):
        self.preview_experiment()
        if self.is_human_var.get():
            self.update_human_response_widgets()
    
    def update_info_frame(self):
        for widget in self.info_frame.winfo_children():
            widget.destroy()

        if self.is_human_var.get():
            ttk.Label(self.info_frame, text="Current Trial:").grid(row=0, column=0, sticky='e', padx=(0, 5))
            ttk.Label(self.info_frame, textvariable=self.current_trial_var).grid(row=0, column=1, sticky='w')
        else:
            if self.task_var.get() == "WCST":
                self.create_wcst_info()
            elif self.task_var.get() == "WCST_without_restriction":
                self.create_wcst_info()
            
        # Configure grid
        self.info_frame.columnconfigure(1, weight=1)
        self.info_frame.columnconfigure(3, weight=1)
    
    def create_wcst_info(self):
        ttk.Label(self.info_frame, text="Current Trial:").grid(row=0, column=0, sticky='e', padx=(0, 5))
        ttk.Label(self.info_frame, textvariable=self.current_trial_var).grid(row=0, column=1, sticky='w')
        ttk.Label(self.info_frame, text="Current Rule:").grid(row=0, column=2, sticky='e', padx=(20, 5))
        ttk.Label(self.info_frame, textvariable=self.current_rule_var).grid(row=0, column=3, sticky='w')
        ttk.Label(self.info_frame, text="Correct Card:").grid(row=1, column=0, sticky='e', padx=(0, 5))
        ttk.Label(self.info_frame, textvariable=self.correct_card_var).grid(row=1, column=1, sticky='w')
        ttk.Label(self.info_frame, text="Correct in Row:").grid(row=1, column=2, sticky='e', padx=(20, 5))
        ttk.Label(self.info_frame, textvariable=self.correct_in_row_var).grid(row=1, column=3, sticky='w')
    def create_stroop_info(self):
        ttk.Label(self.info_frame, text="Current Trial:").grid(row=0, column=0, sticky='e', padx=(0, 5))
        ttk.Label(self.info_frame, textvariable=self.current_trial_var).grid(row=0, column=1, sticky='w')
        ttk.Label(self.info_frame, text="Correct Color:").grid(row=0, column=2, sticky='e', padx=(20, 5))
        ttk.Label(self.info_frame, textvariable=self.correct_color_var).grid(row=0, column=3, sticky='w')
    def create_nback_info(self):
        ttk.Label(self.info_frame, text="Current Trial:").grid(row=0, column=0, sticky='e', padx=(0, 5))
        ttk.Label(self.info_frame, textvariable=self.current_trial_var).grid(row=0, column=1, sticky='w')
        ttk.Label(self.info_frame, text="Is Match:").grid(row=0, column=2, sticky='e', padx=(20, 5))
        ttk.Label(self.info_frame, textvariable=self.is_match_var).grid(row=0, column=3, sticky='w')
        ttk.Label(self.info_frame, text="Correct in Row:").grid(row=1, column=0, sticky='e', padx=(0, 5))
        ttk.Label(self.info_frame, textvariable=self.correct_in_row_var).grid(row=1, column=1, sticky='w')
    def create_tol_info(self):
        ttk.Label(self.info_frame, text="Current Trial:").grid(row=0, column=0, sticky='e', padx=(0, 5))
        ttk.Label(self.info_frame, textvariable=self.current_trial_var).grid(row=0, column=1, sticky='w')
        ttk.Label(self.info_frame, text="Number of moves:").grid(row=0, column=2, sticky='e', padx=(20, 5))
        ttk.Label(self.info_frame, textvariable=self.number_of_moves_var).grid(row=0, column=3, sticky='w')
        ttk.Label(self.info_frame, text="Solution Path:").grid(row=1, column=0, sticky='e', padx=(0, 5))
        ttk.Label(self.info_frame, textvariable=self.solution_path_var).grid(row=1, column=1, sticky='w')
        ttk.Label(self.info_frame, text="Correct in Row:").grid(row=1, column=2, sticky='e', padx=(20, 5))
        ttk.Label(self.info_frame, textvariable=self.correct_in_row_var).grid(row=1, column=3, sticky='w')
    def create_igt_info(self):
        ttk.Label(self.info_frame, text="Current Trial:").grid(row=0, column=0, sticky='e', padx=(0, 5))
        ttk.Label(self.info_frame, textvariable=self.current_trial_var).grid(row=0, column=1, sticky='w')
        ttk.Label(self.info_frame, text="Total Earnings:").grid(row=0, column=2, sticky='e', padx=(20, 5))
        ttk.Label(self.info_frame, textvariable=self.total_earnings_var).grid(row=0, column=3, sticky='w')
        ttk.Label(self.info_frame, text="Decks Value:").grid(row=1, column=0, sticky='e', padx=(0, 5))
        ttk.Label(self.info_frame, textvariable=self.decks_value_var).grid(row=1, column=1, sticky='w', columnspan=3)
    
    def toggle_human_participant(self):
        is_human = self.is_human_var.get()
        # Update comboboxes
        # self.task_combobox.config(state='readonly' if is_human else 'normal')
        self.model_combobox.config(state='disabled' if is_human else 'normal')
        self.prompt_type_combobox.config(state='disabled' if is_human else 'normal')
        self.presentation_mode_combobox.config(state='disabled' if is_human else 'normal')
        self.simulate_impairment_var.set(False)  # Disable impairment simulation for human participants
        self.impairment_type_combobox.config(state='disabled')
        self.impairment_type_var.set('')
        # Update values
        if is_human:
            self.model_var.set('')
            self.prompt_type_var.set(list(PROMPT_TYPES.keys())[0])
            self.presentation_mode_var.set(PRESENTATION_MODES[0])
        # Toggle subject number visibility
        if is_human:
            self.subject_label.grid(row=0, column=0, sticky='w', padx=(0, 5))
            self.subject_entry.grid(row=0, column=1, sticky='w')
            self.subject_ok.grid(row=0, column=2, sticky='w', padx=(5, 0))
        else:
            self.subject_label.grid_remove()
            self.subject_entry.grid_remove()
            self.subject_ok.grid_remove()
        # Toggle response widgets
        if is_human:
            self.response_text.pack_forget()
            self.human_response_frame.pack(fill=tk.BOTH, expand=True)
        else:
            self.human_response_frame.pack_forget()
            self.response_text.pack(fill=tk.BOTH, expand=True)
        # Toggle feedback buttons
        if is_human:
            self.ai_feedback_frame.pack_forget()
        else:
            self.ai_feedback_frame.pack(side=tk.RIGHT)
        self.update_human_response_widgets()
        # Update info frame
        self.update_info_frame()
    
    def update_human_response_widgets(self):
        # Hide all widgets first
        for widget in self.human_response_frame.winfo_children():
            widget.grid_forget()
        
        if not self.is_human_var.get():
            return
        
        task = self.task_var.get()
        
        # Configure grid
        self.human_response_frame.grid_columnconfigure(0, weight=1)
        self.human_response_frame.grid_columnconfigure(1, weight=1)
        self.human_response_frame.grid_columnconfigure(2, weight=1)
        self.human_response_frame.grid_columnconfigure(3, weight=1)
        
        if task == "WCST":
            for i, btn in enumerate(self.wcst_buttons):
                btn.grid(row=0, column=i, padx=5, pady=5, sticky='nsew')
        
        elif task == "Stroop":
            self.stroop_label.grid(row=0, column=0, columnspan=4, pady=10, sticky='nsew')
            self.master.bind('<Key>', self.stroop_key_press)
        
        elif task == "N-back":
            self.nback_match_btn.grid(row=0, column=1, padx=5, pady=5, sticky='nsew')
            self.nback_nomatch_btn.grid(row=0, column=2, padx=5, pady=5, sticky='nsew')
        
        elif task == "Tower_of_London":
            self.tol_frame.grid(row=0, column=0, columnspan=4, sticky='nsew')
            self.human_response_frame.grid_rowconfigure(0, weight=1)
            
            # Clear existing widgets
            for widget in self.tol_frame.winfo_children():
                widget.destroy()
            
            # Configure ToL frame grid
            for i in range(7):
                self.tol_frame.grid_columnconfigure(i, weight=1)
            for i in range(3):
                self.tol_frame.grid_rowconfigure(i, weight=1)
            
            # First row
            ttk.Label(self.tol_frame, text="Number of moves:").grid(row=0, column=0, columnspan=2, padx=5, pady=5, sticky='w')
            self.tol_entry = ttk.Entry(self.tol_frame, width=5)
            self.tol_entry.grid(row=0, column=2, padx=5, pady=5, sticky='w')
            
            # Second row
            self.tol_move_buttons = [
                ttk.Button(self.tol_frame, text=move, command=lambda m=move: self.add_tol_move(m))
                for move in ["【1->2】", "【1->3】", "【2->1】", "【2->3】", "【3->1】", "【3->2】"]
            ]
            for i, btn in enumerate(self.tol_move_buttons):
                btn.grid(row=1, column=i, padx=2, pady=5, sticky='nsew')
            
            # Add delete button
            self.tol_delete_btn = ttk.Button(self.tol_frame, text="Delete", command=self.delete_last_tol_move)
            self.tol_delete_btn.grid(row=1, column=6, padx=(15, 2), pady=5, sticky='nsew')
            
            # Third row
            self.tol_moves_label = ttk.Label(self.tol_frame, text="Moves: ")
            self.tol_moves_label.grid(row=2, column=0, columnspan=6, padx=5, pady=5, sticky='w')
            
            self.tol_submit_btn = ttk.Button(self.tol_frame, text="Submit", command=self.submit_tol_response)
            self.tol_submit_btn.grid(row=2, column=6, padx=5, pady=5, sticky='nsew')
            
            # Update delete button state
            self.update_tol_delete_button()
        
        elif task == "Iowa_Gambling":
            for i, btn in enumerate(self.igt_buttons):
                btn.grid(row=0, column=i, padx=5, pady=5, sticky='nsew')
    
    def human_response(self, response,time_records = []):
        if self.task_var.get() == "WCST":
            self.experiment_flow.current_response = f"Selection: {response}"
        elif self.task_var.get() == "WCST_without_restriction":
            self.experiment_flow.current_response = f"Selection: {response}"
        feedback = self.experiment_flow.evaluate_response(response, self.presentation_mode_var.get(), self.prompt_type_var.get(), self.is_human_var.get())
        self.feedback_text.delete(1.0, tk.END)
        self.feedback_text.insert(tk.END, feedback)
        self.master.after(1000, self.continue_experiment)
        
    def stroop_key_press(self, event):
        key = event.char.upper()
        if key in ['R', 'B', 'Y', 'G']:
            time_record = []
            self.master.unbind('<Key>')
            self.key_press_time = time.time()
            reaction_time = self.key_press_time - self.image_present_time
            image_present_time = datetime.fromtimestamp(self.image_present_time).strftime(r'%Y-%m-%d-%H:%M:%S.%f')
            key_press_time = datetime.fromtimestamp(self.key_press_time).strftime(r'%Y-%m-%d-%H:%M:%S.%f')
            time_record = [image_present_time,key_press_time,reaction_time]
            full_color = self.color_mapping[key]
            self.human_response(full_color,time_record)
    
    def add_tol_move(self, move):
        self.tol_moves.append(move)
        self.update_tol_moves_label()
        self.update_tol_delete_button()
    
    def delete_last_tol_move(self):
        if self.tol_moves:
            self.tol_moves.pop()
            self.update_tol_moves_label()
            self.update_tol_delete_button()
            
    def update_tol_moves_label(self):
        self.tol_moves_label.config(text="Moves: " + " ".join(self.tol_moves))
    
    def update_tol_delete_button(self):
        if self.tol_moves:
            self.tol_delete_btn.config(state="normal")
        else:
            self.tol_delete_btn.config(state="disabled")
    
    def convert_response(self, response):
        moves = response.split('】 【')
        return [[int(x)-1 for x in move.strip('【】').split('->')] for move in moves]
    
    def submit_tol_response(self):
        num_moves = self.tol_entry.get()
        moves = " ".join(self.tol_moves)
        # response = f"{num_moves} {moves}"
        moves = self.convert_response(moves)
        response = [num_moves, moves]
        self.human_response(response)
        self.tol_moves = []
        self.update_human_response_widgets()  # Refresh the widgets
    
    def validate_number(self, var):
        try:
            num = int(var.get())
            if 1 <= num <= 100:
                messagebox.showinfo("Success", f"Number set to {num}")
            else:
                raise ValueError
        except ValueError:
            messagebox.showerror("Error", "Please enter a number between 1 and 100")
            var.set("")
    
    def display_prompt(self, prompt):
        self.prompt_text.delete(1.0, tk.END)
        self.prompt_text.insert(tk.END, prompt)
    
    def display_image(self, image_path):
        if not image_path:
            # Clear the current image
            self.image_label.config(image='')
            self.image_label.image = None
            return
        image = Image.open(image_path)
        self.image_frame.update()
        frame_width = self.image_frame.winfo_width()
        frame_height = self.image_frame.winfo_height()
        # Calculate the scaling factor
        original_width, original_height = image.size
        width_ratio = frame_width / original_width
        height_ratio = frame_height / original_height
        ratio = min(width_ratio, height_ratio)
        new_width = int(original_width * ratio)
        new_height = int(original_height * ratio)
        image = image.resize((new_width, new_height), Image.LANCZOS)
        photo = ImageTk.PhotoImage(image)
        self.image_label.config(image=photo)
        self.image_label.image = photo
        self.image_label.place(relx=0.5, rely=0.5, anchor='center')
    
    def preview_experiment(self):
        if not all([self.task_var.get(), self.prompt_type_var.get(), self.presentation_mode_var.get()]):
            messagebox.showerror("Error", "Please select all options before starting the experiment.")
            return
        self.chat_session = None
        impairment_type = self.impairment_type_var.get() if self.simulate_impairment_var.get() else None
        if self.task_var.get() == "WCST":
            self.experiment_flow = WCSTFlow(self.chat_session, self.is_human_var.get(),1, self.language_var.get(), impairment_type=impairment_type)
        elif self.task_var.get() == "WCST_without_restriction":
            self.experiment_flow = WCSTFlow(self.chat_session, self.is_human_var.get(),1, self.language_var.get(), rule=False, impairment_type=impairment_type)
        
        image_path, prompt = self.experiment_flow.run_trial(self.presentation_mode_var.get(), self.prompt_type_var.get())
        if self.presentation_mode_var.get() in ["OI", "OIT"]:
            self.display_image(image_path)
        else:
            self.display_image(None)
        self.display_prompt(prompt)

    def update_info_labels(self,first_val='',second_val='',third_val=''):
        self.current_trial_var.set(f"{self.experiment_flow.current_trial + 1}/{self.experiment_flow.total_trials}")
        if self.task_var.get() == "WCST":
            self.current_rule_var.set(f"{first_val}")
            self.correct_card_var.set(f"{second_val}")
            self.correct_in_row_var.set(f"{third_val}")
        elif self.task_var.get() == "WCST_without_restriction":
            self.current_rule_var.set(f"{first_val}")
            self.correct_card_var.set(f"{second_val}")
            self.correct_in_row_var.set(f"{third_val}")

    def reset_experiment(self):
        self.subject_number.set("")
        self.session_number.set("")
        
        # Reset info variables
        self.current_trial_var.set("")
        self.current_rule_var.set("")
        self.correct_card_var.set("")
        self.correct_in_row_var.set("")
        self.correct_color_var.set("")
        self.is_match_var.set("")
        self.number_of_moves_var.set("")
        self.solution_path_var.set("")
        self.total_earnings_var.set("")
        self.decks_value_var.set("")
        
        # Clear text widgets
        self.prompt_text.delete(1.0, tk.END)
        self.response_text.config(state=tk.NORMAL)
        self.response_text.delete(1.0, tk.END)
        self.response_text.config(state=tk.DISABLED)
        self.feedback_text.delete(1.0, tk.END)
        
        # Reset image
        self.display_image(None)
        
        # Reset experiment flow and chat session
        self.chat_session = None
        self.experiment_flow = None
        self.current_response = ''
        # self.current_tokens = 0
        self.current_tokens = []
        
        self.simulate_impairment_var.set(False)
        self.impairment_type_var.set('')
        self.impairment_type_combobox.config(state='disabled')
        
        # Clear the response queue
        while not self.response_queue.empty():
            try:
                self.response_queue.get_nowait()
            except queue.Empty:
                break
            
        # Update the UI
        self.toggle_human_participant()
        self.update_info_frame()
        self.preview_experiment()
        self.load_request_log()
        
        os.environ['HTTP_PROXY'] = None
    
    def start_experiment(self):
        # Validate inputs
        if not self.is_human_var.get():
            if not all([self.task_var.get(), self.model_var.get(), self.prompt_type_var.get(), self.presentation_mode_var.get()]):
                messagebox.showerror("Error", "Please select all options before starting the experiment.")
                return
        else:
            if not all([self.task_var.get(), self.prompt_type_var.get(), self.presentation_mode_var.get()]):
                messagebox.showerror("Error", "Please select all options before starting the experiment.")
                return
        if self.is_human_var.get() and not self.subject_number.get():
            messagebox.showerror("Error", "Please enter a Subject Number for human participants.")
            return
        if not self.session_number.get():
            messagebox.showerror("Error", "Please enter a Session Number.")
            return
        
        if not self.is_human_var.get():
            self.chat_session = ChatSession(self.model_var.get())
        
        impairment_type = self.impairment_type_var.get() if self.simulate_impairment_var.get() else None

        if self.task_var.get() == "WCST":
            self.experiment_flow = WCSTFlow(self.chat_session, self.is_human_var.get(),int(self.session_number.get()), self.language_var.get(), impairment_type=impairment_type)
        elif self.task_var.get() == "WCST_without_restriction":
            self.experiment_flow = WCSTFlow(self.chat_session, self.is_human_var.get(),int(self.session_number.get()), self.language_var.get(), rule=False, impairment_type=impairment_type)
        
        if self.is_human_var.get() and self.subject_number.get():
            self.experiment_flow.data_logger.start_new_session(
                self.task_var.get(),
                self.model_var.get(),
                self.prompt_type_var.get(),
                self.presentation_mode_var.get(),
                self.is_human_var.get(),
                self.session_number.get(),
                self.subject_number.get(),
                None,
            )
        else:
            self.experiment_flow.data_logger.start_new_session(
                self.task_var.get(),
                self.model_var.get(),
                self.prompt_type_var.get(),
                self.presentation_mode_var.get(),
                self.is_human_var.get(),
                self.session_number.get(),
                None,
                impairment_type,
            )
        self.run_experiment()
    
    def run_experiment(self):
        if self.experiment_flow.is_experiment_complete():
            self.show_results()
            self.experiment_flow.data_logger.end_session()
            return

        image_path, prompt = self.experiment_flow.run_trial(self.presentation_mode_var.get(), self.prompt_type_var.get())
        self.update_info_labels(*self.experiment_flow.info_list)
        
        if self.presentation_mode_var.get() in ["OI", "OIT"]:
            self.display_image(image_path)
            self.image_present_time = time.time()
        else:
            image_path = None
            self.display_image(None)
        self.display_prompt(prompt)

        if self.is_human_var.get():
            self.update_human_response_widgets()
        else:
            self.update_response_text("Generating response...")
            threading.Thread(target=self.make_api_requests, args=(prompt, image_path)).start()
            self.master.after(100, self.check_response_queue)
    
    def continue_experiment(self):
        self.experiment_flow.finalize_current_trial()
        self.feedback_text.delete(1.0, tk.END)
        self.run_experiment()
    
    def evaluate_response(self, feedback):
        self.experiment_flow.current_response = self.current_response
        feedback = self.experiment_flow.evaluate_response(feedback, self.presentation_mode_var.get(), self.prompt_type_var.get(),  tokens=self.current_tokens )
        self.feedback_text.delete(1.0, tk.END)
        self.feedback_text.insert(tk.END, feedback)
        # self.run_experiment()

    def make_api_requests(self, prompt, image_path):
        model = self.model_var.get()
        current_time = time.time()
        if not self.check_request_limits(model):
            self.response_queue.put("Error: Request limit exceeded. Please try again later.")
            return
        
        try:
            if self.presentation_mode_var.get() == "OT":
                if self.chat_session:
                    response, tokens = self.chat_session.get_response(prompt, None)
                else:
                    response, tokens = make_api_request(self.model_var.get(), prompt, None)
            else:
                if self.chat_session:
                    response, tokens = self.chat_session.get_response(prompt, image_path)
                else:
                    response, tokens = make_api_request(self.model_var.get(), prompt, image_path)
            self.current_response = response
            self.current_tokens = tokens
            self.response_queue.put(self.current_response)
            # print(self.chat_session.messages)
            self.update_usage(model, tokens, current_time)
        except Exception as e:
            self.response_queue.put(f"Error: {str(e)}")
            self.retry_button.pack(side=tk.BOTTOM, pady=5)

    def check_request_limits(self, model):
        current_time = time.time()
        # elapsed_time = current_time - self.last_request_time[model]
        one_minute_ago = current_time - 60
        one_day_ago = current_time - 86400
        
        self.request_log = [req for req in self.request_log if req['time'] > one_day_ago]

        minute_tokens = sum(req['tokens'] for req in self.request_log if req['time'] > one_minute_ago and req['model'] == model)
        day_tokens = sum(req['tokens'] for req in self.request_log if req['model'] == model)
        minute_requests = sum(1 for req in self.request_log if req['time'] > one_minute_ago and req['model'] == model)
        day_requests = sum(1 for req in self.request_log if req['model'] == model)
        # Update token_usage and request_count
        self.token_usage[model]["minute"] = minute_tokens
        self.token_usage[model]["day"] = day_tokens
        self.request_count[model]["minute"] = minute_requests
        self.request_count[model]["day"] = day_requests

        # check limit
        if (minute_tokens >= TPM_LIMITS[model] or
            minute_requests >= RPM_LIMITS[model] or
            (TPD_LIMITS[model] and day_tokens >= TPD_LIMITS[model]) or
            (RPD_LIMITS[model] and day_requests >= RPD_LIMITS[model])):
            return False

        return True
    
    def update_usage(self, model, tokens, request_time):
        self.token_usage[model]["minute"] += tokens[2]
        self.token_usage[model]["day"] += tokens[2]
        self.request_count[model]["minute"] += 1
        self.request_count[model]["day"] += 1
        # self.last_request_time[model] = time.time()
        request_info = {
            "model": model,
            "tokens": tokens[2],
            "time": request_time
        }
        self.request_log.append(request_info)
        self.save_request_log()
    
    def retry_api_request(self):
        self.retry_button.pack_forget()
        self.run_experiment()
    
    def check_response_queue(self):
        try:
            response = self.response_queue.get_nowait()
            self.update_response_text(response)
            if response.startswith("Error:"):
                self.retry_button.pack(side=tk.BOTTOM, pady=5)
            else:
                self.retry_button.pack_forget()
        except queue.Empty:
            self.master.after(100, self.check_response_queue)
    
    def update_response_text(self, response):
        self.response_text.config(state=tk.NORMAL)
        self.response_text.delete(1.0, tk.END)
        self.response_text.insert(tk.END, response)
        self.response_text.tag_configure("new_response", background="#e6f3ff", font=("Helvetica", 12, "bold"))
        self.response_text.tag_add("new_response", "1.0", tk.END)
        self.response_text.config(state=tk.DISABLED)

    def show_results(self):
        self.feedback_text.delete(1.0, tk.END)
        self.feedback_text.insert(tk.END, "Experiment Complete")

class ExperimentParallelGUI:
    def __init__(self, master):
        self.master = master
        self.master.title("Parallel Executive Function Experiment")
        self.master.geometry("3000x1000")
        # Store references to session widgets
        self.session_widgets = []
        self.valid_checkboxes = []
        self.auto_run = tk.BooleanVar(value=False)
        self.error_counts = [0] * 20
        self.max_retries = 5
        self.paused_sessions = set()
        self.stopped_sessions = set()
        
        self.setup_variables()
        self.create_widgets()
        
        self.token_usage = {}
        self.request_count = {}
        self.request_log = []
        self.load_request_log()
        
        self.chat_sessions = [None] * 20
        self.experiment_flows = [None] * 20
        self.current_responses = [''] * 20
        # self.current_tokens = [0] * 20
        self.current_tokens = [[] for _ in range(20)]
        self.response_queues = [queue.Queue() for _ in range(20)]
        self.request_queues = {model: deque() for model in MODELS}
        self.last_request_time = {model: 0 for model in MODELS}
          
    def setup_variables(self):
        self.task_var = tk.StringVar()
        self.model_var = tk.StringVar()
        self.prompt_type_var = tk.StringVar()
        self.presentation_mode_var = tk.StringVar()
        
        self.simulate_impairment_var = tk.BooleanVar(value=False)
        self.impairment_type_var = tk.StringVar()
        
        self.is_human_var = tk.BooleanVar()
        self.session_valid_vars = [tk.BooleanVar(value=True) for _ in range(20)]
    
    def create_widgets(self):
        main_frame = ttk.Frame(self.master, padding="10")
        main_frame.pack(fill=tk.BOTH, expand=True)
        
        control_frame = ttk.Frame(main_frame, width=60)
        control_frame.pack(side=tk.LEFT, fill=tk.Y, padx=(0, 20))
        
        sessions_frame = ttk.Frame(main_frame)
        sessions_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)
        
        ttk.Label(control_frame, text="Task:").grid(row=0, column=0, sticky='w', padx=(0, 5))
        self.task_combobox = ttk.Combobox(control_frame, textvariable=self.task_var, values=list(TASK_DESCRIPTIONS.keys()), width=20)
        self.task_combobox.grid(row=0, column=1, padx=(0, 10))
        self.task_var.set(list(TASK_DESCRIPTIONS.keys())[0])
        
        ttk.Label(control_frame, text="Model:").grid(row=1, column=0, sticky='w', padx=(0, 5))
        self.model_combobox = ttk.Combobox(control_frame, textvariable=self.model_var, values=MODELS, width=20)
        self.model_combobox.grid(row=1, column=1, padx=(0, 10))
        self.model_var.set(MODELS[0])
        
        ttk.Label(control_frame, text="Prompt Type:").grid(row=2, column=0, sticky='w', padx=(0, 5))
        self.prompt_type_combobox = ttk.Combobox(control_frame, textvariable=self.prompt_type_var, values=list(PROMPT_TYPES.keys()), width=20)
        self.prompt_type_combobox.grid(row=2, column=1, padx=(0, 10))
        self.prompt_type_var.set(list(PROMPT_TYPES.keys())[0])
        
        ttk.Label(control_frame, text="Presentation Mode:").grid(row=3, column=0, sticky='w', padx=(0, 5))
        self.presentation_mode_combobox = ttk.Combobox(control_frame, textvariable=self.presentation_mode_var, values=PRESENTATION_MODES, width=20)
        self.presentation_mode_combobox.grid(row=3, column=1, padx=(0, 10))
        self.presentation_mode_var.set(PRESENTATION_MODES[0])
        
        ttk.Checkbutton(control_frame, text="Simulate Impairment", variable=self.simulate_impairment_var, command=self.toggle_impairment_simulation).grid(row=4, column=1, sticky='w')
        ttk.Label(control_frame, text="Impairment Type:").grid(row=5, column=0, sticky='w', padx=(0, 5))
        self.impairment_type_combobox = ttk.Combobox(control_frame, textvariable=self.impairment_type_var, values=list(ROLE_DESCRIPTIONS.keys()), width=20, state='disabled')
        self.impairment_type_combobox.grid(row=5, column=1, padx=(0, 10))
        
        ttk.Checkbutton(control_frame, text="Human Participant", variable=self.is_human_var, command=self.toggle_human_participant).grid(row=6, column=0, columnspan=2, sticky='w')
        
        select_all_button = ttk.Button(control_frame, text="Select All Valid", command=self.select_all_valid)
        select_all_button.grid(row=7, column=0, pady=(0, 10), padx=(0, 5), sticky='ew')
        
        deselect_all_button = ttk.Button(control_frame, text="Deselect All Valid", command=self.deselect_all_valid)
        deselect_all_button.grid(row=7, column=1, pady=(0, 10), padx=(5, 0), sticky='ew')
        
        ttk.Button(control_frame, text="Start All Experiments", command=self.start_all_experiments).grid(row=8, column=0, columnspan=2, padx=(0, 5), sticky='ew')
        ttk.Button(control_frame, text="Reset All Experiments", command=self.reset_all_experiments).grid(row=9, column=0, columnspan=2, padx=(0, 5), sticky='ew')
        
        ttk.Checkbutton(control_frame, text="Auto Run", variable=self.auto_run).grid(row=10, column=0, columnspan=2, sticky='w')
        
        # Create a canvas and scrollbar
        canvas = tk.Canvas(sessions_frame)
        scrollbar = ttk.Scrollbar(sessions_frame, orient="vertical", command=canvas.yview)
        scrollable_frame = ttk.Frame(canvas)
        
        scrollable_frame.bind(
            "<Configure>",
            lambda e: canvas.configure(
                scrollregion=canvas.bbox("all")
            )
        )
        
        canvas.create_window((0, 0), window=scrollable_frame, anchor="nw")
        canvas.configure(yscrollcommand=scrollbar.set)
        
        # Pack the canvas and scrollbar
        canvas.pack(side="left", fill="both", expand=True)
        scrollbar.pack(side="right", fill="y")
        
        # Create session frames
        for i in range(20):
            session_frame = ttk.LabelFrame(scrollable_frame, text=f"Session {i+1}")
            # session_frame.grid(row=i//4, column=i%4, padx=5, pady=5, sticky='nsew')
            session_frame.grid(row=i//7, column=i%7, padx=5, pady=5, sticky='nsew')
            
            session_widgets = {}
            
            control_frame = ttk.Frame(session_frame)
            control_frame.grid(row=0, column=0, sticky='w')
            
            session_widgets['valid_check'] = ttk.Checkbutton(control_frame, text="Valid", variable=self.session_valid_vars[i])
            session_widgets['valid_check'].pack(side=tk.LEFT)
            self.valid_checkboxes.append(self.session_valid_vars[i])
            
            # Add the new Stop button next to the Valid checkbox
            session_widgets['stop_button'] = ttk.Button(control_frame, text="Stop", command=lambda i=i: self.stop_experiment(i))
            session_widgets['stop_button'].pack(side=tk.LEFT, padx=(5, 0))
            
            info_frame = ttk.Frame(session_frame)
            info_frame.grid(row=1, column=0, sticky='nsew')
            ttk.Label(info_frame, text="Current Trial:").grid(row=0, column=0, sticky='w')
            session_widgets['trial_label'] = ttk.Label(info_frame, text="0/0")
            session_widgets['trial_label'].grid(row=0, column=1, sticky='w')
            
            # Add more labels for additional information
            ttk.Label(info_frame, text="").grid(row=1, column=0, sticky='w')
            session_widgets['rule_label'] = ttk.Label(info_frame, text="")
            session_widgets['rule_label'].grid(row=1, column=1, sticky='w')
            
            ttk.Label(info_frame, text="").grid(row=2, column=0, sticky='w')
            session_widgets['correct_card_label'] = ttk.Label(info_frame, text="")
            session_widgets['correct_card_label'].grid(row=2, column=1, sticky='w')
            
            ttk.Label(info_frame, text="").grid(row=3, column=0, sticky='w')
            session_widgets['correct_in_row_label'] = ttk.Label(info_frame, text="")
            session_widgets['correct_in_row_label'].grid(row=3, column=1, sticky='w')
            
            response_frame = ttk.LabelFrame(session_frame, text="Response")
            response_frame.grid(row=2, column=0, sticky='nsew')
            session_widgets['response_text'] = tk.Text(response_frame, height=3, width=30, wrap=tk.WORD)
            session_widgets['response_text'].pack(fill=tk.BOTH, expand=True)
            session_widgets['response_text'].config(state=tk.DISABLED)
            
            feedback_frame = ttk.LabelFrame(session_frame, text="Feedback")
            feedback_frame.grid(row=3, column=0, sticky='nsew')
            session_widgets['feedback_text'] = tk.Text(feedback_frame, height=3, width=30, wrap=tk.WORD)
            session_widgets['feedback_text'].pack(fill=tk.BOTH, expand=True)
            
            button_frame = ttk.Frame(session_frame)
            button_frame.grid(row=4, column=0, sticky='nsew')
            ttk.Button(button_frame, text="Correct", command=lambda i=i: self.evaluate_response(i, True, auto_press_gui=True)).grid(row=0, column=0, padx=2)
            ttk.Button(button_frame, text="Incorrect", command=lambda i=i: self.evaluate_response(i, False, auto_press_gui=True)).grid(row=0, column=1, padx=2)
            ttk.Button(button_frame, text="Continue", command=lambda i=i: self.continue_experiment(i)).grid(row=0, column=2, padx=2)
            ttk.Button(button_frame, text="Retry", command=lambda i=i: self.retry_experiment(i)).grid(row=0, column=3, padx=2)
            
            
            self.session_widgets.append(session_widgets)

    def stop_experiment(self, session_number):
        if session_number not in self.stopped_sessions:
            self.stopped_sessions.add(session_number)
            self.update_response_text(session_number, "Session stopped manually.")
            self.session_widgets[session_number]['stop_button'].config(state=tk.DISABLED)
            
            # If the session was paused due to errors, remove it from paused_sessions
            if session_number in self.paused_sessions:
                self.paused_sessions.remove(session_number)
    
    def retry_experiment(self, session_number):
        if session_number in self.paused_sessions:
            self.paused_sessions.remove(session_number)
        self.error_counts[session_number] = 0
        self.run_experiment(session_number)

    def toggle_impairment_simulation(self):
        if self.simulate_impairment_var.get():
            self.impairment_type_combobox.config(state='readonly')
        else:
            self.impairment_type_combobox.config(state='disabled')
            self.impairment_type_var.set('')  # Clear the selection when disabling

    def select_all_valid(self):
        for checkbox in self.valid_checkboxes:
            checkbox.set(1)

    def deselect_all_valid(self):
        for checkbox in self.valid_checkboxes:
            checkbox.set(0)
    
    def toggle_human_participant(self):
        is_human = self.is_human_var.get()
        self.model_combobox.config(state='disabled' if is_human else 'normal')
        self.prompt_type_combobox.config(state='disabled' if is_human else 'normal')
        self.presentation_mode_combobox.config(state='disabled' if is_human else 'normal')
        self.simulate_impairment_var.set(False)  # Disable impairment simulation for human participants
        self.impairment_type_combobox.config(state='disabled')
        self.impairment_type_var.set('')
        if is_human:
            self.model_var.set('')
            self.prompt_type_var.set(list(PROMPT_TYPES.keys())[0])
            self.presentation_mode_var.set(PRESENTATION_MODES[0])
    
    def start_all_experiments(self):
        if not all([self.task_var.get(), self.prompt_type_var.get(), self.presentation_mode_var.get()]):
            messagebox.showerror("Error", "Please select all options before starting the experiments.")
            return
        
        for i in range(20):
            if self.session_valid_vars[i].get():
                self.start_experiment(i)
    
    def start_experiment(self, session_number):
        if not self.is_human_var.get():
            self.chat_sessions[session_number] = ChatSession(self.model_var.get())
        
        task_class = {
            "WCST": WCSTFlow,
            "WCST_without_restriction": WCSTFlow,
        }[self.task_var.get()]
        
        impairment_type = self.impairment_type_var.get() if self.simulate_impairment_var.get() else None
        
        self.experiment_flows[session_number] = task_class(
            self.chat_sessions[session_number],
            self.is_human_var.get(),
            session_number + 1,
            "English",
            rule=self.task_var.get() != "WCST_without_restriction",
            impairment_type=impairment_type,
        )
        
        self.experiment_flows[session_number].data_logger.start_new_session(
            self.task_var.get(),
            self.model_var.get(),
            self.prompt_type_var.get(),
            self.presentation_mode_var.get(),
            self.is_human_var.get(),
            str(session_number + 1),
            None,
            impairment_type
        )
        
        self.run_experiment(session_number)
    
    def run_experiment(self, session_number):
        if session_number in self.paused_sessions:
            return
        if session_number in self.stopped_sessions:
            return
        if self.experiment_flows[session_number].is_experiment_complete():
            self.show_results(session_number)
            self.experiment_flows[session_number].data_logger.end_session()
            return
        
        image_path, prompt = self.experiment_flows[session_number].run_trial(
            self.presentation_mode_var.get(),
            self.prompt_type_var.get()
        )
        
        self.update_info_labels(session_number)
        
        if not self.is_human_var.get():
            self.update_response_text(session_number, "Generating response...")
            threading.Thread(target=self.make_api_requests, args=(session_number, prompt, image_path)).start()
            self.master.after(100, lambda: self.check_response_queue(session_number))
    
    def continue_experiment(self, session_number):
        self.experiment_flows[session_number].finalize_current_trial()
        self.clear_feedback_text(session_number)
        self.run_experiment(session_number)
    
    def retry_experiment(self, session_number):
        self.run_experiment(session_number)
    
    def evaluate_response(self, session_number, feedback, auto_press_gui=False):
        self.experiment_flows[session_number].current_response = self.current_responses[session_number]
        feedback = self.experiment_flows[session_number].evaluate_response(
            feedback,
            self.presentation_mode_var.get(),
            self.prompt_type_var.get(),
            tokens=self.current_tokens[session_number],
            is_auto=self.auto_run.get(),
            auto_press=auto_press_gui,
        )
        self.update_feedback_text(session_number, feedback)
    
    """
    def make_api_requests(self, session_number, prompt, image_path):
        model = self.model_var.get()
        current_time = time.time()
        if not self.check_request_limits(model):
            self.response_queues[session_number].put("Error: Request limit exceeded. Please try again later.")
            return
        
        try:
            if self.presentation_mode_var.get() == "OT":
                response, tokens = self.chat_sessions[session_number].get_response(prompt, None)
            else:
                response, tokens = self.chat_sessions[session_number].get_response(prompt, image_path)  # We're not using images in this version
            
            self.current_responses[session_number] = response
            self.current_tokens[session_number] = tokens
            self.response_queues[session_number].put(self.current_responses[session_number])
            self.update_usage(model, tokens, current_time)
            self.error_counts[session_number] = 0  # 重置错误计数
        except Exception as e:
            self.response_queues[session_number].put(f"Error: {str(e)}")
            self.error_counts[session_number] += 1
            
            if self.error_counts[session_number] >= self.max_retries:
                self.paused_sessions.add(session_number)
                self.update_response_text(session_number, f"Session paused due to multiple errors: {str(e)}")
            else:
                self.master.after(5000, lambda: self.retry_experiment(session_number))
    """
    
    def make_api_requests(self, session_number, prompt, image_path):
        if session_number in self.stopped_sessions:
            return
        model = self.model_var.get()
        current_time = time.time()

        if not self.check_request_limits(model):
            wait_time = 60 - (current_time - self.last_request_time[model])
            if wait_time > 0:
                self.request_queues[model].append((session_number, prompt, image_path))
                self.master.after(int(wait_time * 1000), lambda: self.process_queue(model))
                return

        self.process_request(model, session_number, prompt, image_path)

    def process_queue(self, model):
        if self.request_queues[model] and self.check_request_limits(model):
            session_number, prompt, image_path = self.request_queues[model].popleft()
            self.process_request(model, session_number, prompt, image_path)
        elif self.request_queues[model]:
            self.master.after(5000, lambda: self.process_queue(model))

    def process_request(self, model, session_number, prompt, image_path):
        if session_number in self.stopped_sessions:
            return
        try:
            if self.presentation_mode_var.get() == "OT":
                response, tokens = self.chat_sessions[session_number].get_response(prompt, None)
            else:
                response, tokens = self.chat_sessions[session_number].get_response(prompt, image_path)
            
            self.current_responses[session_number] = response
            self.current_tokens[session_number] = tokens
            self.response_queues[session_number].put(self.current_responses[session_number])
            self.update_usage(model, tokens, time.time())
            self.error_counts[session_number] = 0
            self.last_request_time[model] = time.time()
        except Exception as e:
            self.response_queues[session_number].put(f"Error: {str(e)}")
            self.error_counts[session_number] += 1
            
            if self.error_counts[session_number] >= self.max_retries:
                self.paused_sessions.add(session_number)
                self.update_response_text(session_number, f"Session paused due to multiple errors: {str(e)}")
            else:
                self.master.after(5000, lambda: self.retry_experiment(session_number))
    
    def check_response_queue(self, session_number):
        if session_number in self.stopped_sessions:
            return
        try:
            response = self.response_queues[session_number].get_nowait()
            self.update_response_text(session_number, response)
            if self.auto_run.get():
                feedback = self.parse_feedback(response)
                if feedback is not None:
                    self.evaluate_response(session_number, feedback)
                    self.master.after(500, lambda: self.continue_experiment(session_number))
                else:
                    print(f"Unable to parse feedback for session {session_number}")
        except queue.Empty:
            self.master.after(100, lambda: self.check_response_queue(session_number))
    
    def parse_feedback(self, response):
        if self.task_var.get()=="WCST":
            match = re.search(r'Selection:.*?(\d)', response)
            if match:
                return int(match.group(1))
        elif self.task_var.get()=="WCST_without_restriction":
            match = re.search(r'Selection:.*?(\d)', response)
            if match:
                return int(match.group(1))
        return None    
    
    def update_response_text(self, session_number, response):
        response_text = self.session_widgets[session_number]['response_text']
        response_text.config(state=tk.NORMAL)
        response_text.delete(1.0, tk.END)
        response_text.insert(tk.END, response)
        response_text.config(state=tk.DISABLED)
    
    def update_feedback_text(self, session_number, feedback):
        feedback_text = self.session_widgets[session_number]['feedback_text']
        feedback_text.delete(1.0, tk.END)
        feedback_text.insert(tk.END, feedback)
    
    def clear_feedback_text(self, session_number):
        feedback_text = self.session_widgets[session_number]['feedback_text']
        feedback_text.delete(1.0, tk.END)
    
    def update_info_labels(self, session_number):
        if self.experiment_flows[session_number] is None:
            return

        flow = self.experiment_flows[session_number]
        widgets = self.session_widgets[session_number]

        # Update trial information
        widgets['trial_label'].config(text=f"{flow.current_trial + 1}/{flow.total_trials}")

        # Update task-specific information
        if isinstance(flow, WCSTFlow):
            widgets['rule_label'].config(text=f"Current Rule: {flow.info_list[0]}")
            widgets['correct_card_label'].config(text=f"Correct Card: {flow.info_list[1]}")
            widgets['correct_in_row_label'].config(text=f"Correct in Row: {flow.info_list[2]}")
    
    def show_results(self, session_number):
        self.update_feedback_text(session_number, "Experiment Complete")
    
    def reset_all_experiments(self):
        for i in range(20):
            self.reset_experiment(i)
        self.load_request_log()
    
    def reset_experiment(self, session_number):
        self.chat_sessions[session_number] = None
        self.experiment_flows[session_number] = None
        self.current_responses[session_number] = ''
        # self.current_tokens[session_number] = 0
        self.current_tokens[session_number] = []
        
        self.simulate_impairment_var.set(False)
        self.impairment_type_var.set('')
        self.impairment_type_combobox.config(state='disabled')
        
        while not self.response_queues[session_number].empty():
            try:
                self.response_queues[session_number].get_nowait()
            except queue.Empty:
                break
        
        self.update_info_labels(session_number)
        self.update_response_text(session_number, "")
        self.clear_feedback_text(session_number)
        self.error_counts[session_number] = 0
        if session_number in self.paused_sessions:
            self.paused_sessions.remove(session_number)
        if session_number in self.stopped_sessions:
            self.stopped_sessions.remove(session_number)
        self.session_widgets[session_number]['stop_button'].config(state=tk.NORMAL)
    
    def load_request_log(self):
        try:
            with open(REQUEST_LOG_FILE, 'r') as f:
                log_data = json.load(f)
                self.request_log = log_data['requests']
                self.token_usage = log_data['token_usage']
                self.request_count = log_data['request_count']
        except FileNotFoundError:
            self.request_log = []
            self.token_usage = {model: {"minute": 0, "day": 0} for model in MODELS}
            self.request_count = {model: {"minute": 0, "day": 0} for model in MODELS}
    
    def save_request_log(self):
        log_data = {
            'requests': self.request_log,
            'token_usage': self.token_usage,
            'request_count': self.request_count,
        }
        with open(REQUEST_LOG_FILE, 'w') as f:
            json.dump(log_data, f)
    
    def check_request_limits(self, model):
        current_time = time.time()
        one_minute_ago = current_time - 60
        one_day_ago = current_time - 86400
        
        self.request_log = [req for req in self.request_log if req['time'] > one_day_ago]

        minute_tokens = sum(req['tokens'] for req in self.request_log if req['time'] > one_minute_ago and req['model'] == model)
        day_tokens = sum(req['tokens'] for req in self.request_log if req['model'] == model)
        minute_requests = sum(1 for req in self.request_log if req['time'] > one_minute_ago and req['model'] == model)
        day_requests = sum(1 for req in self.request_log if req['model'] == model)
        
        self.token_usage[model]["minute"] = minute_tokens
        self.token_usage[model]["day"] = day_tokens
        self.request_count[model]["minute"] = minute_requests
        self.request_count[model]["day"] = day_requests

        if (minute_tokens >= TPM_LIMITS[model] or
            minute_requests >= RPM_LIMITS[model] or
            (TPD_LIMITS[model] and day_tokens >= TPD_LIMITS[model]) or
            (RPD_LIMITS[model] and day_requests >= RPD_LIMITS[model])):
            return False

        return True
    
    def update_usage(self, model, tokens, request_time):
        # self.token_usage[model]["minute"] += tokens[2]
        # self.token_usage[model]["day"] += tokens[2]
        # self.request_count[model]["minute"] += 1
        # self.request_count[model]["day"] += 1
        request_info = {
            "model": model,
            "tokens": tokens[2],
            "time": request_time
        }
        self.request_log.append(request_info)
        self.save_request_log()
