import torch
import copy
import numpy as np
import matplotlib.patches as patches_plt
import matplotlib.pyplot as plt
import torch.nn.functional as F
import os
import sys
import models_mae
from einops import rearrange
from torch.nn.functional import softmax
from PIL import Image

cwd = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(cwd))

imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])


def calculate_euclidean_distance(x, y):
    """
    calculate_euclidean_distance
    """
    n = x.size(0)
    m = y.size(0)
    d = x.size(1)
    x = x.unsqueeze(1).expand(n, m, d)
    y = y.unsqueeze(0).expand(n, m, d)
    return torch.pow(x - y, 2).sum(2)


def calculate_cosine_similarity(x, y):
    """
    calculate_cosine_similarity
    """
    x_norm = F.normalize(x, p=2, dim=1)
    y_norm = F.normalize(y, p=2, dim=1)
    return torch.mm(x_norm, y_norm.transpose(0, 1))


def fill_to_full(arr):
    new_arr = copy.deepcopy(arr)
    if isinstance(new_arr, np.ndarray):
        new_arr = list(new_arr)
    for i in range(196):
        if i not in new_arr:
            new_arr.append(i)
    return torch.tensor(new_arr)[np.newaxis, ]


def convert_to_tensor(img):
    if isinstance(img, np.ndarray):
        img = torch.tensor(img)
    if len(img.shape) != 4:
        # make it a batch-like
        img = img.unsqueeze(dim=0)
        img = torch.einsum('nhwc->nchw', img)
    elif img.shape[-1] == 3:
        assert isinstance(img, torch.Tensor)
        img = torch.einsum('nhwc->nchw', img)
    return img


def generate_arr_mask_for_evaluation(arr):
    mask = np.zeros((14, 14))
    if arr == 'a1' or arr == 'a8':
        mask[:7] = 1
        mask[:, :7] = 1
        mask = obtain_values_from_mask(mask)
        len_keep = len(mask)
    elif arr == 'a2' or arr == 'a7':
        mask[:7] = 1
        mask[:, 7:] = 1
        mask = obtain_values_from_mask(mask)
        len_keep = len(mask)
    elif arr == 'a3' or arr == 'a6':
        mask[7:] = 1
        mask[:, 7:] = 1
        mask = obtain_values_from_mask(mask)
        len_keep = len(mask)
    elif arr == 'a4' or arr == 'a5':
        mask[7:] = 1
        mask[:, :7] = 1
        mask = obtain_values_from_mask(mask)
        len_keep = len(mask)
    else:
        raise ValueError("The arrangement is not in list!")

    return fill_to_full(mask), len_keep


def generate_large_mask_for_evaluation(arr):
    mask = np.zeros((14, 14))
    if arr == 'a1' or arr == 'a8':
        mask[:11] = 1
        mask[:, :11] = 1
        mask = obtain_values_from_mask(mask)
        len_keep = len(mask)
    elif arr == 'a2' or arr == 'a7':
        mask[:11] = 1
        mask[:, 11:] = 1
        mask = obtain_values_from_mask(mask)
        len_keep = len(mask)
    elif arr == 'a3' or arr == 'a6':
        mask[11:] = 1
        mask[:, 11:] = 1
        mask = obtain_values_from_mask(mask)
        len_keep = len(mask)
    elif arr == 'a4' or arr == 'a5':
        mask[11:] = 1
        mask[:, :11] = 1
        mask = obtain_values_from_mask(mask)
        len_keep = len(mask)
    else:
        raise ValueError("The arrangement is not in list!")

    return fill_to_full(mask), len_keep


def generate_mask_for_evaluation():
    mask = np.zeros((14,14))
    mask[:7] = 1
    mask[:, :7] = 1
    mask = obtain_values_from_mask(mask)
    len_keep = len(mask)
    return fill_to_full(mask), len_keep


def generate_mask_for_evaluation_2rows():
    mask = np.zeros((14,14))
    mask[:9] = 1
    mask[:, :7] = 1
    mask = obtain_values_from_mask(mask)
    len_keep = len(mask)
    return fill_to_full(mask), len_keep


