
from ipywidgets import widgets, HBox, VBox
import matplotlib
from matplotlib.path import Path
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageOps
from tkinter import simpledialog as sd
from tkinter import Tk

def image_grid(imgs, rows, cols):

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

class Status():
    
    def __init__(self):
        self.cords = []
        self.mode = 'neighbors'
        
    def append_cords(self, v):
        self.cords.append(v)
        
    def clear_cords(self):
        self.cords.clear()
        
    def set_mode(self, v):
        self.mode = v
        
def interface(embedding, confidence, error, file, name, use_confidence = True, transform = None, point_s = 2, point_s_on_click = 3):
    
    plt.ion()
    
    f, a = plt.subplots(1, 2, gridspec_kw={'width_ratios': [1, 2]})
    f.set_size_inches(15, 7.5)
    f.suptitle('Class: {}'.format(name))
    
    if use_confidence:
        a0_color = 1 - confidence
        a0_title = 'Lighter colors indicate lower confidence'
    else:
        a0_color = error
        a0_title = 'Lighter colors indicate errors'
    
    # Setup defaults and status
    
    def default_a0():
        a[0].clear()
        a[0].set_facecolor('silver')
        a[0].scatter(embedding[:, 0], embedding[:, 1], c = a0_color, s = point_s, cmap = 'viridis', alpha = 0.5)
        a[0].axes.xaxis.set_visible(False)
        a[0].axes.yaxis.set_visible(False)
        a[0].set_title(a0_title)
    
    def default_a1():    
        a[1].clear()
        a[1].axes.xaxis.set_visible(False)
        a[1].axes.yaxis.set_visible(False)
    
    default_a0()
    default_a1()
    
    status = Status()
    
    # Common plotting utilities
    def format_grid(chosen):
        num_chosen = len(chosen)
        # Sort by error to group the correctly/incorrectly predicted images
        chosen = chosen[np.argsort(error[chosen])]
        # Format each image in the grid
        imgs = []
        for i in range(num_chosen):
            img = Image.open(file[chosen[i]]).convert('RGB') 

            if transform is not None:
                img = transform(img)

            # Add a red border to misclassified images   
            if error[chosen[i]] == 1: 
                img = ImageOps.expand(img, border = 5, fill = (255, 0, 0))
                img = ImageOps.expand(img, border = 3, fill = 0)
            else:
                img = ImageOps.expand(img, border = 8, fill = 0)

            imgs.append(img)
       
        return imgs
        
    # Handle user clicks
    def onclick(event, status = status): 
        if status.mode == 'group':
            if len(status.cords) == 0:
                default_a0()
            x = event.xdata
            y = event.ydata
            status.append_cords([x, y])
            a[0].scatter([x], [y], c = 'red', s = point_s_on_click, alpha = 1.0)
            
        elif status.mode == 'neighbors':
            d1 = 3
            d2 = 4
            K = d1 * d2
            # Find the K neaerest neighbors
            diffs = embedding - [event.xdata, event.ydata]
            dist = np.sum(diffs**2, axis = 1)    
            chosen = np.argpartition(dist, K)[:K]
            # Display them in a gird
            imgs = format_grid(chosen)
            grid = image_grid(imgs, d1, d2)  
            # Plot stuff
            default_a0()
            a[0].scatter(embedding[chosen, 0], embedding[chosen, 1], c = 'red', s = point_s_on_click, alpha = 1.0)
            default_a1()
            a[1].imshow(grid)
            if use_confidence:
                a[1].set_title('Average Confidence: {:.0f}%'.format(np.round(100 * np.mean(confidence[chosen]), 0))) 
            else:
                a[1].set_title('Error Rate: {:.0f}%'.format(np.round(100 * np.mean(error[chosen]), 0)))
            f.show()
    
    f.canvas.mpl_connect('button_press_event', onclick)     
    
    # Switch Mode Buttons
    def switch_group(b, status = status):
        status.set_mode('group')
        status.clear_cords()
        default_a0()
        default_a1()
        
    def switch_neighbors(b, status = status):
        status.set_mode('neighbors')
        status.clear_cords()
        default_a0()
        default_a1()
    
    buttons = []
    
    b = widgets.Button(description = 'Mode: Select Group')
    b.on_click(switch_group)
    buttons.append(b)
    
    b = widgets.Button(description = 'Mode: Explore')
    b.on_click(switch_neighbors)    
    buttons.append(b)
   
    # Save Figure Button
    def save_figure(b):
        Tk().withdraw() # create and withdraw root window
        save_name = sd.askstring('Input', 'Please enter a filename')
        if save_name:
            f.savefig('./Outputs/interface/{}'.format(save_name), bbox_inches = 'tight', pad_inches = 0.25)
        
    b = widgets.Button(description = 'Save Figure')
    b.on_click(save_figure)
    buttons.append(b)
    
    # Get Group Button 
    def get_group(b, status = status):
        if status.mode == 'group':
            # For some reason, referencing 'all_cords' breaks in this function
            d1 = 4
            d2 = 5
            K = d1 * d2
            # Find all of the points that are in the chosen region
            poly = Path(status.cords)
            contains = poly.contains_points(embedding)
            # Sample d1 * d2 of those points to display in a grid
            indices = np.where(contains)[0]
            num_chosen = min(K, len(indices))
            chosen = np.random.choice(indices, num_chosen, replace = False)
            imgs = format_grid(chosen)                
            grid = image_grid(imgs, d1, d2)
            # Plot stuff
            default_a0()
            a[0].add_line(plt.Polygon(status.cords, closed = True, color = 'red', alpha = 0.25))
            default_a1()
            a[1].imshow(grid)
            if use_confidence:
                a[1].set_title('Average Confidence: {:.0f}%'.format(np.round(100 * np.mean(confidence[indices]), 0))) 
            else:
                a[1].set_title('Error Rate: {:.0f}%'.format(np.round(100 * np.mean(error[indices]), 0)))
            f.show()
            
    def reset_group(b, status = status):
        if status.mode == 'group':
            # Reset for next loop
            default_a0()
            default_a1()
            status.clear_cords()
     
    b = widgets.Button(description = 'Process Group')
    b.on_click(get_group)
    buttons.append(b)
  

    b = widgets.Button(description = 'Reset Group')
    b.on_click(reset_group)
    buttons.append(b)
    
    # Show plot and wait for interaction
    right_box = VBox([buttons[0], buttons[3], buttons[4]])
    left_box = VBox([buttons[1], buttons[2]])
    box = HBox([left_box, right_box])
    display(box)
    f.show() 
