#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.

import os
import sys
import time
import argparse
import numpy as np
from tqdm import tqdm
import re
import datetime
import PIL
import PIL.Image as Image
from PIL import ImageDraw
import torch
import torch.nn.functional as F
from torchvision import transforms
from pycocotools import mask
import pycocotools.mask as mask_util
from scipy import ndimage
from scipy.linalg import eigh
import json
import cv2
from detectron2.utils.colormap import random_color
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from scipy.ndimage import generate_binary_structure
from skimage import morphology

import dino
from crf import densecrf

# Image transformation applied to all images
ToTensor = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize(
                                (0.485, 0.456, 0.406),
                                (0.229, 0.224, 0.225)),])
def IoU(mask1, mask2):
    mask1, mask2 = (mask1>0.5).to(torch.bool), (mask2>0.5).to(torch.bool)
    intersection = torch.sum(mask1 * (mask1 == mask2), dim=[-1, -2]).squeeze()
    union = torch.sum(mask1 + mask2, dim=[-1, -2]).squeeze()
    return (intersection.to(torch.float) / union).mean().item()

def random_color(rgb=False, maximum=255):
    color = np.random.randint(0, maximum + 1, size=3) 
    if not rgb:
        color = color[::-1] 
    return color

def resize_pil(I, patch_size=16) : 
    w, h = I.size

    new_w, new_h = int(round(w / patch_size)) * patch_size, int(round(h / patch_size)) * patch_size
    feat_w, feat_h = new_w // patch_size, new_h // patch_size

    return I.resize((new_w, new_h), resample=Image.LANCZOS), w, h, feat_w, feat_h

def vis_mask(input, mask, mask_color) :
    fg = mask > 0.5
    rgb = np.copy(input)
    rgb[fg] = (rgb[fg] * 0.3 + np.array(mask_color) * 0.7).astype(np.uint8)
    return Image.fromarray(rgb)

def vis_box(input, mask, box_color) :
    ys, xs = np.where(mask > 0)
    # print(ys, xs)
    if len(xs) == 0 or len(ys) == 0:
        print('mask is zero!')
        return input

    x_min, x_max = xs.min(), xs.max()
    y_min, y_max = ys.min(), ys.max()
    print(x_min,y_min,x_max,y_max)

    box_color = tuple(int(c) for c in box_color)
    box_color = (255, 0, 0)
    # print(box_color)
    
    draw = ImageDraw.Draw(input)
    draw.rectangle([x_min, y_min, x_max, y_max], outline=box_color, width=2)

    return input

def get_affinity_matrix(feats, tau, eps=1e-5):
    # get affinity matrix via measuring patch-wise cosine similarity
    feats = F.normalize(feats, p=2, dim=0)
    A = (feats.transpose(0,1) @ feats).cpu().numpy()
    # convert the affinity matrix to a binary one.
    A = A > tau
    A = np.where(A.astype(float) == 0, eps, A)
    d_i = np.sum(A, axis=1)
    D = np.diag(d_i)
    return A, D
# def get_affinity_matrix(feats, tau, eps=1e-5):
#     # feats: (784, 3600)
#     feats = F.normalize(feats, p=2, dim=0)  # Cosine normalization along feature dim
#     sim = feats.T @ feats  # (3600, 3600), cosine similarity matrix

#     # Binary affinity (thresholding)
#     A = (sim > tau).float()  # torch.float32
#     A[A == 0] = eps  # Avoid zero entries

#     # Degree matrix (as 1D vector)
#     d_i = A.sum(dim=1)
#     D = torch.diag(d_i)  # Optional: or return d_i only to save memory
#     return A, D  # Both are torch.float32

def second_smallest_eigenvector(A, D):
    # get the second smallest eigenvector from affinity matrix
    _, eigenvectors = eigh(D-A, D, subset_by_index=[1,2])
    eigenvec = np.copy(eigenvectors[:, 0])
    second_smallest_vec = eigenvectors[:, 0]
    return eigenvec, second_smallest_vec

# def second_smallest_eigenvector(A, D, device):
#     L = D - A
#     diag_D = torch.diag(D)
#     inv_sqrt_diag = diag_D.pow(-0.5)
#     D_inv_sqrt = torch.diag(inv_sqrt_diag)
#     L_sym = D_inv_sqrt @ L @ D_inv_sqrt
#     n = A.shape[0]
#     X = torch.randn(n, 2, device=device, dtype=torch.float32)
#     Q, _ = torch.linalg.qr(X, mode="reduced")
#     eigvals, eigvecs = torch.lobpcg(A=L_sym, k=2, B=None, X=Q, niter=1000, tol=1e-5, largest=False)
#     # eigvals, eigvecs = torch.lobpcg(A=L_sym,k=2,X=X,niter=2000,tol=1e-6,largest=False)
#     y2 = eigvecs[:, 1]
#     second_smallest_vec = D_inv_sqrt @ y2
#     return second_smallest_vec.cpu().numpy()


