import os
import argparse
import blobfile as bf
from functools import partial
import pickle
from PIL import Image
from PIL import ImageTk

import tkinter as tk
import tkinter.font as tkfont
import tkinter.messagebox as msgbox

def main():

    global curr_image, feedbacks, curr_filename, image_display

    args = create_argparser().parse_args()

    if not os.path.isdir(args.data_dir):
        print("ERROR: Provided data directory does not exist!")
        return 0
    else:
        img_list = _list_image_files_recursively(args.data_dir)
        img_iterator = iter(img_list)

    if os.path.isfile(args.feedback_path):
        with open(args.feedback_path, "rb") as f:
            feedbacks = pickle.load(f)
    else:
        feedbacks = {}

    root = tk.Tk()
    root.title("Feedback data collector")
    root.geometry("640x480")

    frame = tk.Frame(root)
    frame.pack()

    question = tk.Label(frame, text=f"Does/Is the following image (have) {args.pos_feature}?", font=tkfont.Font(family="Arial", size=15), pady=10)
    question.grid(row=0, column=0, columnspan=3, pady=6)

    curr_filename = next(img_iterator)
    while curr_filename in feedbacks.keys():
        curr_filename = next(img_iterator, None)
    
    if curr_filename is None:
        print("Feedback already completed.")
        return 0

    # curr_image = tk.PhotoImage(file=curr_filename)
    curr_image = Image.open(curr_filename).resize((args.resolution, args.resolution))
    curr_image = ImageTk.PhotoImage(curr_image)

    image_display = tk.Label(frame, image=curr_image, pady=10)
    image_display.grid(row=1, column=0, columnspan=3, pady=6)

    def store_response_n_show_next_image(response):
        # store the response within the dictionary
        global feedbacks, curr_filename, curr_image

        feedbacks[curr_filename] = response
        curr_filename = next(img_iterator, None)

        if curr_filename is not None:
            # curr_image = tk.PhotoImage(file=curr_filename)
            curr_image = Image.open(curr_filename).resize((args.resolution, args.resolution))
            curr_image = ImageTk.PhotoImage(curr_image)
            image_display.config(image=curr_image)
        else:
            save_feedback_and_quit()
    
    def save_feedback_and_quit():
        # if len(feedbacks.keys()) == len(img_list):
        if all(path in feedbacks.keys() for path in img_list):
            msgbox.showinfo("End of session", "Feedback session is completed. Thanks!")
            save_results = True
            quit_program = True
        else:
            save_results = msgbox.askyesnocancel("Save and quit", "Feedback session is yet in progress. Will you save the intermediate results?")
            quit_program = save_results is not None
        
        if quit_program:
            root.quit()
        
        if save_results:
            print("Saving feedbacks ...")
            with open(args.feedback_path, "wb") as f:
                pickle.dump(feedbacks, f)

    yes_btn = tk.Button(frame, text="Yes (Y)", width=10, height=3, command=partial(store_response_n_show_next_image, 0))
    no_btn = tk.Button(frame, text="No (N)", width=10, height=3, command=partial(store_response_n_show_next_image, 1))
    abs_btn = tk.Button(frame, text="Undecidable (D)", width=10, height=3, command=partial(store_response_n_show_next_image, None))
    
    yes_btn.grid(row=2, column=0, padx=6, pady=20, sticky='news')
    no_btn.grid(row=2, column=1, padx=6, pady=20, sticky='news')
    abs_btn.grid(row=2, column=2, padx=6, pady=20, sticky='news')

    root.bind("y", lambda event: store_response_n_show_next_image(0))
    root.bind("n", lambda event: store_response_n_show_next_image(1))
    root.bind("d", lambda event: store_response_n_show_next_image(None))
    root.bind("<Escape>", lambda event: save_feedback_and_quit())

    root.protocol("WM_DELETE_WINDOW", save_feedback_and_quit)
    root.config()
    root.mainloop()


def _list_image_files_recursively(data_dir):
    results = []
    for entry in sorted(bf.listdir(data_dir)):
        full_path = bf.join(data_dir, entry)
        ext = entry.split(".")[-1]
        if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
            results.append(full_path)
        elif bf.isdir(full_path):
            results.extend(_list_image_files_recursively(full_path))
    return results


def add_dict_to_argparser(parser, default_dict):
    for k, v in default_dict.items():
        v_type = type(v)
        if v is None:
            v_type = str
        elif isinstance(v, bool):
            v_type = str2bool
        parser.add_argument(f"--{k}", default=v, type=v_type)


def str2bool(v):
    """
    https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
    """
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError("boolean value expected")


def create_argparser():
    defaults = dict(
        data_dir="",
        feedback_path="",
        pos_feature="",
        resolution=100,
    )
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser


if __name__ == "__main__":
    main()