
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageOps
import sys

from torch_spotlight import spotlight

sys.path.insert(0, '../Common/')
from Blindspots import image_grid

def run_spotlight(embeddings, losses, ids, errors, # basic inputs
                  num_spotlights = 10, spherical = True, # spotlight config
                  num_steps = 5000, learning_rate = 1e-3, lr_patience = 10, # optimization config
                  barrier_scale = 10.0, min_weight_scale = 0.01, barrier_min_scale = 0.05, # objective function config
                  device = 'cuda:0', print_every = 500, d1 = 4, d2 = 5, assignment_threshold = 0.5, transform = None, out_dir = None):

    # Setup variables for spotlight
    num_points = len(embeddings)
    min_weight = int(min_weight_scale * num_points)
    barrier_min_x = barrier_min_scale * min_weight
    barrier_x_schedule = np.geomspace(num_points - min_weight, barrier_min_x, num_steps)

    # Setup the variables for the visualization
    K = d1 * d2
    plt.rcParams["figure.figsize"] = (16,20)

    # Run spotlight
    out = {}
    for i in range(num_spotlights):
    
        print('Finding Spotlight number', i + 1)

        if i != 0:
            print('Reducing losses based on past weights...')
            weights_unnorm /= max(weights_unnorm)
            losses = losses * (1 - weights_unnorm)

        weights, weights_unnorm, \
        objective_history, total_weight_history, lr_history, \
        mean_vector, precision_matrix \
        = spotlight.run_spotlight(
            embeddings, 
            losses,
            min_weight,
            spherical,
            barrier_x_schedule,
            barrier_scale,
            learning_rate,
            scheduler_patience = lr_patience,
            print_every = print_every,
            device = device)
        
        tmp = weights_unnorm / max(weights_unnorm)
        out[i] = np.where(tmp > assignment_threshold)[0]

        tmp = weights_unnorm
        tmp = tmp.argsort()[::-1]
        examples = [ids[i] for i in tmp[:K]]

        imgs = []
        for j in range(K):
            img = Image.open(examples[j]).convert('RGB')
            if transform is not None:
                img = transform(img)
                
            if errors[tmp[j]]:
                img = ImageOps.expand(img, border = 5, fill = (255, 0, 0))
                img = ImageOps.expand(img, border = 3, fill = 0)
            else:
                img = ImageOps.expand(img, border = 8, fill = 0)
            imgs.append(img)

        grid = image_grid(imgs, d1, d2)

        plt.imshow(grid)
        plt.axis('off')
        plt.title('Error Rate {}'.format(np.round(np.mean(errors[tmp[:K]]), 2)))
        if out_dir is None:
            plt.show()
        else:
            plt.savefig('{}/spotlight_{}.png'.format(out_dir, i + 1))
        plt.close()

        print()
        print()
        print()
        
    return out
        