def ncut(features, tau=0.15, eps=1e-5, eig_vecs=1):
    """
    Compute the normalized cut eigenvectors and eigenvalues. Ths function uses pytorch to compute batched eigenvectors.
    :param features: batched features of shape (batch_size, num_nodes, feature_dim)
    :param tau: threshold for the adjacency matrix
    :param eps: small value to add to the adjacency matrix to avoid division by zero
    :param eig_vecs: number of eigenvectors to compute. If eig_vecs is 1, then only the second-smallest eigenvector is
     returned
    :return: eigenvectors and eigenvalues, both of shape (batch_size, num_nodes, eig_vecs).
    If the eigenvalue computation fails, then (None, None) is returned
    """
    features = F.normalize(features, p=2, dim=-1)
    A = torch.bmm(features, features.permute(0, 2, 1))
    # A = A > tau
    # A = torch.where(~A, eps, A.float())
    A = torch.where(A > tau, A, torch.tensor(eps, device=A.device, dtype=A.dtype))

    # A has shape (batch_size, num_nodes, num_nodes)
    batch_size, num_nodes, _ = A.size()

    # Create diagonal degree matrix D
    D = torch.sum(A, dim=2)
    D_diag = torch.diag_embed(D)
    D_over_sqrt = torch.diag_embed(torch.sqrt(1.0 / D))

    # Compute normalized Laplacian L = D^(-1/2) * (D - A) * D^(-1/2)
    L = torch.matmul(D_over_sqrt, torch.matmul(D_diag - A, D_over_sqrt)) # L.device same to features.device (cuda)
    try:
        # Compute the eigenvectors and eigenvalues of L
        eigenvalues, eigenvectors = torch.linalg.eigh(L, UPLO='L')
    except:
        # if eigh fails then D is not positive definite, and we should return Nones
        print("eigh failed")
        return None, None
    eigenvalues, eigenvectors = eigenvalues[:, 1:eig_vecs+1], eigenvectors[:, :, 1:eig_vecs+1]

    return eigenvectors, eigenvalues

def ncut_eigh(features, tau=0.15, eps=1e-5):
    """
    Returns the second smallest eigenvector of the normalized Laplacian using torch.linalg.eigh.
    
    Args:
        features: Tensor of shape (D, N), features for N nodes.
        tau: threshold for affinity matrix.
        eps: small value to replace zeros.

    Returns:
        second_vec: Tensor of shape (N,)
    """
    # Normalize features over the N dimension (dim=0)
    feats = F.normalize(features, p=2, dim=0)
    N = feats.shape[1]

    # Cosine similarity and binary affinity
    sim = feats.T @ feats           # (N, N)
    A = (sim > tau).float()
    A[A == 0] = eps

    # Degree and inv sqrt
    d = A.sum(dim=1)                # (N,)
    inv_sqrt_d = d.pow(-0.5)        # (N,)

    # Compute normalized symmetric Laplacian: L_sym = I - D^{-1/2} A D^{-1/2}
    D_inv_sqrt = torch.diag(inv_sqrt_d)
    L_sym = torch.eye(N, device=features.device) - D_inv_sqrt @ A @ D_inv_sqrt

    # Compute eigenvalues/eigenvectors via torch.linalg.eigh
    eigenvalues, eigenvectors = torch.linalg.eigh(L_sym, UPLO='L')  # :contentReference[oaicite:1]{index=1}

    # eigenvalues sorted ascending; first is zero, second is Fiedler
    second_vec = eigenvectors[:, 1]  # (N,)

    return second_vec

def affinity_second_eigenvector(feats, tau, eps=1e-5):
    device = feats.device
    # Normalize features
    feats = F.normalize(feats, p=2, dim=0)
    N = feats.shape[1]

    # Cosine similarity matrix
    sim = feats.T @ feats  # (N, N)

    # Binary affinity matrix
    A = (sim > tau).float()
    A[A == 0] = eps

    # Degree matrix (as 1D vector)
    d = A.sum(dim=1)                                    # d: (N,)
    inv_sqrt_d = d.pow(-0.5)                            # inv_sqrt_d: (N,)
    D_inv_sqrt = torch.diag(inv_sqrt_d)                 # D^{-1/2}: (N, N)

    # 5. Build normalized Laplacian: L_sym = I - D^{-1/2} A D^{-1/2}
    # I = torch.eye(N, device=device, dtype=torch.float32)
    # L_sym = I - (D_inv_sqrt @ A @ D_inv_sqrt)            # L_sym: (N, N)
    D = torch.diag(d)
    L = D - A
    L_sym = D_inv_sqrt @ L @ D_inv_sqrt
    
    # 6. Prepare initial guess X with 2 orthonormal columns
    X = torch.randn(N, 2, device=device, dtype=torch.float32)
    Q, _ = torch.linalg.qr(X)                           # Q: (N, 2), orthonormal  
        
    eigvals, eigvecs = torch.lobpcg(A=L_sym, k=2, B=None, X=Q, niter=1000, tol=1e-5, largest=False)

    # 8. Map back to generalized eigenvector: x = D^{-1/2} y
    y2 = eigvecs[:, 1]                                  # second-smallest in normalized space
    second_vec = D_inv_sqrt @ y2                        # shape: (N,)
    # Return the second smallest eigenvector
    return second_vec  # shape (N,)

