import argparse
import gc
from itertools import product
from operator import itemgetter

import matplotlib.patches as patches
import numpy as np
import torch
from matplotlib import pyplot as plt
from torch import nn
from torchvision.transforms.functional import affine

from mind_the_pad.paths import plot_folder

ssd_coco_shifted = plot_folder / 'ssd_coco_shifted_wo_batchnorm'
if not ssd_coco_shifted.exists(): ssd_coco_shifted.mkdir()


def arg_funcname(f):
    return '--' + f.__name__.replace('_', '-')


def set_padding_mode_(model: nn.Module, padding_mode):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            m.padding_mode = padding_mode


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--disable-batch-norm', action='store_true')
    parser.add_argument(arg_funcname(plot_images_with_bb_confidence), action='store_true')
    parser.add_argument(arg_funcname(plot_global_confidence_map), action='store_true')
    parser.add_argument('--set-padding-mode', choices=[None, 'reflect', 'replicate', 'circular'], default=None)
    args = parser.parse_args()
    print(args)

    ssd_model = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd')
    if args.disable_batch_norm:
        disable_batch_norm_(ssd_model)
    if args.set_padding_mode is not None:
        set_padding_mode_(ssd_model, args.set_padding_mode)

    utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd_processing_utils')
    ssd_model.to('cuda')
    ssd_model.eval()
    img_shifts = [  #data structure with the form (img_url, list of shifts (dx, dy)
        ('https://farm1.staticflickr.com/57/205590680_17aa8b608a_z.jpg',
         [(dx, dy) for dx, dy in product([0], [0])]),#[(dx, dy) for dx, dy in product(list(range(-70, 211, 8)), list(range(-42, 225, 16)))]),
        ('http://farm4.staticflickr.com/3422/3755419503_bfeef876f0_z.jpg', [(dx, dy) for dx, dy in product([0], [0])]),
    ]
    uris = list(map(itemgetter(0), img_shifts))
    url_shifts = [(url, dx, dy) for url, shift_list in img_shifts for (dx, dy) in shift_list]
    inputs = [utils.prepare_input(uri) for uri in uris]
    tensor = utils.prepare_tensor(inputs)
    tensor_shifted = [affine(tensor[i], angle=0, translate=[dx, dy], scale=1.0, shear=[0, 0])
                      for i in range(len(tensor)) for dx, dy in img_shifts[i][1]]
    # tensor_shifted = torch.stack(tensor_shifted, dim=0)
    detections_batch = []
    for i in range(len(tensor)):
        for dx, dy in img_shifts[i][1]:
            with torch.cuda.amp.autocast():
                with torch.no_grad():
                    tensor_i = affine(tensor[i], angle=0, translate=[dx, dy], scale=1.0, shear=[0, 0]).unsqueeze(0)
                    detections_batch.append(ssd_model(tensor_i))
                    del tensor_i
                    gc.collect()
    detections_batch = torch.cat(list(map(itemgetter(0), detections_batch)), dim=0).cpu(), \
                       torch.cat(list(map(itemgetter(1), detections_batch)), dim=0).cpu()
    results_per_input = utils.decode_results(detections_batch)
    best_results_per_input = [utils.pick_best(results, 0.40) for results in results_per_input]
    classes_to_labels = utils.get_coco_object_dictionary()
    if args.plot_images_with_bb_confidence:
        plot_images_with_bb_confidence(best_results_per_input, classes_to_labels, tensor_shifted, url_shifts)
    if args.plot_global_confidence_map:
        plot_global_confidence_map(best_results_per_input)


def plot_images_with_bb_confidence(best_results_per_input, classes_to_labels, tensor_shifted, url_shifts):
    for image_idx in range(len(best_results_per_input)):
        print(f'{image_idx = }')
        fig, ax = plt.subplots(1)
        # Show original, denormalized image...
        image = tensor_shifted[image_idx].cpu().permute(1, 2, 0) / 2 + 0.5
        ax.imshow(image)
        # ...with detections
        bboxes, classes, confidences = best_results_per_input[image_idx]
        for idx in range(len(bboxes)):
            left, bot, right, top = bboxes[idx]
            x, y, w, h = [val * 300 for val in [left, bot, right - left, top - bot]]
            print(x, y, w, h)
            rect = patches.Rectangle((x, y), w, h, linewidth=1, edgecolor='r', facecolor='none')
            ax.add_patch(rect)
            ax.text(x, y, "{} {:.0f}%".format(classes_to_labels[classes[idx] - 1], confidences[idx] * 100),
                    bbox=dict(facecolor='white', alpha=0.5))
        fname_image = url_shifts[image_idx][0].split("/")[-1].split(".")[0]
        shift_image = url_shifts[image_idx][1:]
        plot_image = ssd_coco_shifted / f'{fname_image}_shift={shift_image}'
        fig.savefig(plot_image, bbox_inches='tight')
        del fig, ax
        plt.close()


def plot_global_confidence_map(best_results_per_input):
    global_map = np.zeros((300, 300, 1))
    for image_idx in range(len(best_results_per_input)):
        print(f'{image_idx = }')
        bboxes, classes, confidences = best_results_per_input[image_idx]
        idx = 0
        if len(bboxes) > 0:
            left, bot, right, top = bboxes[idx]
            x, y, w, h = [val * 300 for val in [left, bot, right - left, top - bot]]
            print(x, y, w, h)
            x, y, w, h = int(x), int(y), int(w), int(h)
            global_map[y:y + h, x:x + w] = np.maximum(global_map[y:y + h, x:x + w], confidences[idx])
            gc.collect()
    plt.imshow(global_map, vmin=0.0, vmax=1.0)
    plt.savefig(ssd_coco_shifted / 'global_confidence_map.png')


def disable_batch_norm_(ssd_model):
    for m in ssd_model.modules():
        if isinstance(m, torch.nn.BatchNorm2d):
            m.forward = lambda x: x


if __name__ == '__main__':
    main()
