import torch
import numpy as np
import cv2
from matplotlib import pyplot as plt

from samples.CLS2IDX import CLS2IDX
from src.utilities import min_max_normal

# create heatmap from mask on image
def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return cam


def gen_grad_cam(baselines, index, image):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    output = baselines.model(image.unsqueeze(0).to(device), register_hook=True)
    if index is None:
        index = np.argmax(output.cpu().data.numpy())
    one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
    one_hot[0][index] = 1
    one_hot = torch.from_numpy(one_hot).requires_grad_(True)
    one_hot = torch.sum(one_hot.to(device) * output)
    baselines.model.zero_grad()
    one_hot.backward(retain_graph=True)
    grad = baselines.model.blocks[-1].attn.get_attn_gradients()
    cam = baselines.model.blocks[-1].attn.get_attention_map()
    cam = cam[0, :, 0, 1:].reshape(-1, 14, 14)
    grad = grad[0, :, 0, 1:].reshape(-1, 14, 14)
    grad = grad.mean(dim=[1, 2], keepdim=True)
    cam = (cam * grad).mean(0).clamp(min=0)
    cam = (cam - cam.min()) / (cam.max() - cam.min())
    return cam


def show_cam_on_image_helper(original_image, transformer_attribution):
    
    image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()
    image_transformer_attribution = min_max_normal(image_transformer_attribution)

    vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)
    vis = np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis


def get_importance_maps(attribution_generator, baselines, original_image,
                        class_index=None, method="transformer_attribution"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if method == "attn_gradcam":
        # transformer_attribution = baselines.generate_cam_attn(original_image.unsqueeze(0).to(device)).detach()
        transformer_attribution = gen_grad_cam(baselines, class_index, original_image).detach()
    else:
        transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0).to(device),
                                                                     method=method, index=class_index).detach()

    if method == "full_lrp":
        transformer_attribution = transformer_attribution.reshape(1, 1, 224, 224)
    else:
        transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
        transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16,
                                                                  mode='bilinear')
    transformer_attribution = transformer_attribution.reshape(224, 224).to(device).data.cpu().numpy()
    transformer_attribution = min_max_normal(transformer_attribution)

    return transformer_attribution


def print_top_classes(predictions, **kwargs):
    # Print Top-5 predictions
    prob = torch.softmax(predictions, dim=1)
    class_indices = predictions.data.topk(5, dim=1)[1][0].tolist()
    max_str_len = 0
    class_names = []
    for cls_idx in class_indices:
        class_names.append(CLS2IDX[cls_idx])
        if len(CLS2IDX[cls_idx]) > max_str_len:
            max_str_len = len(CLS2IDX[cls_idx])

    print('Top 5 classes:')
    for cls_idx in class_indices:
        output_string = '\t{} : {}'.format(cls_idx, CLS2IDX[cls_idx])
        output_string += ' ' * (max_str_len - len(CLS2IDX[cls_idx])) + '\t\t'
        output_string += 'value = {:.3f}\t prob = {:.1f}%'.format(predictions[0, cls_idx], 100 * prob[0, cls_idx])
        print(output_string)


def plot_images(*images):
    images = list(images)
    if len(images) == 6:
        n_rows, n_cols = 2, 3
    else:
        n_rows, n_cols = 1, len(images)

    fig, axes = plt.subplots(n_rows, n_cols)
    flattened_axes = axes.flatten()
    for idx, image in enumerate(images):
        flattened_axes[idx].imshow(image)
        flattened_axes[idx].axis('off')

    return fig, axes


def show_before_after_perturbation(*images):
    fig, axes = plot_images(*images)
    titles = ['original image', 'before perturbation', 'after perturbation']
    for idx, title in enumerate(titles):
        axes[0][idx].set_title(title)

    return fig, axes