def check_num_fg_corners(bipartition, dims):
    # check number of corners belonging to the foreground
    bipartition_ = bipartition.reshape(dims)
    top_l, top_r, bottom_l, bottom_r = bipartition_[0][0], bipartition_[0][-1], bipartition_[-1][0], bipartition_[-1][-1]
    nc = int(top_l) + int(top_r) + int(bottom_l) + int(bottom_r)
    return nc

def get_masked_affinity_matrix(painting, feats, mask, ps):
    # mask out affinity matrix based on the painting matrix 
    dim, num_patch = feats.size()[0], feats.size()[1]
    painting = painting + mask.unsqueeze(0)
    painting[painting > 0] = 1
    painting[painting <= 0] = 0
    feats = feats.clone().view(dim, ps, ps)
    feats = ((1 - painting) * feats).view(dim, num_patch)
    return feats, painting

def get_masks(eigen_vec, dims, reverse):
    masks = []
    # print(eigen_vec.shape)
    # get_salient_areas
    if reverse:
        print('------------  reverse eigen vector')
        eigen_vec = eigen_vec * -1
    
    avg = np.average(eigen_vec)
    # print(np.max(eigen_vec), np.min(eigen_vec), np.average(eigen_vec))
    bipartition = eigen_vec > avg
    # print('>>>  eigen_vec : ', eigen_vec)
    # print(eigen_vec.shape)

    # exit()
    # check if we should reverse the partition based on:
    # 1) peak of the 2nd smallest eigvec 2) object centric bias
    # seed = np.argmax(np.abs(eigen_vec))
    nc = check_num_fg_corners(bipartition, dims)
    if nc >= 3 and not reverse:
        # print('------------  reverse eigen vector')
        eigen_vec = eigen_vec * -1
        bipartition = np.logical_not(bipartition)
    
    # plt.imsave(fname=f"debug/attentions.png", arr=eigen_vec.reshape(dims), cmap='cividis')

    # if nc >= 3:
    #     reverse = True
    # #     print('reverse !!!')
    # else:
    #     reverse = False
        # reverse = bipartition[seed] != 1
    #     print('reverse :', reverse)
    # if np.average(eigen_vec) < 0 and np.abs(np.max(eigen_vec)) < np.abs(np.min(eigen_vec)):
    #     reverse = True

    # if reverse:
    #     print('------------  reverse eigen vector')
    #     eigen_vec = eigen_vec * -1
    #     bipartition = np.logical_not(bipartition)

    # eigen_vec = eigen_vec * -1
    # bipartition = np.logical_not(bipartition)

    bipartition = bipartition.reshape(dims)
    bipartition = ndimage.binary_fill_holes(bipartition)
    # labels, counts = np.unique(bipartition, return_counts=True)
    # teh=min(counts)
    # print(counts)
    # print(teh)

    # bipartition_bool = bipartition.astype(bool)
    # bipartition1 = morphology.remove_small_holes(bipartition_bool, area_threshold=1028)
    # Image.fromarray((bipartition1 * 255).astype(np.uint8)).save(f"debug/morphology.png")


    # Image.fromarray((bipartition * 255).astype(np.uint8)).save(f"debug/bipartition.png")
    # structure = [
    #     [1,1,1],
    #     [1,1,1],
    #     [1,1,1]
    # ]
    # structure = generate_binary_structure(rank=2, connectivity=2)
    # objects, num_objects = ndimage.label(bipartition, structure) #背景为0才有效。reverse的逻辑怎么搞？
    objects, num_objects = ndimage.label(bipartition) #背景为0才有效。reverse的逻辑怎么搞？
    if num_objects < 1:
        raise ValueError('num_objects < 1, algorithnm fail')
    # print(f'num_objects: {num_objects}')
    # objects, num_objects = ndimage.label(bipartition)
    # cc = objects[np.unravel_index(seed, dims)]
    labels, counts = np.unique(objects, return_counts=True)
    # print('labels:', labels)
    # print('counts:', counts)
    # means = ndimage.mean(eigen_vec.reshape(dims), labels=objects, index=labels)
    # order = np.argsort(means)[::-1]  # 从大到小排序
    # sort_counts = counts[order]
    # sort_labels = labels[order]
    # sort_means  = means[order]
    # print('sort_counts: ', sort_counts)
    # print('sort_labels: ', sort_labels)
    # print('sort_means: ', sort_means)
    # cmap = ListedColormap(plt.cm.tab10.colors)
    # plt.imsave('debug/objects-mean.png', objects % 10, cmap=cmap, format='png')
    # mapped_means = (sort_means - sort_means[-1]) / (sort_means[0] - sort_means[-1] + 1e-8)
    # salient_mask = mapped_means > 0.6 # ablation study
    # means_greater = sort_means[salient_mask]
    # labels_greater = sort_labels[salient_mask]
    # counts_greater = sort_counts[salient_mask]
    # print(f'mapped_means: {mapped_means}')
    # print(f'means_greater: {means_greater}')
    # print(f'--labels_greater: {labels_greater}')
    # print(f'counts_greater: {counts_greater}')

    sums = ndimage.sum(eigen_vec.reshape(dims), labels=objects, index=labels)
    order = np.argsort(sums)[::-1]  # 从大到小排序
    sort_counts = counts[order]
    sort_labels = labels[order]
    sort_sums  = sums[order]
    # print('sort_counts: ', sort_counts)
    # print('sort_labels: ', sort_labels)
    # print('sort_sums: ', sort_sums)
    # sort_sums = sort_sums[:-1]
    # sort_labels = sort_labels[:-1]
    # sort_counts = sort_counts[:-1]
    # main_object_label = sort_labels[0]
    # sort_sums = sort_sums[1:]
    # sort_labels = sort_labels[1:]
    # sort_counts = sort_counts[1:]
    mapped_sums = (sort_sums - sort_sums[-1]) / (sort_sums[0] - sort_sums[-1] + 1e-8)
    salient_mask = mapped_sums > 0.5 # ablation study
    sums_greater = sort_sums[salient_mask]
    labels_greater = sort_labels[salient_mask]
    counts_greater = sort_counts[salient_mask]
    # print(f'mapped_sums: {mapped_sums}')
    # print(f'sums_greater: {sums_greater}')
    # print(f'--labels_greater: {labels_greater}')
    # print(f'counts_greater: {counts_greater}')
    # labels_greater = np.insert(labels_greater, 0, main_object_label)
    # cmap = ListedColormap(plt.cm.tab10.colors)
    # plt.imsave('debug/objects-sum.png', objects % 10, cmap=cmap, format='png')
    # mask_pos = np.where(objects == cc)
    # pseudo_mask = np.zeros(dims).astype(np.uint8)
    # pseudo_mask[mask_pos[0],mask_pos[1]] = 1
    # pseudo_mask_resize = cv2.resize(pseudo_mask, (480, 480), interpolation=cv2.INTER_NEAREST)
    # # Image.fromarray((pseudo_mask * 255).astype(np.uint8)).save(f"debug/mask_bbb0.png")
    # masks.append(pseudo_mask_resize)
    # for idx, cc in enumerate(sort_labels[:-1]):

    sort_sums = sort_sums[:-1]
    sort_labels = sort_labels[:-1]
    sort_counts = sort_counts[:-1]
    for idx, cc in enumerate(sort_labels):
        if idx == 0:
            prev_obj_count = sort_counts[0]
        elif sort_counts[idx] < (prev_obj_count * 0.1):
            # print('break for')
            break
        mask_pos = np.where(objects == cc)
        pseudo_mask = np.zeros(dims).astype(np.uint8)
        pseudo_mask[mask_pos[0],mask_pos[1]] = 1
        # Image.fromarray((pseudo_mask * 255).astype(np.uint8)).save(f"debug/mask_get_{idx}.png")
        masks.append(pseudo_mask)

    # print('get masks:', len(masks))
    return masks

