import json
import tkinter as tk
from tkinter import ttk
from PIL import Image, ImageTk
import os

class ScoreCalculator:
    def __init__(self, master):
        self.master = master
        self.master.title("Vision Score Calculator")
        self.master.geometry("1000x700")
        
        # Apply a theme for better looking widgets
        self.style = ttk.Style()
        self.style.theme_use('clam')

        # Configure colors
        self.bg_color = "#f0f0f0"
        self.accent_color = "#4a86e8"
        self.text_color = "#333333"
        self.progress_bg_color = "#e0e0e0"
        self.progress_color = "#2196F3"  # A more vibrant blue

        self.master.configure(bg=self.bg_color)
        self.style.configure("TFrame", background=self.bg_color)
        self.style.configure("TLabel", background=self.bg_color, foreground=self.text_color)
        self.style.configure("TButton", background="white", foreground=self.accent_color)
        # Configure progress bar style
        self.style.configure("Custom.Horizontal.TProgressbar",
                             troughcolor=self.progress_bg_color,
                             background=self.progress_color,
                             thickness=10)

        self.load_results()
        # self.scores = {model: [] for model in self.results.keys()}
        # self.current_model = self.models[0] if self.models else ""
        self.load_existing_scores()
        self.current_model = self.get_next_incomplete_model()

        self.create_widgets()
        self.load_current_image()

    def load_results(self):
        try:
            with open('./output/vision_accuracy_results.json', 'r') as f:
                self.results = json.load(f)
            self.models = list(self.results.keys())
        except FileNotFoundError:
            print("Error: vision_accuracy_results.json not found.")
            self.results = {}
            self.models = []
    
    def get_next_incomplete_model(self):
        for model in self.models:
            if len(self.scores[model]) < len(self.results[model]):
                return model
        return None  # All models are complete
    
    def load_existing_scores(self):
        self.scores = {model: [] for model in self.results.keys()}
        try:
            with open('./output/vision_accuracy_scores.json', 'r') as f:
                existing_scores = json.load(f)
            for model, scores in existing_scores.items():
                if model in self.scores:
                    self.scores[model] = scores
        except FileNotFoundError:
            print("No existing scores found. Starting fresh.")

    def create_widgets(self):
        # Main frame
        main_frame = ttk.Frame(self.master)
        main_frame.pack(fill=tk.BOTH, expand=True, padx=20, pady=20)

        # Left frame for image and response
        left_frame = ttk.Frame(main_frame)
        left_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)

        # Image frame
        image_frame = ttk.Frame(left_frame, borderwidth=2, relief="groove")
        image_frame.pack(pady=10, padx=10, fill=tk.BOTH, expand=True)

        # Image label
        self.image_label = ttk.Label(image_frame)
        self.image_label.pack(pady=10, padx=10)

        # Response frame
        response_frame = ttk.Frame(left_frame, borderwidth=2, relief="groove")
        response_frame.pack(pady=10, padx=10, fill=tk.BOTH, expand=True)

        # Response text
        self.response_text = tk.Text(response_frame, height=15, width=50, wrap=tk.WORD, font=('Arial', 14))
        self.response_text.pack(pady=10, padx=10, fill=tk.BOTH, expand=True)

        # Right frame for controls
        right_frame = ttk.Frame(main_frame, borderwidth=2, relief="groove")
        right_frame.pack(side=tk.RIGHT, fill=tk.Y, padx=20, pady=20)

        # Model selection
        ttk.Label(right_frame, text="Select Model:", font=('Arial', 14, 'bold')).pack(pady=(20, 10))
        self.model_var = tk.StringVar(value=self.current_model)
        self.model_dropdown = ttk.Combobox(right_frame, textvariable=self.model_var, values=self.models, state="readonly", font=('Arial', 12))
        self.model_dropdown.pack(pady=(0, 20))
        self.model_dropdown.bind("<<ComboboxSelected>>", self.on_model_change)

        # Score inputs
        score_frame = ttk.Frame(right_frame)
        score_frame.pack(pady=10)

        score_labels = ["Card Count (5):", "Color Count (0-5):", "Shape Count (0-5):", "Number Count (0-5):"]
        self.score_vars = [tk.StringVar(value="5") for _ in range(4)]

        for i, label in enumerate(score_labels):
            ttk.Label(score_frame, text=label, font=('Arial', 12)).grid(row=i, column=0, padx=5, pady=10, sticky="e")
            ttk.Entry(score_frame, textvariable=self.score_vars[i], width=5, font=('Arial', 12)).grid(row=i, column=1, padx=5, pady=10)

        # Next button
        next_button = ttk.Button(right_frame, text="Next", command=self.next_image)
        next_button.pack(pady=20)
        self.style.configure('TButton', font=('Arial', 12, 'bold'))
        
        # Progress bar
        self.progress_var = tk.DoubleVar()
        self.progress_bar = ttk.Progressbar(right_frame, variable=self.progress_var, maximum=100,style="Custom.Horizontal.TProgressbar")
        self.progress_bar.pack(fill=tk.X, padx=10, pady=10)

        # Progress label
        self.progress_label = ttk.Label(right_frame, text="Progress: 0/0", font=('Arial', 12))
        self.progress_label.pack(pady=5)

    def load_current_image(self):
        if not self.current_model:
            self.display_error("All models have been processed.")
            return

        model = self.current_model
        current_index = len(self.scores[model])
        
        if current_index >= len(self.results[model]):
            self.current_model = self.get_next_incomplete_model()
            if not self.current_model:
                self.calculate_final_scores()
                self.master.quit()
                return
            self.model_var.set(self.current_model)
            self.model_dropdown.set(self.current_model)
            current_index = len(self.scores[self.current_model])

        image_data = self.results[model][current_index]
        
        try:
            image = Image.open(image_data['image'])
            width, height = image.size
            new_height = 400
            new_width = int((new_height / height) * width)
            image = image.resize((new_width, new_height), Image.LANCZOS)
            photo = ImageTk.PhotoImage(image)
            self.image_label.config(image=photo)
            self.image_label.image = photo

            self.response_text.delete(1.0, tk.END)
            self.response_text.insert(tk.END, image_data['response'])
            
            # Update progress bar and label
            total_images = len(self.results[model])
            progress = (current_index / total_images) * 100
            self.progress_var.set(progress)
            self.progress_label.config(text=f"Progress: {current_index}/{total_images}")
        except FileNotFoundError:
            self.display_error(f"Image file not found: {image_data['image']}")

    def on_model_change(self, event):
        self.current_model = self.model_var.get()
        self.load_current_image()

    def next_image(self):
        if not self.current_model:
            self.display_error("No model selected or all models completed.")
            return

        self.save_current_score()
        
        if len(self.scores[self.current_model]) >= len(self.results[self.current_model]):
            self.current_model = self.get_next_incomplete_model()
            if not self.current_model:
                self.calculate_final_scores()
                self.master.quit()
                return
            self.model_var.set(self.current_model)
            self.model_dropdown.set(self.current_model)

        self.load_current_image()

    def save_current_score(self):
        reported_count = int(self.score_vars[0].get())
        card_score = 1 if reported_count == 5 else 0
        # Penalty Calculation
        penalty = 0
        if reported_count > 5:
            penalty += 0.5 * (reported_count - 5)
        score = {
            "card_count": card_score, 
            "color_count": int(self.score_vars[1].get()),
            "shape_count": int(self.score_vars[2].get()),
            "number_count": int(self.score_vars[3].get()),
            "penalty_score": penalty,
        }
        self.scores[self.current_model].append(score)

        # Reset score inputs
        for var in self.score_vars:
            var.set("5")

    def calculate_final_scores(self):
        final_scores = {}
        for model, scores in self.scores.items():
            if not scores:
                continue
            total_score = sum(s['card_count'] + s['color_count'] + s['shape_count'] + s['number_count'] - s['penalty_score'] for s in scores)
            card_count_accuracy = sum(s['card_count'] for s in scores) / len(scores)
            color_accuracy = sum(s['color_count'] for s in scores) / (len(scores) * 5)
            shape_accuracy = sum(s['shape_count'] for s in scores) / (len(scores) * 5)
            number_accuracy = sum(s['number_count'] for s in scores) / (len(scores) * 5)
            overall_accuracy = total_score / (len(scores) * 16)

            final_scores[model] = {
                "card_count_accuracy": card_count_accuracy,
                "color_accuracy": color_accuracy,
                "shape_accuracy": shape_accuracy,
                "number_accuracy": number_accuracy,
                "overall_accuracy": overall_accuracy,
            }

        try:
            self.save_results('./output/vision_accuracy_scores.json', self.scores)
            print("Scores have been saved to 'vision_accuracy_scores.json'")
            
            self.save_results('./output/vision_accuracy_final_scores.json', final_scores)
            print("Final scores have been saved to 'vision_accuracy_final_scores.json'")
        except IOError:
            print("Error: Unable to save final scores to file.")

    def save_results(self, file_path, new_results):
        if os.path.exists(file_path):
            with open(file_path, 'r') as f:
                existing_results = json.load(f)
        else:
            existing_results = {}

        # Merge new results with existing results
        for model, scores in new_results.items():
            if model in existing_results:
                if isinstance(scores, list):  # For vision_accuracy_scores.json
                    existing_results[model].extend(scores)
                else:  # For vision_accuracy_final_scores.json
                    existing_results[model].update(scores)
            else:
                existing_results[model] = scores

        with open(file_path, 'w') as f:
            json.dump(existing_results, f, indent=2)

    def display_error(self, message):
        self.response_text.delete(1.0, tk.END)
        self.response_text.insert(tk.END, f"Error: {message}")

if __name__ == "__main__":
    root = tk.Tk()
    app = ScoreCalculator(root)
    root.mainloop()