from operator import itemgetter

import torch
from itertools import product
from torchvision.transforms.functional import affine
from matplotlib import pyplot as plt
import matplotlib.patches as patches

from mind_the_pad.paths import plot_folder


def main():
    ssd_coco_shifted = plot_folder / 'ssd_coco_shifted'
    if not ssd_coco_shifted.exists(): ssd_coco_shifted.mkdir()

    ssd_model = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd')
    utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd_processing_utils')
    ssd_model.to('cuda')
    ssd_model.eval()
    img_shifts = [
        ('http://images.cocodataset.org/val2017/000000397133.jpg', [(dx, dy) for dx, dy in product([-200, 0, 20], [-45, 0, 50])]),
        ('http://images.cocodataset.org/val2017/000000037777.jpg', [(dx, dy) for dx, dy in product([-100, 0, 75], [-160, 0, 50])]),
        ('http://images.cocodataset.org/val2017/000000252219.jpg', [(dx, dy) for dx, dy in product([-150, 0, 90], [-125, -20, 0, 40, 210])]),
        ('https://farm4.staticflickr.com/3197/2892106186_0322b8ce50_z.jpg', [(dx, dy) for dx, dy in product([-50, 0, 100], [-200, -20, 0, 100])]),
        ('http://farm6.staticflickr.com/5267/5587018821_781fa6e04e_z.jpg', [(dx, dy) for dx, dy in product([-50, -20, 0, 30, 60], [-100, -70, -30, 0, 40, 60])])
    ]
    uris = list(map(itemgetter(0), img_shifts))
    url_shifts = [(url, dx, dy) for url, shift_list in img_shifts for (dx, dy) in shift_list]
    inputs = [utils.prepare_input(uri) for uri in uris]
    tensor = utils.prepare_tensor(inputs)
    tensor_shifted = [affine(tensor[i], angle=0, translate=(dx, dy), scale=1.0, shear=[0,0])
                        for i in range(len(tensor)) for dx, dy in img_shifts[i][1]]
    tensor_shifted = torch.stack(tensor_shifted, dim=0)
    with torch.no_grad():
        detections_batch = ssd_model(tensor_shifted)
    results_per_input = utils.decode_results(detections_batch)
    best_results_per_input = [utils.pick_best(results, 0.40) for results in results_per_input]
    classes_to_labels = utils.get_coco_object_dictionary()
    for image_idx in range(len(best_results_per_input)):
        fig, ax = plt.subplots(1)
        # Show original, denormalized image...
        image = tensor_shifted[image_idx].cpu().permute(1,2,0) / 2 + 0.5
        ax.imshow(image)
        # ...with detections
        bboxes, classes, confidences = best_results_per_input[image_idx]
        for idx in range(len(bboxes)):
            left, bot, right, top = bboxes[idx]
            x, y, w, h = [val * 300 for val in [left, bot, right - left, top - bot]]
            rect = patches.Rectangle((x, y), w, h, linewidth=1, edgecolor='r', facecolor='none')
            ax.add_patch(rect)
            ax.text(x, y, "{} {:.0f}%".format(classes_to_labels[classes[idx] - 1], confidences[idx] * 100),
                    bbox=dict(facecolor='white', alpha=0.5))
        fname_image = url_shifts[image_idx][0].split("/")[-1].split(".")[0]
        shift_image = url_shifts[image_idx][1:]
        plot_image = ssd_coco_shifted / f'{fname_image}_shift={shift_image}'
        fig.savefig(plot_image, bbox_inches='tight')
        del fig, ax
        plt.close()


if __name__ == '__main__':
    main()