def maskcut_forward(device, feats, dims, scales, init_image_size, tau=0, reverse=False):
    bipartitions = []

    feats = feats.permute(1, 0).unsqueeze(0)
    # print(feats.shape)
    vecs, _ = ncut(feats, tau)
    second_smallest_vec = vecs.view(-1)
    # print(second_smallest_vec.shape)
    # exit()

    second_smallest_vec = second_smallest_vec.cpu().numpy()
    bipartitions = get_masks(second_smallest_vec, dims, reverse)

    return bipartitions, second_smallest_vec

def maskcut(device, img_path, backbone, patch_size, tau, fixed_size=480, reverse=False):
    I = Image.open(img_path).convert('RGB')

    I_new = I.resize((int(fixed_size), int(fixed_size)), PIL.Image.LANCZOS)
    I_resize, w, h, feat_w, feat_h = resize_pil(I_new, patch_size)

    tensor = ToTensor(I_resize).unsqueeze(0)
    tensor = tensor.to(device)
    feat = backbone(tensor)[0]

    bipartitions, eigen_vec = maskcut_forward(device, feat, [feat_h, feat_w], [patch_size, patch_size], [h,w], tau, reverse)


    return bipartitions, eigen_vec, I_new

def resize_binary_mask(array, new_size):
    image = Image.fromarray(array.astype(np.uint8)*255)
    image = image.resize(new_size)
    return np.asarray(image).astype(np.bool_)