def generate_mask_for_evaluation_2rows_more_context():
    mask = np.zeros((14,14))
    mask[:9] = 1
    mask[:, :7] = 1
    mask[: , 12:] = 1
    mask = obtain_values_from_mask(mask)
    len_keep = len(mask)
    return fill_to_full(mask), len_keep

    
def obtain_values_from_mask(mask: np.ndarray):
    if mask.shape == (14, 14):
        return list(mask.flatten().nonzero()[0])
    assert mask.shape == (224, 224)
    counter = 0
    values = []
    for i in range(0, 224, 16):
        for j in range(0, 224, 16):
            if np.sum(mask[i:i+16, j:j+16]) == 16 ** 2:
                values.append(counter)
            counter += 1
    return values


WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
RED = (255, 0, 0)
GREEN = (0, 255, 0)
BLUE = (0, 0, 255)
PURPLE = (0x44, 0x01, 0x54)
YELLOW = (0xFD, 0xE7, 0x25)


class TemperatureScaling:
    def __init__(self, temperature=1.0):
        self.temperature = temperature

    def scale_probs(self, logits):
        scaled_logits = logits / self.temperature
        return softmax(scaled_logits, dim=-1)


def prepare_model(chkpt_dir, arch='mae_vit_large_patch16', device='cpu'):
    # build model
    model = getattr(models_mae, arch)()

    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    model.to(device)
    return model


def prepare_seggpt_model(model_attr, chkpt_dir, arch='painter_vit_large_patch16_input896x448_win_dec64_8glb_sl1', seg_type=None):
    # build model
    model = getattr(model_attr, arch)()
    if seg_type:
        model.seg_type = seg_type
    # model.to("cuda")
    # # load model
    # checkpoint = torch.load(chkpt_dir, 'cpu')
    checkpoint = torch.load(chkpt_dir)
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    model.eval()
    return model


@torch.no_grad()
def generate_image(orig_image, model, ids_shuffle, len_keep: int, device: str = 'cpu'):
    """ids_shuffle is [bs, 196]"""
    mask, orig_image, x = generate_raw_prediction(device, ids_shuffle, len_keep, model, orig_image)
    num_patches = 14
    y = x.argmax(dim=-1)
    im_paste, mask, orig_image = decode_raw_predicion(mask, model, num_patches, orig_image, y)

    return orig_image, im_paste[0], mask


def new_extract_ignore_idx(image_name, masks, class_ids, purple=False):
    PURPLE = (0x44, 0x01, 0x54)
    YELLOW = (0xFD, 0xE7, 0x25)
    mask = np.array(masks)
    boundary = np.floor(mask / 255.)
    if not purple:
        if (class_ids + 1) not in mask:
            print(f'ohno, {image_name} not contain {class_ids}')
        mask[mask != class_ids + 1] = 0
        mask[mask == class_ids + 1] = 255
        return Image.fromarray(mask), boundary
    color_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
    for x in range(mask.shape[0]):
        for y in range(mask.shape[1]):
            if mask[x, y] != class_ids + 1:
                color_mask[x, y] = np.array(PURPLE)
            else:
                color_mask[x, y] = np.array(YELLOW)
    return Image.fromarray(color_mask), boundary


def get_query_support_mask(query_image, query_cmask, support_image, support_cmask, class_id, support_class):
    query_mask, query_ignore_idx = new_extract_ignore_idx(query_image, query_cmask, class_id, purple=False)
    support_masks = []
    support_ignore_idxs = []
    for scmask_id in range(len(support_cmask)):
        support_mask, support_ignore_idx = new_extract_ignore_idx(support_image[scmask_id], support_cmask[scmask_id],
                                                                  support_class, purple=False)
        support_masks.append(support_mask)
        support_ignore_idxs.append(support_ignore_idx)
    return query_mask, support_masks


def add_image_path(img_path, ann_path, query_name, support_name, query_class, support_class):
    support_name = [os.path.basename(chose_name) for chose_name in support_name]
    support_image_path = [(os.path.join(img_path, chose_id) + '.jpg') for chose_id in support_name]
    support_cmask = [Image.open(os.path.join(ann_path, chose_id) + '.png') for chose_id in support_name]
    query_image_path = os.path.join(img_path, query_name[0]) + '.jpg'
    query_cmask = Image.open(os.path.join(ann_path, query_name[0]) + '.png')
    query_class = query_class.cpu().numpy()[0]
    support_class = support_class.cpu().numpy()[0]

    query_mask, support_masks = get_query_support_mask(query_image_path, query_cmask,
                                                       support_image_path, support_cmask, query_class, support_class)

    return query_image_path, support_image_path, query_mask, support_masks


