
import pickle
import gradio as gr
import copy
import cv2
from PIL import Image
import numpy as np
from collections import Counter
import os

# choose a model
model = 'resnet' # clip, blip, resnet
stat_file = './outputs/classification/%s_visual.pkl'%model
selected_slice_save_name = "./outputs/selected_slice.pkl"
with open(stat_file, 'rb') as f:
    stat_results = pickle.load(f)

# teddy, teddy bear
# brown bear, bruin, Ursus arctos
# American black bear, black bear, Ursus americanus, Euarctos americanus
# ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus
# sloth bear, Melursus ursinus, Ursus ursinus
stat_results = stat_results["brown bear, bruin, Ursus arctos"]

prediction_dict = {}
base_dir = './outputs/classification/data'
predictions = np.load("./outputs/classification/%s_predictions.npz"%model)
preds,gts,paths = predictions['preds'], predictions['labels'], predictions['paths']
for pred, gt, path in zip(preds, gts, paths):
    prediction_dict[path] = pred
image_paths = []
slices = list(stat_results.values())

slice_index = 0
image_index = 0
slices = sorted(slices, key=lambda x:x["accuracy"])
selected_slice = []
selected_slice_index = []
presented_slices = copy.deepcopy(slices)
tags = {}
for slice in slices:
    for attr, tag in slice["name"].items():
        if attr not in tags:
            tags[attr] = [tag, ]
        elif tag not in tags[attr]:
            tags[attr].append(tag)

include_attrs = copy.deepcopy(tags)
exclude_attrs = {}

def filter_presented_slices():
    global presented_slices
    slices_after_filters = []
    for slice in slices:
        valid = False
        for attr, tag in slice["name"].items():
            if attr in include_attrs:
                if tag in include_attrs[attr]:
                    valid = True

        for attr, tag in slice["name"].items():
            if attr in exclude_attrs:
                if tag in exclude_attrs[attr]:
                    valid = False
            if not valid:
                break
        if valid:
            slices_after_filters.append(slice)
    presented_slices = slices_after_filters

def add_constraint(dropdown1, dropdown2):
    global include_attrs, image_index, slice_index
    attr = dropdown1
    tag = dropdown2
    if attr == 'ALL':
        include_attrs = copy.deepcopy(tags)
    
    elif not attr in include_attrs:
        if tag == 'ALL':
            include_attrs[attr] = tags[attr]
        else:
            include_attrs[attr] = [tag,]
    else:
        if tag == 'ALL':
            include_attrs[attr] = tags[attr]
        elif tag not in include_attrs[attr]:
            include_attrs[attr].append(tag)
    filter_presented_slices()
    image_index, slice_index = 0,0
    return_text = ""
    for attr in include_attrs:
        return_text += f"{attr}: {include_attrs[attr]} \n"
    return return_text

def drop_constraint(dropdown1, dropdown2):
    global exclude_attrs, image_index, slice_index
    attr = dropdown1
    tag = dropdown2
    if attr == 'ALL':
        exclude_attrs = copy.deepcopy(tags)
    elif not attr in exclude_attrs:
        if tag == 'ALL':
            exclude_attrs[attr] = tags[attr]
        else:
            exclude_attrs[attr] = [tag,]
    else:
        if tag == 'ALL':
            exclude_attrs[attr] = tags[attr]
        elif tag not in exclude_attrs[attr]:
            exclude_attrs[attr].append(tag)
    filter_presented_slices()
    image_index, slice_index = 0,0
    return_text = ""
    for attr in exclude_attrs:
        return_text += f"{attr}: {exclude_attrs[attr]} \n"
    return return_text

def clear_constraint():
    global include_attrs, exclude_attrs
    include_attrs, exclude_attrs = {}, {}
    return "", ""

def dropdown_change(dropdown):
    if dropdown != 'ALL':
        return gr.update(choices=["ALL",] + tags[dropdown], value="ALL")
    else:
        return gr.update(choices=["ALL",], value="ALL")