def bbox_from_mask(mask: np.ndarray):
    # bbox format is [x, y, width, height]
    x = np.where(mask.sum(axis=0))[0]
    y = np.where(mask.sum(axis=1))[0]
    bbox = [np.min(x), np.min(y), np.max(x) - np.min(x) + 1, np.max(y) - np.min(y) + 1]
    return np.array(bbox)

def IoU_bbox(mask1, mask2):
    """
    This method calculates the IoU between the two bboxes of mask1 and mask2.
    :param mask1:
    :param mask2:
    :return:
    """
    bbox_1 = bbox_from_mask(mask1)
    bbox_2 = bbox_from_mask(mask2)
    # calculate the intersection area
    x1 = max(bbox_1[0], bbox_2[0])
    y1 = max(bbox_1[1], bbox_2[1])
    x2 = min(bbox_1[0] + bbox_1[2], bbox_2[0] + bbox_2[2])
    y2 = min(bbox_1[1] + bbox_1[3], bbox_2[1] + bbox_2[3])
    intersection_area = max(0, x2 - x1) * max(0, y2 - y1)
    # calculate the union area
    union_area = bbox_1[2] * bbox_1[3] + bbox_2[2] * bbox_2[3] - intersection_area
    return intersection_area / union_area

def mask_post_processing_offical(bipartition, image_rgb, device='cpu'):
    success = True
    bipartition = bipartition.astype(np.float32)
    bipartition = F.interpolate(torch.from_numpy(bipartition[None, None, :, :]), size=(height, width), mode='nearest')[0][0].numpy()
    # print(bipartition.shape)
    try:
        pseudo_mask = densecrf(np.array(image_rgb), bipartition)
        # pseudo_mask = bipartition
        pseudo_mask = ndimage.binary_fill_holes(pseudo_mask>=0.5)
        # print(pseudo_mask.shape)

        # filter out the mask that have a very different pseudo-mask after the CRF
        mask1 = torch.from_numpy(bipartition)
        mask2 = torch.from_numpy(pseudo_mask)
        mask1 = mask1.to(device)
        mask2 = mask2.to(device)
        # # Image.fromarray((bipartition*255).astype(np.uint8)).save(f"debug/mask_{idx}.png")
        # # Image.fromarray((pseudo_mask*255).astype(np.uint8)).save(f"debug/mask_{idx}_crf.png")
        iou = IoU(mask1, mask2)
        if np.sum(pseudo_mask) == 0 or iou < 0.5:
            return bipartition, False
        binary_mask = pseudo_mask
    except Exception as e:
        # in case crf failed for some reason use the original mask
        binary_mask = bipartition
        success = False
    return binary_mask, success