def calculate_js_divergence(p, q):
    m = 0.5 * (p + q)

    kl_div_p_m = F.kl_div(p.log(), m, reduction='none')
    kl_div_q_m = F.kl_div(q.log(), m, reduction='none')

    js_div = 0.5 * (kl_div_p_m + kl_div_q_m)

    return js_div.sum(-1)


@torch.no_grad()
def generate_image_by_knn_indice_weight_prob(orig_image, model, ids_shuffle, datastore, key, k, len_keep, args, device: str = 'cpu'):
    mask, orig_image, x = generate_raw_prediction(device, ids_shuffle, len_keep, model, orig_image)

    temp_scaling = TemperatureScaling(temperature=args.temp)
    alpha = args.alpha
    datastore = datastore[key]
    num_patches = 14
    y = x.argmax(dim=-1)

    indices_to_process = mask[0].nonzero(as_tuple=True)[0]
    x_elements = x

    x_elements_prob = temp_scaling.scale_probs(x_elements)

    for index in indices_to_process:
        datastore_specific = datastore[str(index.item())]
        datastore_keys = torch.stack([item['key'] for item in datastore_specific])
        datastore_keys_prob = temp_scaling.scale_probs(datastore_keys).to(device)
        x_element_prob = x_elements_prob[0, index].unsqueeze(0).to(device)

        distances = calculate_js_divergence(x_element_prob, datastore_keys_prob)
        weights = softmax(-distances, dim=-1)
        values, indices = torch.topk(distances, min(k, len(distances)), largest=False)
        nearest_probs = datastore_keys_prob[indices]
        nearest_weights = weights[indices]

        weighted_sum = torch.zeros_like(nearest_probs[0])
        for prob, weight in zip(nearest_probs, nearest_weights):
            weighted_sum += prob * weight
        weighted_sum /= torch.sum(nearest_weights)

        original_prob = x_elements_prob[0, index]
        fused_prob = (1 - alpha) * original_prob + alpha * weighted_sum
        y[0, index] = fused_prob.argmax()

    im_paste, mask, orig_image = decode_raw_predicion(mask, model, num_patches, orig_image, y)

    return orig_image, im_paste[0], mask


@torch.no_grad()
def generate_image_by_knn_indice_weight_prob_fc_features(orig_image, model, ids_shuffle, datastore, key, k, len_keep, args, device: str = 'cpu'):
    mask, orig_image, x, x_fc_features = generate_raw_prediction_features_before_fc(device, ids_shuffle, len_keep, model, orig_image)

    temp_scaling = TemperatureScaling(temperature=args.temp)
    alpha = args.alpha
    datastore = datastore[key]
    num_patches = 14
    y = x.argmax(dim=-1)

    indices_to_process = mask[0].nonzero(as_tuple=True)[0]
    x_elements = x

    x_elements_prob = temp_scaling.scale_probs(x_elements)

    for index in indices_to_process:
        datastore_specific = datastore[str(index.item())]
        datastore_keys = torch.stack([item['key'] for item in datastore_specific])
        datastore_keys_prob = temp_scaling.scale_probs(datastore_keys).to(device)
        datastore_fc_features = torch.stack([item['fc_feature'] for item in datastore_specific])
        distances = calculate_cosine_similarity(x_fc_features[0, index].unsqueeze(0).to(device), datastore_fc_features.to(device))

        weights = distances / distances.sum(dim=-1, keepdim=True)
        weights = weights.t()

        weighted_sum = torch.zeros_like(datastore_keys_prob[0]).to(device)
        for prob, weight in zip(datastore_keys_prob, weights):
            weighted_sum += prob * weight

        original_prob = x_elements_prob[0, index]
        fused_prob = (1 - alpha) * original_prob + alpha * weighted_sum
        y[0, index] = fused_prob.argmax()

    im_paste, mask, orig_image = decode_raw_predicion(mask, model, num_patches, orig_image, y)

    return orig_image, im_paste[0], mask