def display_image():
    # get the path of the current images
    global image_paths
    image_paths = presented_slices[slice_index]["visuals"]
    image_path, image_acc = image_paths[image_index%len(presented_slices[slice_index]["visuals"])][:2]
    # get labels of the current images
    labels = presented_slices[slice_index]["name"]
    label_display = ""
    count = presented_slices[slice_index]["count"]
    acc = presented_slices[slice_index]["accuracy"]
    # display labels with data counts
    label_display += "selected\n" if is_selected() else "not selected\n"
    label_display += f"slices: {slice_index}/{len(presented_slices)}:\n"
    label_display += f"image: {image_index}/{count}:\n"
    label_display += "acc:%.2f:\n"%acc
    element_counts = Counter()
    element_counts.update([prediction_dict[s_i[0]] for s_i in image_paths])
    slice_distribution = ["%d:%.2f"%(key, value/len(image_paths)) for key, value in element_counts.items()]
    label_display += "slice distribution: " + "  ".join(slice_distribution) + "\n"
    label_display += f"model prediction: {prediction_dict[image_path]}\n"
    label_display += "image correctness: %.2f\n"%image_acc
    label_display += "\n"
    for attribute_type in ["main object","background","global"]:
        label_display += f"{attribute_type}:\n"
        for attr, tag in labels.items():
            label_attr_type = attr.split(", ")[0]
            if label_attr_type == attribute_type:
                label_display += f"\t {attr[len(attribute_type)+2:]}:{tag}\n"

    return image_path, label_display

def next_image():
    global image_paths, image_index, presented_slices, slice_index
    image_index = (image_index + 1) % len(image_paths)
    
    return display_image() 

def prev_image():
    global image_paths, image_index, presented_slices, slice_index
    image_index = (image_index - 1 + len(image_paths)) % len(image_paths)
    return display_image() 

def next_slice():
    global image_paths, image_index, presented_slices, slice_index
    slice_index = (slice_index + 1) % len(presented_slices)
    image_index = 0
    return display_image() 

def prev_slice():
    global image_paths, image_index, presented_slices, slice_index
    slice_index = (slice_index - 1 + len(presented_slices)) % len(presented_slices)
    image_index = 0
    return display_image() 

def is_selected():
    for slice in selected_slice:
        if slice["name"] == presented_slices[slice_index]["name"]:
            return True
    return False

def save_slice():
    if not is_selected():
        selected_slice.append(presented_slices[slice_index])
        with open(selected_slice_save_name, 'wb') as f:
            pickle.dump({"slice":selected_slice}, f)
    return display_image() 

# Gradio UI
with gr.Blocks() as interface:
    with gr.Row():
        dropdown1 = gr.Dropdown(choices=["ALL",] + list(tags.keys()), value=list(tags.keys())[0], label="ATTRIBUTE")
        dropdown2 = gr.Dropdown(choices=["ALL",] + list(tags.values())[0], value="ALL", label="TAGS")
        button1 = gr.Button("Add")
        button2 = gr.Button("Drop")
        button3 = gr.Button("Clear")

    with gr.Row():
        textbox1 = gr.Textbox(label="Include")
        textbox2 = gr.Textbox(label="Not Include")

    with gr.Row():
        prev_s_btn = gr.Button("last slice")
        next_s_btn = gr.Button("next slice")
        save_s_btn = gr.Button("save this slice")
    with gr.Row():
        prev_btn = gr.Button("last image")
        next_btn = gr.Button("next iamge")
    with gr.Row():
        image = gr.Image(label="image")
        label_text = gr.Textbox(label="labels")

    dropdown1.change(dropdown_change, inputs=[dropdown1], outputs=[dropdown2,])
    button1.click(add_constraint, inputs=[dropdown1, dropdown2], outputs=[textbox1,])
    button2.click(drop_constraint, inputs=[dropdown1, dropdown2], outputs=[textbox2,])
    button3.click(clear_constraint, inputs=[], outputs=[textbox1, textbox2])

    next_s_btn.click(next_slice, 
                   inputs=[], 
                   outputs=[image, label_text,])
    
    prev_s_btn.click(prev_slice, 
                   inputs=[], 
                   outputs=[image, label_text,])
    
    save_s_btn.click(save_slice,
                    inputs=[], 
                    outputs=[image, label_text,])

    next_btn.click(next_image, 
                   inputs=[], 
                   outputs=[image, label_text,])
    
    prev_btn.click(prev_image, 
                   inputs=[], 
                   outputs=[image, label_text,])
    
    interface.load(display_image, 
                   inputs=[], 
                   outputs=[image, label_text,])
interface.launch()