def mask_post_processing(mask, image_rgb, device='cpu'):
    """
    Post-processing of the mask. It performs crf and returns the final mask in the original image size.
    In case of crf failure, it returns the original mask.
    mask: numpy array of shape [height, width] with [0,1] values
    image_rgb: PIL image
    return: tuple - (mask as numpy array of shape [height, width] with [0,1] values, success flag)
    """
    success = True
    image_orig_size = image_rgb.size
    rescale_size = (image_orig_size[1], image_orig_size[0])
    # resizes the mask to the original image size with nearest neighbor interpolation
    mask = mask.astype(np.float32)
    patches_mask = F.interpolate(torch.from_numpy(mask[None, None, :, :]), size=rescale_size, mode='nearest')[0][0].numpy()
    # crop the mask by the bounding box
    bbox = bbox_from_mask(patches_mask)
    x, y, w, h = bbox
    crop_x = (max(x - w//3, 0), min((x + w) + w//3, rescale_size[1]))
    crop_y = (max(y - h//3, 0), min((y + h) + h//3, rescale_size[0]))
    mask_cropped = patches_mask[crop_y[0]:crop_y[1], crop_x[0]:crop_x[1]]
    # crop the image by the bounding box
    img = np.asarray(image_rgb).copy()
    img_cropped = img[crop_y[0]:crop_y[1], crop_x[0]:crop_x[1], :]
    # apply CRF to the bounding box
    try:
        pseudo_mask_crop = densecrf(img_cropped, mask_cropped)
        pseudo_mask_crop = ndimage.binary_fill_holes(pseudo_mask_crop >= 0.5)
        # create a pseudo mask with the same size as the original image
        pseudo_mask = np.zeros_like(patches_mask)
        pseudo_mask[crop_y[0]:crop_y[1], crop_x[0]:crop_x[1]] = pseudo_mask_crop
        # in case crf did not provide a mask or the IoU between the original mask and the pseudo mask is too different
        # Image.fromarray((patches_mask*255).astype(np.uint8)).save(f"debug/mask.png")
        # Image.fromarray((pseudo_mask*255).astype(np.uint8)).save(f"debug/mask_crf.png")
        # we consider the mask as not an object
        if np.sum(pseudo_mask) == 0 or IoU_bbox(torch.from_numpy(patches_mask).to(device), torch.from_numpy(pseudo_mask).to(device)) < 0.5:
            return patches_mask, False
        binary_mask = pseudo_mask
    except Exception as e:
        # in case crf failed for some reason use the original mask
        binary_mask = patches_mask
        success = False
    return binary_mask, success

def mask_post_processing_new(mask, image_rgb, device='cpu'):
    """
    Post-processing of the mask. It performs crf and returns the final mask in the original image size.
    In case of crf failure, it returns the original mask.
    mask: numpy array of shape [height, width] with [0,1] values
    image_rgb: PIL image
    return: tuple - (mask as numpy array of shape [height, width] with [0,1] values, success flag)
    """
    success = True
    image_orig_size = image_rgb.size
    rescale_size = (image_orig_size[1], image_orig_size[0])
    # resizes the mask to the original image size with nearest neighbor interpolation
    mask = mask.astype(np.float32)
    patches_mask = F.interpolate(torch.from_numpy(mask[None, None, :, :]), size=rescale_size, mode='nearest')[0][0].numpy()
    # return patches_mask, success
    # crop the mask by the bounding box
    bbox = bbox_from_mask(patches_mask)
    x, y, w, h = bbox
    factor = 0.33
    crop_x = (max(x - int(w*factor), 0), min((x + w) + int(w*factor), rescale_size[1]))
    crop_y = (max(y - int(h*factor), 0), min((y + h) + int(h*factor), rescale_size[0]))
    mask_cropped = patches_mask[crop_y[0]:crop_y[1], crop_x[0]:crop_x[1]]
    # crop the image by the bounding box
    img = np.asarray(image_rgb).copy()
    # print(img.shape)
    # exit()
    img_cropped = img[crop_y[0]:crop_y[1], crop_x[0]:crop_x[1], :]
    # apply CRF to the bounding box
    try:
        # pseudo_mask_crop = img_cropped, mask_cropped
        pseudo_mask_crop = densecrf(img_cropped, mask_cropped)
        pseudo_mask_crop = ndimage.binary_fill_holes(pseudo_mask_crop)
        # pseudo_mask_crop = ndimage.binary_fill_holes(pseudo_mask_crop >= 0.5)
        # create a pseudo mask with the same size as the original image
        pseudo_mask = np.zeros_like(patches_mask)
        pseudo_mask[crop_y[0]:crop_y[1], crop_x[0]:crop_x[1]] = pseudo_mask_crop
        # in case crf did not provide a mask or the IoU between the original mask and the pseudo mask is too different
        # Image.fromarray((patches_mask*255).astype(np.uint8)).save(f"debug/mask.png")
        # Image.fromarray((pseudo_mask*255).astype(np.uint8)).save(f"debug/mask_crf.png")
        # we consider the mask as not an object
        # iou = IoU(torch.from_numpy(patches_mask).to(device), torch.from_numpy(pseudo_mask).to(device))
        # print('iou: ', iou)
        # if np.sum(pseudo_mask) == 0 or iou < 0.33:
        #     # print('iou: ', iou)
        #     return patches_mask, False
        binary_mask = pseudo_mask
    except Exception as e:
        # in case crf failed for some reason use the original mask
        binary_mask = patches_mask
        success = False
    return binary_mask, success

def create_image_info(image_id, file_name, image_size, 
                      date_captured=datetime.datetime.now(datetime.timezone.utc).isoformat(' '),
                      license_id=1, coco_url="", flickr_url=""):
    """Return image_info in COCO style
    Args:
        image_id: the image ID
        file_name: the file name of each image
        image_size: image size in the format of (width, height)
        date_captured: the date this image info is created
        license: license of this image
        coco_url: url to COCO images if there is any
        flickr_url: url to flickr if there is any
    """
    image_info = {
            "id": image_id,
            "file_name": file_name,
            "width": image_size[0],
            "height": image_size[1],
            "date_captured": date_captured,
            "license": license_id,
            "coco_url": coco_url,
            "flickr_url": flickr_url
    }
    return image_info


def create_annotation_info(annotation_id, image_id, category_info, binary_mask, 
                           image_size=None, bounding_box=None):
    """Return annotation info in COCO style
    Args:
        annotation_id: the annotation ID
        image_id: the image ID
        category_info: the information on categories
        binary_mask: a 2D binary numpy array where '1's represent the object
        file_name: the file name of each image
        image_size: image size in the format of (width, height)
        bounding_box: the bounding box for detection task. If bounding_box is not provided, 
        we will generate one according to the binary mask.
    """
    upper = np.max(binary_mask)
    lower = np.min(binary_mask)
    thresh = upper / 2.0
    binary_mask[binary_mask > thresh] = upper
    binary_mask[binary_mask <= thresh] = lower
    if image_size is not None:
        binary_mask = resize_binary_mask(binary_mask.astype(np.uint8), image_size)

    binary_mask_encoded = mask.encode(np.asfortranarray(binary_mask.astype(np.uint8)))

    area = mask.area(binary_mask_encoded)
    if area < 1:
        return None

    if bounding_box is None:
        bounding_box = mask.toBbox(binary_mask_encoded)

    rle = mask_util.encode(np.array(binary_mask[...,None], order="F", dtype="uint8"))[0]
    rle['counts'] = rle['counts'].decode('ascii')
    segmentation = rle

    annotation_info = {
        "id": annotation_id,
        "image_id": image_id,
        "category_id": category_info["id"],
        "iscrowd": 0,
        "area": area.tolist(),
        "score": 1.0, # pycocotools need this field
        "bbox": bounding_box.tolist(),
        "segmentation": segmentation,
        "width": binary_mask.shape[1],
        "height": binary_mask.shape[0],
    } 

    return annotation_info

# necessay info used for coco style annotations
INFO = {
    "description": "COCO: pseudo-masks with MaskCut",
    "url": "https://github.com/facebookresearch/CutLER",
    "version": "1.0",
    "year": 2023,
    "contributor": "Xudong Wang",
    "date_created": datetime.datetime.now(datetime.timezone.utc).isoformat(' ')
}

LICENSES = [
    {
        "id": 1,
        "name": "Apache License",
        "url": "https://github.com/facebookresearch/CutLER/blob/main/LICENSE"
    }
]

# only one class, i.e. foreground
CATEGORIES = [
    {
        'id': 1,
        'name': 'fg',
        'supercategory': 'fg',
    },
]

convert = lambda text: int(text) if text.isdigit() else text.lower()
natrual_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]

output = {
        "info": INFO,
        "licenses": LICENSES,
        "categories": CATEGORIES,
        "images": [],
        "annotations": []}
output = []

category_info = {
    "is_crowd": 0,
    "id": 1
}

if __name__ == "__main__":

    parser = argparse.ArgumentParser('MaskCut script')
    # default arguments
    parser.add_argument('--out-dir', type=str, default='mask_annotation', help='output directory')
    parser.add_argument('--vit-arch', type=str, default='base', choices=['base', 'small'], help='which architecture')
    parser.add_argument('--vit-feat', type=str, default='k', choices=['k', 'q', 'v', 'kqv'], help='which features')
    parser.add_argument('--patch-size', type=int, default=8, choices=[16, 8], help='patch size')
    parser.add_argument('--nb-vis', type=int, default=20, choices=[1, 200], help='nb of visualization')
    parser.add_argument('--img-path', type=str, default=None, help='single image visualization')

    # additional arguments
    parser.add_argument('--dataset-path', type=str, default="/data/xxx/datasets/coco/val2017", help='path to the dataset')
    parser.add_argument('--tau', type=float, default=0.15, help='threshold used for producing binary graph')
    parser.add_argument('--num-folder-per-job', type=int, default=1, help='the number of folders each job processes')
    parser.add_argument('--job-index', type=int, default=0, help='the index of the job (for imagenet: in the range of 0 to 1000/args.num_folder_per_job-1)')
    parser.add_argument('--fixed_size', type=int, default=480, help='rescale the input images to a fixed size')
    parser.add_argument('--pretrain_path', type=str, default=None, help='path to pretrained model')
    parser.add_argument('--N', type=int, default=1, help='the maximum number of pseudo-masks per image')

    args = parser.parse_args()

    if args.pretrain_path is not None:
        url = args.pretrain_path
    if args.vit_arch == 'base' and args.patch_size == 8:
        if args.pretrain_path is None:
            url = "https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
        feat_dim = 768
    elif args.vit_arch == 'small' and args.patch_size == 8:
        if args.pretrain_path is None:
            url = "https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"
        feat_dim = 384
    elif args.vit_arch == 'base' and args.patch_size == 16:
        if args.pretrain_path is None:
            url = "https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
        feat_dim = 768
    elif args.vit_arch == 'small' and args.patch_size == 16:
        if args.pretrain_path is None:
            url = "https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
        feat_dim = 384
    
    print(args)

    backbone = dino.ViTFeat(url, feat_dim, args.vit_arch, args.vit_feat, args.patch_size)

    msg = 'Load {} pre-trained feature...'.format(args.vit_arch)
    print (msg)

    device = torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu")

    backbone.eval()
    backbone.to(device)

    start_time = time.time()
    if args.out_dir is not None and not os.path.exists(args.out_dir) :
        os.mkdir(args.out_dir)

    image_id, segmentation_id = 1, 1
    image_names = []

    if args.img_path != None:
        pbar = tqdm([args.img_path])
        pseudo_mask_list = []
    else:
        img_list = sorted(os.listdir(args.dataset_path))
        # img_list = ['000000031217.jpg', '000000173383.jpg', '000000286849.jpg', '000000050145.jpg', '000000570448.jpg', '000000311180.jpg', '000000232244.jpg', '000000017031.jpg', '000000180798.jpg', '000000326174.jpg']
        # pbar = tqdm(img_list[:10])
        pbar = tqdm(img_list)

    for img_name in pbar:
        # get image path
        if args.img_path != None:
            img_path = img_name
            image_id = img_name.split('/')[-1].split('.')[0]
        else:
            img_path = os.path.join(args.dataset_path, img_name)
            image_id = int(img_name.split('.')[0])

        # print(img_name, image_id)
        # get pseudo-masks for each image using MaskCut
        bipartitions, eigen_vec, I_new = maskcut(device, img_path, backbone, args.patch_size, args.tau, fixed_size=args.fixed_size)
        # t1 = time.time()

        # try:
        #     bipartitions, _, I_new = maskcut(device, img_path, backbone, args.patch_size, \
        #         args.tau, fixed_size=args.fixed_size)
        # except:
        #     print(f'Skipping {img_name}')
        #     continue
        t2 = time.time()
        # print(f't1 Time cost: {str(datetime.timedelta(milliseconds=int((t2 - t1)*1000)))}')
        index = 0
        I = Image.open(img_path).convert('RGB')
        width, height = I.size
        for idx, bipartition in enumerate(bipartitions):
            # pseudo_mask,success = mask_post_processing(bipartition, I)
            # pseudo_mask,success = mask_post_processing_offical(bipartition, I)
            pseudo_mask,success = mask_post_processing_new(bipartition, I)
            if not success:
                continue
            # construct binary pseudo-masks
            # pseudo_mask[pseudo_mask < 0] = 0
            pseudo_mask = Image.fromarray(np.uint8(pseudo_mask*255))
            pseudo_mask = np.asarray(pseudo_mask.resize((width, height)))
            
            index += 1
            if args.img_path != None:
                pseudo_mask_list.append(pseudo_mask)
                continue
            # create coco-style image info
            if img_name not in image_names:
                image_info = create_image_info(
                    image_id, "{}/{}".format(args.dataset_path, img_name), (height, width, 3))
                # output["images"].append(image_info)
                image_names.append(img_name)           

            # create coco-style annotation info
            annotation_info = create_annotation_info(
                segmentation_id, image_id, category_info, pseudo_mask.astype(np.uint8), None)
            if annotation_info is not None:
                # output["annotations"].append(annotation_info)
                output.append(annotation_info)
                segmentation_id += 1
            # output.append(predictions)
        # print(f'mask numbers: {index+1}')
        t3 = time.time()
        # print(f't2 Time cost: {str(datetime.timedelta(milliseconds=int((t3 - t2)*1000)))}')

    if args.img_path != None:
        if len(pseudo_mask_list) == 0:
            print('=== segment fail ===')
        print('number of instances:',len(pseudo_mask_list))
        # for i,bi in enumerate(pseudo_mask_list):
        #     Image.fromarray((bi).astype(np.uint8)).save(f"mask3_{i}.png")
        input = np.array(I)
        for pseudo_mask in pseudo_mask_list:
            input = vis_mask(input, pseudo_mask, random_color(rgb=True))
            # break
        for pseudo_mask in pseudo_mask_list:
            input = vis_box(input, pseudo_mask, random_color(rgb=True))
            # break
        file_name = os.path.basename(args.img_path)
        name_without_ext = os.path.splitext(file_name)[0]
        output_file = f"{name_without_ext}-askonce.png"
        input.save(os.path.join("", output_file))
        print(f'dumping {output_file}')
        end_time = time.time()
        print(f'Time cost: {str(datetime.timedelta(milliseconds=int((end_time - start_time)*1000)))}')
        exit()
    # save annotations
    json_name = '{}/coco_train_fixsize{}_tau{}-Mask-improve.json'.format(args.out_dir, args.fixed_size, args.tau, )

    with open(json_name, 'w') as output_json_file:
        json.dump(output, output_json_file, indent=2)
    print(f'dumping {json_name}')
    # print("Done: {} images; {} anns.".format(len(output['images']), len(output['annotations'])))

    