@torch.no_grad()
def generate_image_by_knn_indice_weight_prob_pixel_patch(orig_image, model, ids_shuffle, datastore, key, k, len_keep, args, device: str = 'cpu'):
    def patchify(image, patch_size=16):
        """
        patchify image
        image: [H, W, C]
        patch_size: int
        return: [num_patches, patch_size*patch_size*C]
        """
        patches = rearrange(image, '(h p1) (w p2) c -> (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)
        return patches

    mask, orig_image, x = generate_raw_prediction(device, ids_shuffle, len_keep, model, orig_image)

    temp_scaling = TemperatureScaling(temperature=args.temp)
    alpha = args.alpha
    datastore = datastore[key]
    num_patches = 14
    y = x.argmax(dim=-1)

    im_paste_, _, _ = decode_raw_predicion(mask, model, num_patches, orig_image, y)
    im_paste_ = im_paste_[0]
    patches = patchify(im_paste_, patch_size=16)
    x_patches = patches.unsqueeze(0)

    indices_to_process = mask[0].nonzero(as_tuple=True)[0]
    x_elements = x
    x_elements_prob = temp_scaling.scale_probs(x_elements)

    for index in indices_to_process:
        datastore_specific = datastore[str(index.item())]
        datastore_keys = torch.stack([item['key'] for item in datastore_specific])
        datastore_keys_prob = temp_scaling.scale_probs(datastore_keys).to(device)
        datastore_pixel_features = torch.stack([item['feature'] for item in datastore_specific])
        distances = calculate_cosine_similarity(x_patches[0, index].unsqueeze(0).to(device), datastore_pixel_features.to(device))

        weights = distances / distances.sum(dim=-1, keepdim=True)
        weights = weights.t()

        weighted_sum = torch.zeros_like(datastore_keys_prob[0]).to(device)
        for prob, weight in zip(datastore_keys_prob, weights):
            weighted_sum += prob * weight

        original_prob = x_elements_prob[0, index]
        fused_prob = (1 - alpha) * original_prob + alpha * weighted_sum
        y[0, index] = fused_prob.argmax()

    im_paste, mask, orig_image = decode_raw_predicion(mask, model, num_patches, orig_image, y)

    return orig_image, im_paste[0], mask


def generate_raw_pred_for_train(orig_image, model, ids_shuffle, len_keep: int, device: str = 'cpu'):
    """ids_shuffle is [bs, 196]"""
    x_stack = torch.tensor([])
    batch_size = orig_image.shape[0]
    for i in range(batch_size):
        x, mask = generate_for_training(orig_image[i], model, ids_shuffle, len_keep, device)
        if len(x_stack) == 0:
            x_stack = x
        else:
            x_stack = torch.cat((x_stack, x))

    return x_stack, mask


@torch.no_grad()
def generate_feature_for_train(orig_image, model, ids_shuffle, len_keep: int, device: str = 'cpu'):
    """ids_shuffle is [bs, 196]"""
    x_stack = torch.tensor([])
    batch_size = orig_image.shape[0]
    for i in range(batch_size):
        x, mask = generate_for_training(orig_image[i], model, ids_shuffle, len_keep, device)
        num_patches = 14
        y = x.argmax(dim=-1)
        feature, orig_image, mask_ = decode_raw_predicion_2_get_feature(mask, model, num_patches, orig_image, y)
        if len(x_stack) == 0:
            x_stack = feature
        else:
            x_stack = torch.cat((x_stack, feature))

    return x_stack, mask_, orig_image


def decode_raw_predicion(mask, model, num_patches, orig_image, y):
    y = model.vae.quantize.get_codebook_entry(y.reshape(-1),
                                              [y.shape[0], y.shape[-1] // num_patches, y.shape[-1] // num_patches, -1])
    y = model.vae.decode(y)
    # plt.figure(); plt.imshow(y.cpu()[0].permute(1,2,0)); plt.show()
    y = F.interpolate(y, size=(224, 224), mode='bilinear').permute(0, 2, 3, 1)
    y = torch.clip(y * 255, 0, 255).int().detach().cpu()
    # plt.figure()
    # plt.imshow(y[0])
    # plt.show()

    # visualize the mask
    mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0] ** 2 * 3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
    orig_image = torch.einsum('nchw->nhwc', orig_image)
    orig_image = (
        torch.clip((orig_image[0].cpu().detach() * imagenet_std + imagenet_mean) * 255, 0, 255).int()).unsqueeze(0)
    # MAE reconstruction pasted with visible patches
    im_paste = orig_image * (1 - mask) + y * mask
    # plt.figure()
    # plt.imshow(im_paste[0])
    # plt.show()

    return im_paste, mask, orig_image


def decode_raw_predicion_2_get_feature(mask, model, num_patches, orig_image, y):
    y = model.vae.quantize.get_codebook_entry(y.reshape(-1),
                                              [y.shape[0], y.shape[-1] // num_patches, y.shape[-1] // num_patches, -1])
    y, feature = model.vae.decode_feature(y)
    feature = F.interpolate(feature, size=(224, 224), mode='bilinear').permute(0, 2, 3, 1).detach().cpu()
    # y = F.interpolate(y, size=(224, 224), mode='bilinear').permute(0, 2, 3, 1)
    # y = torch.clip(y * 255, 0, 255).int().detach().cpu()
    # plt.figure()
    # plt.imshow(y[0])
    # plt.show()

    # visualize the mask
    mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0] ** 2 * 3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
    orig_image = torch.einsum('nchw->nhwc', orig_image)
    orig_image = (
        torch.clip((orig_image[0].cpu().detach() * imagenet_std + imagenet_mean) * 255, 0, 255).int()).unsqueeze(0)
    # MAE reconstruction pasted with visible patches
    # im_paste = orig_image * (1 - mask) + y * mask
    # plt.figure()
    # plt.imshow(im_paste[0])
    # plt.show()

    return feature, orig_image, mask


@torch.no_grad()
def generate_raw_prediction(device, ids_shuffle, len_keep, model, orig_image):
    ids_shuffle = ids_shuffle.to(device)
    # make it a batch-like
    orig_image = convert_to_tensor(orig_image).to(device)
    temp_x = orig_image.clone().detach().to(device)
    # RUN ENCODER:
    # embed patches
    latent = model.patch_embed(temp_x.float())
    # add pos embed w/o cls token
    latent = latent + model.pos_embed[:, 1:, :]
    # masking: length -> length * mask_ratio
    N, L, D = latent.shape  # batch, length, dim
    # sort noise for each sample
    ids_restore = torch.argsort(ids_shuffle, dim=1)
    # keep the first subset
    ids_keep = ids_shuffle[:, :len_keep]
    latent = torch.gather(
        latent, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
    # generate the binary mask: 0 is keep, 1 is remove
    mask = torch.ones([N, L], device=latent.device)
    mask[:, :len_keep] = 0
    # unshuffle to get the binary mask
    mask = torch.gather(mask, dim=1, index=ids_restore)
    # append cls token
    cls_token = model.cls_token + model.pos_embed[:, :1, :]
    cls_tokens = cls_token.expand(latent.shape[0], -1, -1)
    latent = torch.cat((cls_tokens, latent), dim=1)
    # apply Transformer blocks
    for blk in model.blocks:
        latent = blk(latent)
    latent = model.norm(latent)
    x = model.decoder_embed(latent)
    # append mask tokens to sequence
    mask_tokens = model.mask_token.repeat(
        x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
    x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
    x_ = torch.gather(
        x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
    x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token
    # add pos embed
    x = x + model.decoder_pos_embed
    # apply Transformer blocks
    for block_num, blk in enumerate(model.decoder_blocks):
        # Here is unrollment of the decoder blocks:
        x_temp = blk.norm1(x)
        # here is an unrollment of the attention mechanism:
        B, N, C = x_temp.shape
        qkv = blk.attn.qkv(x_temp).reshape(
            B, N, 3, blk.attn.num_heads, C // blk.attn.num_heads).permute(2, 0, 3, 1, 4)
        # make torchscript happy (cannot use tensor as tuple)
        q, k, v = qkv.unbind(0)
        attn = (q @ k.transpose(-2, -1)) * blk.attn.scale
        # This is where our code comes to mind:
        attn = attn.softmax(dim=-1)

        x_temp = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x_temp = blk.attn.proj(x_temp)
        x_temp = blk.attn.proj_drop(x_temp)
        # Here we continue to the orignal block.
        x = x + blk.drop_path1(x_temp)

        x = x + blk.drop_path2(blk.mlp(blk.norm2(x)))
    x = model.decoder_norm(x)
    # predictor projection
    x = model.decoder_pred(x)
    # remove cls token
    x = x[:, 1:, :]

    return mask, orig_image, x


@torch.no_grad()
def generate_raw_prediction_features_before_fc(device, ids_shuffle, len_keep, model, orig_image):
    ids_shuffle = ids_shuffle.to(device)
    # make it a batch-like
    orig_image = convert_to_tensor(orig_image).to(device)
    temp_x = orig_image.clone().detach().to(device)
    # RUN ENCODER:
    # embed patches
    latent = model.patch_embed(temp_x.float())
    # add pos embed w/o cls token
    latent = latent + model.pos_embed[:, 1:, :]
    # masking: length -> length * mask_ratio
    N, L, D = latent.shape  # batch, length, dim
    # sort noise for each sample
    ids_restore = torch.argsort(ids_shuffle, dim=1)
    # keep the first subset
    ids_keep = ids_shuffle[:, :len_keep]
    latent = torch.gather(
        latent, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
    # generate the binary mask: 0 is keep, 1 is remove
    mask = torch.ones([N, L], device=latent.device)
    mask[:, :len_keep] = 0
    # unshuffle to get the binary mask
    mask = torch.gather(mask, dim=1, index=ids_restore)
    # append cls token
    cls_token = model.cls_token + model.pos_embed[:, :1, :]
    cls_tokens = cls_token.expand(latent.shape[0], -1, -1)
    latent = torch.cat((cls_tokens, latent), dim=1)
    # apply Transformer blocks
    for blk in model.blocks:
        latent = blk(latent)
    latent = model.norm(latent)
    x = model.decoder_embed(latent)
    # append mask tokens to sequence
    mask_tokens = model.mask_token.repeat(
        x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
    x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
    x_ = torch.gather(
        x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
    x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token
    # add pos embed
    x = x + model.decoder_pos_embed
    # apply Transformer blocks
    for block_num, blk in enumerate(model.decoder_blocks):
        # Here is unrollment of the decoder blocks:
        x_temp = blk.norm1(x)
        # here is an unrollment of the attention mechanism:
        B, N, C = x_temp.shape
        qkv = blk.attn.qkv(x_temp).reshape(
            B, N, 3, blk.attn.num_heads, C // blk.attn.num_heads).permute(2, 0, 3, 1, 4)
        # make torchscript happy (cannot use tensor as tuple)
        q, k, v = qkv.unbind(0)
        attn = (q @ k.transpose(-2, -1)) * blk.attn.scale
        # The attention shape is [1, 16, 197, 197]
        # This is where our code comes to mind:
        attn = attn.softmax(dim=-1)

        x_temp = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x_temp = blk.attn.proj(x_temp)
        x_temp = blk.attn.proj_drop(x_temp)
        # Here we continue to the orignal block.
        x = x + blk.drop_path1(x_temp)

        x = x + blk.drop_path2(blk.mlp(blk.norm2(x)))
    x = model.decoder_norm(x)
    fc_feature = x[:, 1:, :]
    # predictor projection
    x = model.decoder_pred(x)
    # remove cls token
    x = x[:, 1:, :]

    return mask, orig_image, x, fc_feature


@torch.no_grad()
def generate_decoder_embeddings(orig_image, model, ids_shuffle, len_keep, attribute: str = 'none', index: int = -1, device: str = 'cpu'):
    """ids_shuffle is [bs, 196]"""
    ids_shuffle = ids_shuffle.to(device)
    # make it a batch-like
    orig_image = convert_to_tensor(orig_image).to(device)
    temp_x = orig_image.clone().detach().to(device)

    # embed patches
    latent = model.patch_embed(temp_x.float())

    # add pos embed w/o cls token
    latent = latent + model.pos_embed[:, 1:, :]

    # masking: length -> length * mask_ratio
    N, L, D = latent.shape  # batch, length, dim
    # sort noise for each sample
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    # keep the first subset
    ids_keep = ids_shuffle[:, :len_keep]
    latent = torch.gather(
        latent, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

    # generate the binary mask: 0 is keep, 1 is remove
    mask = torch.ones([N, L], device=latent.device)
    mask[:, :len_keep] = 0
    # unshuffle to get the binary mask
    mask = torch.gather(mask, dim=1, index=ids_restore)

    # append cls token
    cls_token = model.cls_token + model.pos_embed[:, :1, :]
    cls_tokens = cls_token.expand(latent.shape[0], -1, -1)
    latent = torch.cat((cls_tokens, latent), dim=1)

    # apply Transformer blocks
    for blk in model.blocks:
        latent = blk(latent)
    latent = model.norm(latent)
    x = model.decoder_embed(latent)

    # append mask tokens to sequence
    mask_tokens = model.mask_token.repeat(
        x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
    x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
    x_ = torch.gather(
        x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
    x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token
    # add pos embed
    x = x + model.decoder_pos_embed
    embeddings = []
    # apply Transformer blocks
    for block_num, blk in enumerate(model.decoder_blocks):
        # Here is unrollment of the decoder blocks:
        x_temp = blk.norm1(x)
        # here is an unrollment of the attention mechanism:
        B, N, C = x_temp.shape
        qkv = blk.attn.qkv(x_temp).reshape(
            B, N, 3, blk.attn.num_heads, C // blk.attn.num_heads).permute(2, 0, 3, 1, 4)
        # make torchscript happy (cannot use tensor as tuple)
        q, k, v = qkv.unbind(0)
        embeddings.append(
                (q.detach().cpu().numpy(), k.detach().cpu().numpy(), v.detach().cpu().numpy()))
        if block_num == index:
            return {
                'q': q.detach().cpu().numpy(), 
                'k': k.detach().cpu().numpy(), 
                'v': v.detach().cpu().numpy()
                }[attribute] 
        attn = (q @ k.transpose(-2, -1)) * blk.attn.scale
        # The attention shape is [1, 16, 197, 197]
        # This is where our code comes to mind:
        attn = attn.softmax(dim=-1)

        x_temp = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x_temp = blk.attn.proj(x_temp)
        x_temp = blk.attn.proj_drop(x_temp)
        # Here we continue to the orignal block.
        x = x + blk.drop_path(x_temp)

        x = x + blk.drop_path(blk.mlp(blk.norm2(x)))
    return embeddings


def show_image(image, title='', ax=None, patches=None, lines=None):
    # image is [H, W, 3]
    if ax is None:
        _, ax = plt.subplots()
    if patches is not None:
        for patch in patches:
            patch = [16 * (patch // 14), 16 *(patch % 14)]
            query = patches_plt.Rectangle(
                patch[::-1], 15, 15, linewidth=1, edgecolor='r', facecolor='none')
            ax.add_patch(query)
    assert image.shape[2] == 3
    if lines is not None:
        for line in lines:
            x, y = line
            x = [16 * (x // 14), 16 *(x % 14)]
            y = [16 * (y // 14), 16 *(y % 14)]
            plt.plot([y[1] + 8, x[1] + 8], [y[0] + 8, x[0] + 8], color="red", linewidth=1)
    ax.imshow(torch.clip((image.cpu().detach() * imagenet_std + imagenet_mean) * 255, 0, 255).int())
    plt.title(title, fontsize=16)
    plt.axis('off')
    return

def generate_for_training(orig_image, model, ids_shuffle, len_keep: int, device: str = 'cpu'):
    """ids_shuffle is [bs, 196]"""
    ids_shuffle = ids_shuffle.to(device)
    # make it a batch-like
    temp_x = orig_image.unsqueeze(0)

    # RUN ENCODER:
    # embed patches
    latent = model.patch_embed(temp_x.float())

    # add pos embed w/o cls token
    latent = latent + model.pos_embed[:, 1:, :]

    # masking: length -> length * mask_ratio
    N, L, D = latent.shape  # batch, length, dim
    # sort noise for each sample
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    # keep the first subset
    ids_keep = ids_shuffle[:, :len_keep]
    latent = torch.gather(
        latent, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

    # generate the binary mask: 0 is keep, 1 is remove
    mask = torch.ones([N, L], device=latent.device)
    mask[:, :len_keep] = 0
    # unshuffle to get the binary mask
    mask = torch.gather(mask, dim=1, index=ids_restore)

    # append cls token
    cls_token = model.cls_token + model.pos_embed[:, :1, :]
    cls_tokens = cls_token.expand(latent.shape[0], -1, -1)
    latent = torch.cat((cls_tokens, latent), dim=1)

    # apply Transformer blocks
    for blk in model.blocks:
        latent = blk(latent)
    latent = model.norm(latent)

    x = model.decoder_embed(latent)

    # append mask tokens to sequence
    mask_tokens = model.mask_token.repeat(
        x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)

    x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token


    x_ = torch.gather(
        x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
    x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token
    # add pos embed
    x = x + model.decoder_pos_embed

    # apply Transformer blocks
    for blk in model.decoder_blocks:
        x = blk(x)
    x = model.decoder_norm(x)

    # predictor projection
    x = model.decoder_pred(x)

    # remove cls token
    x = x[:, 1:, :]
    return x, mask


@torch.no_grad()
def generate_decoder_attention_maps(orig_image, model, ids_shuffle, len_keep, index: int = -1, device: str = 'cpu'):
    """ids_shuffle is [bs, 196]"""
    ids_shuffle = ids_shuffle.to(device)
    # make it a batch-like
    orig_image = convert_to_tensor(orig_image).to(device)
    temp_x = orig_image.clone().detach().to(device)

    # embed patches
    latent = model.patch_embed(temp_x.float())

    # add pos embed w/o cls token
    latent = latent + model.pos_embed[:, 1:, :]

    # masking: length -> length * mask_ratio
    N, L, D = latent.shape  # batch, length, dim
    # sort noise for each sample
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    # keep the first subset
    ids_keep = ids_shuffle[:, :len_keep]
    latent = torch.gather(
        latent, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

    # generate the binary mask: 0 is keep, 1 is remove
    mask = torch.ones([N, L], device=latent.device)
    mask[:, :len_keep] = 0
    # unshuffle to get the binary mask
    mask = torch.gather(mask, dim=1, index=ids_restore)

    # append cls token
    cls_token = model.cls_token + model.pos_embed[:, :1, :]
    cls_tokens = cls_token.expand(latent.shape[0], -1, -1)
    latent = torch.cat((cls_tokens, latent), dim=1)

    # apply Transformer blocks
    for blk in model.blocks:
        latent = blk(latent)
    latent = model.norm(latent)
    x = model.decoder_embed(latent)

    # append mask tokens to sequence
    mask_tokens = model.mask_token.repeat(
        x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
    x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
    x_ = torch.gather(
        x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
    x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token
    # add pos embed
    x = x + model.decoder_pos_embed
    embeddings = []
    attns = []
    # apply Transformer blocks
    for block_num, blk in enumerate(model.decoder_blocks):
        # Here is unrollment of the decoder blocks:
        x_temp = blk.norm1(x)
        # here is an unrollment of the attention mechanism:
        B, N, C = x_temp.shape
        qkv = blk.attn.qkv(x_temp).reshape(
            B, N, 3, blk.attn.num_heads, C // blk.attn.num_heads).permute(2, 0, 3, 1, 4)
        # make torchscript happy (cannot use tensor as tuple)
        q, k, v = qkv.unbind(0)
        attn = (q @ k.transpose(-2, -1)) * blk.attn.scale
        # The attention shape is [1, 16, 197, 197]
        # This is where our code comes to mind:
        attn = attn.softmax(dim=-1)
        attns.append(attn.detach().cpu().numpy())
        if block_num == index:
            return attns[-1]
        
        x_temp = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x_temp = blk.attn.proj(x_temp)
        x_temp = blk.attn.proj_drop(x_temp)
        # Here we continue to the orignal block.
        x = x + blk.drop_path(x_temp)

        x = x + blk.drop_path(blk.mlp(blk.norm2(x)))
    return attns