import os
import sys
import time
import argparse
import numpy as np
from tqdm import tqdm
import PIL.Image as Image
from PIL import ImageDraw
import torch
import torch.nn.functional as F
from torchvision import transforms
from scipy import ndimage
import json
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

from utils.annotations_worker import CocoAnnotationsWorker
from networks import get_model 
from cutonce.crf import densecrf

# Image transformation applied to all images
ToTensor = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize(
                                (0.485, 0.456, 0.406),
                                (0.229, 0.224, 0.225)),])
# Visualization 
Colors = [
    (255, 0, 0),    # 红色
    (255, 165, 0),  # 橙色
    (255, 255, 0),  # 黄色
    (0, 255, 0),    # 绿色
    (0, 0, 255),    # 蓝色
    (128, 0, 128),  # 紫色
    (255, 192, 203), # 粉色
]

def IoU(mask1, mask2):
    mask1, mask2 = (mask1>0.5).to(torch.bool), (mask2>0.5).to(torch.bool)
    intersection = torch.sum(mask1 * (mask1 == mask2), dim=[-1, -2]).squeeze()
    union = torch.sum(mask1 + mask2, dim=[-1, -2]).squeeze()
    return (intersection.to(torch.float) / union).mean().item()

def random_color(rgb=False, maximum=255):
    color = np.random.randint(0, maximum + 1, size=3) 
    if not rgb:
        color = color[::-1] 
    return color

def vis_mask(input, mask, mask_color) :
    fg = mask > 0.5
    rgb = np.copy(input)
    rgb[fg] = (rgb[fg] * 0.5 + np.array(mask_color) * 0.5).astype(np.uint8)
    return Image.fromarray(rgb)

def vis_box(input, mask, box_color) :
    ys, xs = np.where(mask > 0)
    # print(ys, xs)
    if len(xs) == 0 or len(ys) == 0:
        return input

    x_min, x_max = xs.min(), xs.max()
    y_min, y_max = ys.min(), ys.max()
    print(x_min,y_min,x_max,y_max)

    box_color = tuple(int(c) for c in box_color)
    box_color = (255, 0, 0)
    # print(box_color)
    
    draw = ImageDraw.Draw(input)
    draw.rectangle([x_min, y_min, x_max, y_max], outline=box_color, width=2)

    return input

def ncut(features, tau=0.15, eps=1e-5, eig_vecs=1):
    """
    Compute the normalized cut eigenvectors and eigenvalues. Ths function uses pytorch to compute batched eigenvectors.
    :param features: batched features of shape (batch_size, num_nodes, feature_dim)
    :param tau: threshold for the adjacency matrix
    :param eps: small value to add to the adjacency matrix to avoid division by zero
    :param eig_vecs: number of eigenvectors to compute. If eig_vecs is 1, then only the second-smallest eigenvector is
     returned
    :return: eigenvectors and eigenvalues, both of shape (batch_size, num_nodes, eig_vecs).
    If the eigenvalue computation fails, then (None, None) is returned
    """
    features = F.normalize(features, p=2, dim=-1)
    A = torch.bmm(features, features.permute(0, 2, 1))
    A = A > tau
    A = torch.where(~A, eps, A.float())

    # Create diagonal degree matrix D
    D = torch.sum(A, dim=2)
    D_diag = torch.diag_embed(D)
    D_over_sqrt = torch.diag_embed(torch.sqrt(1.0 / D))

    # Compute normalized Laplacian L = D^(-1/2) * (D - A) * D^(-1/2)
    L = torch.matmul(D_over_sqrt, torch.matmul(D_diag - A, D_over_sqrt)) # L.device same to features.device (cuda)
    try:
        # Compute the eigenvectors and eigenvalues of L
        eigenvalues, eigenvectors = torch.linalg.eigh(L, UPLO='L')
    except:
        # if eigh fails then D is not positive definite, and we should return Nones
        print("eigh failed")
        return None, None
    eigenvalues, eigenvectors = eigenvalues[:, 1:eig_vecs+1], eigenvectors[:, :, 1:eig_vecs+1]

    return eigenvectors, eigenvalues, A

def ncut2(features, tau=0.15, eps=1e-5, eig_vecs=1, 
         k_neighbors=10, alpha=0.5, base_temp=1.0, chunk_size=200):
    B, N, D = features.shape
    features = F.normalize(features, p=2, dim=-1).squeeze(0)  # [3600, 768]
    
    # 先计算原始余弦相似度矩阵用于密度估计
    local_densities = torch.zeros(N, device=features.device)
    
    for i in range(0, N, chunk_size):
        end_i = min(i + chunk_size, N)
        # 计算当前块与所有点的余弦相似度
        cosine_sim = torch.mm(features[i:end_i], features.T)  # [chunk_size, 3600]
        
        # 计算局部密度：取每个点的k近邻平均相似度
        topk_values, _ = torch.topk(cosine_sim, k=k_neighbors+1, dim=-1)  # +1排除自己
        local_densities[i:end_i] = topk_values[:, 1:].mean(dim=-1)  # 排除自己
    
    # 分块计算自适应亲和矩阵
    A = torch.zeros(N, N, device=features.device)
    
    for i in range(0, N, chunk_size):
        end_i = min(i + chunk_size, N)
        feat_i = features[i:end_i]  # [chunk_size, 768]
        density_i = local_densities[i:end_i]  # [chunk_size]
        
        for j in range(0, N, chunk_size):
            end_j = min(j + chunk_size, N)
            feat_j = features[j:end_j]  # [chunk_size, 768]
            density_j = local_densities[j:end_j]  # [chunk_size]
            
            # 计算余弦相似度
            cosine_sim = torch.mm(feat_i, feat_j.T)  # [chunk_size, chunk_size]
            
            # 计算自适应温度
            # 广播密度到相似度矩阵形状
            density_i_expanded = density_i.unsqueeze(1)  # [chunk_size, 1]
            density_j_expanded = density_j.unsqueeze(0)  # [1, chunk_size]
            avg_density = (density_i_expanded + density_j_expanded) / 2  # [chunk_size, chunk_size]
            
            # 自适应温度：密度高的区域用更高温度（更平滑）
            temperature = base_temp + alpha * avg_density
            
            # 应用温度调制
            adaptive_sim = cosine_sim / temperature
            
            A[i:end_i, j:end_j] = adaptive_sim

    # 恢复batch维度并进行后续处理
    A = A.unsqueeze(0)  # [1, 3600, 3600]
    A = A > tau
    A = torch.where(~A, eps, A.float())
    
    # Create diagonal degree matrix D
    D = torch.sum(A, dim=2)
    D_diag = torch.diag_embed(D)
    D_over_sqrt = torch.diag_embed(torch.sqrt(1.0 / D))

    # Compute normalized Laplacian L = D^(-1/2) * (D - A) * D^(-1/2)
    L = torch.matmul(D_over_sqrt, torch.matmul(D_diag - A, D_over_sqrt)) # L.device same to features.device (cuda)
    try:
        # Compute the eigenvectors and eigenvalues of L
        eigenvalues, eigenvectors = torch.linalg.eigh(L, UPLO='L')
    except:
        # if eigh fails then D is not positive definite, and we should return Nones
        print("eigh failed")
        return None, None
    eigenvalues, eigenvectors = eigenvalues[:, 1:eig_vecs+1], eigenvectors[:, :, 1:eig_vecs+1]

    return eigenvectors, eigenvalues, A

def num_corners_on_border_mask(mask):
    """
    :param mask: binary mask of shape (H, W)
    """
    # check if there is an overlap between the bbox and at list 2 image borders
    num_of_corners_on_border = int(mask[0, 0]) + int(mask[0, -1]) + int(mask[-1, 0]) + int(mask[-1, -1])
    return num_of_corners_on_border

def compute_neighbor_diff_sum(tensor):
    """
    The sum of the absolute values of the differences between each point and the four adjacent points is calculated
    Boundary filling strategy: fill the missing neighbors with the value of the current point
    """
    h, w = tensor.shape
    
    padded = torch.zeros(h+2, w+2, dtype=tensor.dtype, device=tensor.device)

    padded[1:h+1, 1:w+1] = tensor
    
    padded[0, 1:w+1] = tensor[0, :]
    padded[h+1, 1:w+1] = tensor[h-1, :]
    padded[1:h+1, 0] = tensor[:, 0]
    padded[1:h+1, w+1] = tensor[:, w-1]
    
    padded[0, 0] = tensor[0, 0]
    padded[0, w+1] = tensor[0, w-1] 
    padded[h+1, 0] = tensor[h-1, 0]
    padded[h+1, w+1] = tensor[h-1, w-1]

    up = padded[0:h, 1:w+1]
    down = padded[2:h+2, 1:w+1] 
    left = padded[1:h+1, 0:w]
    right = padded[1:h+1, 2:w+2]

    center = tensor

    diff_sum = (torch.abs(center - up) + 
                torch.abs(center - down) + 
                torch.abs(center - left) + 
                torch.abs(center - right))
    
    return diff_sum

def normalized_cut_score(W, mask):
    A = mask
    B = ~mask
    cut = W[A][:, B].sum()
    assoc_A = W[A].sum()
    assoc_B = W[B].sum()
    return cut / assoc_A + cut / assoc_B

def find_best_threshold(y, W, num_thresh=20):
    # Calculate base values
    t0 = torch.mean(y)
    y_max = torch.max(y)
    y_min = torch.min(y)
    delta_d = (y_max - y_min) / 100
    upper_bound = 0.5 * (y_max + t0)
    lower_bound = 0.5 * (y_min + t0)

    # # Generate multipliers from -5 to 5
    # multipliers = torch.arange(-5, 6, device=y.device)  # [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]

    # # Calculate base thresholds
    # base_thresholds = t0 + multipliers * delta_d

    # # Apply different bounds for positive and negative multipliers
    # # For positive: use upper_bound (minimum with upper limit)
    # # For negative: use lower_bound (maximum with lower limit)
    # positive_mask = multipliers >= 0  # [False, False, False, False, False, True, True, True, True, True, True]
    # thresholds = torch.where(positive_mask, 
    #                         torch.minimum(base_thresholds, upper_bound),  # For positive multipliers
    #                         torch.maximum(base_thresholds, lower_bound))  # For negative multipliers
    multipliers = torch.arange(0, 6, device=y.device)
    thresholds = torch.minimum(t0 + multipliers * delta_d, upper_bound)
    # Convert to list of 11 values
    thresholds = thresholds.tolist()
    best_score = float('inf')
    best_thresh = thresholds[0]

    for t in thresholds:
        mask = y > t
        # 防止全为True或False导致assoc为0
        if mask.sum() == 0 or mask.sum() == len(mask):
            continue
        score = normalized_cut_score(W, mask)
        if score < best_score:
            best_score = score
            best_thresh = t

    labels = y > best_thresh
    return best_thresh, labels

def compute_neighbor_diff_sum(tensor, neighbors=4):
   h, w = tensor.shape
   
   padded = torch.zeros(h+2, w+2, dtype=tensor.dtype, device=tensor.device)
   padded[1:h+1, 1:w+1] = tensor
   
   padded[0, 1:w+1] = tensor[0, :]
   padded[h+1, 1:w+1] = tensor[h-1, :]
   padded[1:h+1, 0] = tensor[:, 0]
   padded[1:h+1, w+1] = tensor[:, w-1]
   
   padded[0, 0] = tensor[0, 0]
   padded[0, w+1] = tensor[0, w-1] 
   padded[h+1, 0] = tensor[h-1, 0]
   padded[h+1, w+1] = tensor[h-1, w-1]

   center = tensor
   
   if neighbors == 4:
       up = padded[0:h, 1:w+1]
       down = padded[2:h+2, 1:w+1] 
       left = padded[1:h+1, 0:w]
       right = padded[1:h+1, 2:w+2]
       
       diff_sum = (torch.abs(center - up) + 
                   torch.abs(center - down) + 
                   torch.abs(center - left) + 
                   torch.abs(center - right))
       diff_avg = diff_sum / 4
   
   elif neighbors == 8:
       up = padded[0:h, 1:w+1]
       down = padded[2:h+2, 1:w+1] 
       left = padded[1:h+1, 0:w]
       right = padded[1:h+1, 2:w+2]
       up_left = padded[0:h, 0:w]
       up_right = padded[0:h, 2:w+2]
       down_left = padded[2:h+2, 0:w]
       down_right = padded[2:h+2, 2:w+2]
       
       diff_sum = (torch.abs(center - up) + 
                   torch.abs(center - down) + 
                   torch.abs(center - left) + 
                   torch.abs(center - right) +
                   torch.abs(center - up_left) +
                   torch.abs(center - up_right) +
                   torch.abs(center - down_left) +
                   torch.abs(center - down_right))
       diff_avg = diff_sum / 8
   
#    mean_val = diff_avg.mean()
#    diff_avg = torch.where(diff_avg > mean_val, diff_avg, torch.zeros_like(diff_avg))
   return diff_avg

def filter_arrays_by_threshold(sort_sums, sort_counts, sort_labels, threshold=0.95):
   total_sums = sum(sort_sums)
   total_counts = sum(sort_counts)
   
   m = len(sort_sums)
   n = len(sort_counts)
   
   cumsum_sums = 0
   for i, val in enumerate(sort_sums):
       cumsum_sums += val
       if cumsum_sums / total_sums > threshold:
           m = i + 1
           break
   
   cumsum_counts = 0
   for i, val in enumerate(sort_counts):
       cumsum_counts += val
       if cumsum_counts / total_counts > threshold:
           n = i + 1
           break
   
   keep_items = min(m, n)
   return sort_sums[:keep_items], sort_counts[:keep_items], sort_labels[:keep_items]

def get_masks(eigen_vec, A):
    avg = torch.mean(eigen_vec)
    bipartition = eigen_vec > avg
    # eigen_vec = eigen_vec * -1
    if num_corners_on_border_mask(bipartition) >= 3:
        eigen_vec = eigen_vec * -1
    elif torch.abs(torch.min(eigen_vec)) > 6 * torch.abs(torch.max(eigen_vec)):
        eigen_vec = eigen_vec * -1

    eigen_vec = eigen_vec - compute_neighbor_diff_sum(eigen_vec, neighbors=8)

    avg = torch.mean(eigen_vec)
    bipartition = eigen_vec > avg
    # plt.imsave('debug/attention_no_no.png', eigen_vec, cmap='coolwarm', format='png')
    # plt.imsave('debug/bipartition_no_no.png', ~bipartition, cmap='binary', format='png')
    # Image.fromarray(bipartition).save(f"debug/bipartition_denisty.png")
    # _, bipartition = find_best_threshold(eigen_vec.view(-1), A[0])
    # bipartition = bipartition.reshape(60,60)
    # bipartition = ndimage.binary_fill_holes(bipartition)

    objects, num_objects = ndimage.label(bipartition) 
    if num_objects < 1:
        return []
        # raise ValueError('num_objects < 1, algorithnm fail')
    # cmap = ListedColormap(plt.cm.tab20.colors)
    # plt.imsave('objects.png', objects % 20, cmap=cmap, format='png')
    labels, counts = np.unique(objects, return_counts=True)

    sums = ndimage.sum(eigen_vec, labels=objects, index=labels)
    order = np.argsort(sums)[::-1]
    sort_counts = counts[order]
    sort_labels = labels[order]
    sort_sums  = sums[order]

    sort_sums = sort_sums[:-1]
    sort_labels = sort_labels[:-1]
    sort_counts = sort_counts[:-1]
    sort_sums, sort_counts, sort_labels = filter_arrays_by_threshold(sort_sums, sort_counts, sort_labels)
    masks = []
    for idx, cc in enumerate(sort_labels):
        # if idx == 0:
        #     first_obj_count = sort_counts[idx]
        # elif sort_counts[idx] < (first_obj_count * 0.09):
        #     # print('break for')
        #     break
        mask_pos = np.where(objects == cc)
        pseudo_mask = np.zeros_like(bipartition).astype(np.uint8)
        pseudo_mask[mask_pos[0],mask_pos[1]] = 1
        # if num_corners_on_border_mask(pseudo_mask) >= 2:
        #         continue
        # Image.fromarray((pseudo_mask * 255).astype(np.uint8)).save(f"debug/mask_get_{idx}.png")
        masks.append(pseudo_mask)

    # print('get masks:', len(masks))
    return masks

def process_eig_vecs(eig_vec):
    if eig_vec.shape[1] == 900:
            eig_vec = eig_vec.reshape(eig_vec.shape[0], 30, 30)
    elif eig_vec.shape[1] == 3600:
        eig_vec = eig_vec.reshape(eig_vec.shape[0], 60, 60)
    elif eig_vec.shape[1] == 1156:
        eig_vec = eig_vec.reshape(eig_vec.shape[0], 34, 34)
    else:
        raise ValueError(f"Invalid eig vec shape: {eig_vec.shape}")
    return eig_vec

def maskcut_forward(feats, tau):
    feats = feats.unsqueeze(0)
    # feats = feats.permute(1, 0).unsqueeze(0)
    # vecs, _, A = ncut(feats, tau)
    vecs, _, A = ncut2(feats, tau)
    second_smallest_vec = process_eig_vecs(vecs)
    second_smallest_vec = second_smallest_vec[0].to("cpu")

    # plt.imsave(fname=f"affinity_mattrix.pdf", arr=A[0].cpu(), cmap='RdBu',vmin=1e-5, vmax=1)

    bipartitions = get_masks(second_smallest_vec, A)
    
    return bipartitions, second_smallest_vec

def maskcut(device, img_path, backbone, patch_size, tau, fixed_size=480):
    img = Image.open(img_path).convert('RGB')
    # img = img.resize((fixed_size, fixed_size), Image.Resampling.LANCZOS)
    img = img.resize((int(fixed_size), int (fixed_size)), Image.Resampling.LANCZOS)
    image_tensor = ToTensor(img).unsqueeze(0)
    image_tensor = image_tensor.to(device)
    _, k, _ = backbone.get_last_qkv(image_tensor)
    k = k.transpose(1, 2).reshape(image_tensor.shape[0], k.shape[2], -1)
    features = k[:, 1:, :][0]
    # tensor = ToTensor(img).unsqueeze(0)
    # tensor = tensor.to(device)
    # features = backbone(tensor)[0]

    # features = features @ F.softmax(features.T @ features, dim=-1)
    # dims = (int(np.ceil(fixed_size / patch_size)), int(np.ceil(fixed_size / patch_size)))
    bipartitions, eigen_vec = maskcut_forward(features, tau)

    return bipartitions, eigen_vec

def resize_binary_mask(array, new_size):
    image = Image.fromarray(array.astype(np.uint8)*255)
    image = image.resize(new_size)
    return np.asarray(image).astype(np.bool_)

def bbox_from_mask(mask: np.ndarray):
    # bbox format is [x, y, width, height]
    x = np.where(mask.sum(axis=0))[0]
    y = np.where(mask.sum(axis=1))[0]
    bbox = [np.min(x), np.min(y), np.max(x) - np.min(x) + 1, np.max(y) - np.min(y) + 1]
    return np.array(bbox)

def IoU_bbox(mask1, mask2):
    """
    This method calculates the IoU between the two bboxes of mask1 and mask2.
    :param mask1:
    :param mask2:
    :return:
    """
    bbox_1 = bbox_from_mask(mask1)
    bbox_2 = bbox_from_mask(mask2)
    # calculate the intersection area
    x1 = max(bbox_1[0], bbox_2[0])
    y1 = max(bbox_1[1], bbox_2[1])
    x2 = min(bbox_1[0] + bbox_1[2], bbox_2[0] + bbox_2[2])
    y2 = min(bbox_1[1] + bbox_1[3], bbox_2[1] + bbox_2[3])
    intersection_area = max(0, x2 - x1) * max(0, y2 - y1)
    # calculate the union area
    union_area = bbox_1[2] * bbox_1[3] + bbox_2[2] * bbox_2[3] - intersection_area
    return intersection_area / union_area

def mask_post_processing_offical(bipartition, image_rgb, device='cpu'):
    success = True
    bipartition = bipartition.astype(np.float32)
    bipartition = F.interpolate(torch.from_numpy(bipartition[None, None, :, :]), size=(height, width), mode='nearest')[0][0].numpy()
    # print(bipartition.shape)
    try:
        # pseudo_mask = densecrf(np.array(image_rgb), bipartition)
        pseudo_mask = bipartition
        # pseudo_mask = ndimage.binary_fill_holes(pseudo_mask>=0.5)

        # filter out the mask that have a very different pseudo-mask after the CRF
        mask1 = torch.from_numpy(bipartition)
        mask2 = torch.from_numpy(pseudo_mask)
        mask1 = mask1.to(device)
        mask2 = mask2.to(device)
        # # Image.fromarray((bipartition*255).astype(np.uint8)).save(f"debug/mask_{idx}.png")
        # # Image.fromarray((pseudo_mask*255).astype(np.uint8)).save(f"debug/mask_{idx}_crf.png")
        iou = IoU(mask1, mask2)
        if np.sum(pseudo_mask) == 0 or iou < 0.5:
            return bipartition, False
        binary_mask = pseudo_mask
    except Exception as e:
        # in case crf failed for some reason use the original mask
        binary_mask = bipartition
        success = False
    return binary_mask, success

def mask_post_processing(mask, image_rgb, device='cpu'):
    """
    Post-processing of the mask. It performs crf and returns the final mask in the original image size.
    In case of crf failure, it returns the original mask.
    mask: numpy array of shape [height, width] with [0,1] values
    image_rgb: PIL image
    return: tuple - (mask as numpy array of shape [height, width] with [0,1] values, success flag)
    """
    success = True
    image_orig_size = image_rgb.size
    rescale_size = (image_orig_size[1], image_orig_size[0])
    # resizes the mask to the original image size with nearest neighbor interpolation
    patches_mask = F.interpolate(torch.from_numpy(mask[None, None, :, :]), size=rescale_size, mode='nearest')[0][0].numpy()
    # crop the mask by the bounding box
    bbox = bbox_from_mask(patches_mask)
    x, y, w, h = bbox
    crop_x = (max(x - w//3, 0), min((x + w) + w//3, rescale_size[1]))
    crop_y = (max(y - h//3, 0), min((y + h) + h//3, rescale_size[0]))
    mask_cropped = patches_mask[crop_y[0]:crop_y[1], crop_x[0]:crop_x[1]]
    # crop the image by the bounding box
    img = np.asarray(image_rgb).copy()
    img_cropped = img[crop_y[0]:crop_y[1], crop_x[0]:crop_x[1], :]
    # apply CRF to the bounding box
    try:
        pseudo_mask_crop = densecrf(img_cropped, mask_cropped)
        pseudo_mask_crop = ndimage.binary_fill_holes(pseudo_mask_crop >= 0.5)
        # create a pseudo mask with the same size as the original image
        pseudo_mask = np.zeros_like(patches_mask)
        pseudo_mask[crop_y[0]:crop_y[1], crop_x[0]:crop_x[1]] = pseudo_mask_crop
        # in case crf did not provide a mask or the IoU between the original mask and the pseudo mask is too different
        # Image.fromarray((patches_mask*255).astype(np.uint8)).save(f"debug/mask.png")
        # Image.fromarray((pseudo_mask*255).astype(np.uint8)).save(f"debug/mask_crf.png")
        # we consider the mask as not an object
        if np.sum(pseudo_mask) == 0 or IoU_bbox(torch.from_numpy(patches_mask).to(device), torch.from_numpy(pseudo_mask).to(device)) < 0.5:
            return patches_mask, False
        binary_mask = pseudo_mask
    except Exception as e:
        # in case crf failed for some reason use the original mask
        binary_mask = patches_mask
        success = False
    return binary_mask, success

def mask_post_processing_new(mask, image_rgb, device='cpu'):
    """
    Post-processing of the mask. It performs crf and returns the final mask in the original image size.
    In case of crf failure, it returns the original mask.
    mask: numpy array of shape [height, width] with [0,1] values
    image_rgb: PIL image
    return: tuple - (mask as numpy array of shape [height, width] with [0,1] values, success flag)
    """
    success = True
    image_orig_size = image_rgb.size
    rescale_size = (image_orig_size[1], image_orig_size[0])
    # resizes the mask to the original image size with nearest neighbor interpolation
    patches_mask = F.interpolate(torch.from_numpy(mask[None, None, :, :]), size=rescale_size, mode='nearest')[0][0].numpy()
    patches_mask = ndimage.binary_fill_holes(patches_mask)
    return patches_mask, success
    # crop the mask by the bounding box
    bbox = bbox_from_mask(patches_mask)
    x, y, w, h = bbox
    factor = 0.33
    crop_x = (max(x - int(w*factor), 0), min((x + w) + int(w*factor), rescale_size[1]))
    crop_y = (max(y - int(h*factor), 0), min((y + h) + int(h*factor), rescale_size[0]))
    mask_cropped = patches_mask[crop_y[0]:crop_y[1], crop_x[0]:crop_x[1]]
    # crop the image by the bounding box
    img = np.asarray(image_rgb).copy()
    # print(img.shape)
    # exit()
    img_cropped = img[crop_y[0]:crop_y[1], crop_x[0]:crop_x[1], :]
    # apply CRF to the bounding box
    try:
        # pseudo_mask_crop = img_cropped, mask_cropped
        pseudo_mask_crop = densecrf(img_cropped, mask_cropped)
        pseudo_mask_crop = ndimage.binary_fill_holes(pseudo_mask_crop)
        # pseudo_mask_crop = ndimage.binary_fill_holes(pseudo_mask_crop >= 0.5)
        # create a pseudo mask with the same size as the original image
        pseudo_mask = np.zeros_like(patches_mask)
        pseudo_mask[crop_y[0]:crop_y[1], crop_x[0]:crop_x[1]] = pseudo_mask_crop
        # in case crf did not provide a mask or the IoU between the original mask and the pseudo mask is too different
        # Image.fromarray((patches_mask*255).astype(np.uint8)).save(f"debug/mask.png")
        # Image.fromarray((pseudo_mask*255).astype(np.uint8)).save(f"debug/mask_crf.png")
        # we consider the mask as not an object
        # iou = IoU(torch.from_numpy(patches_mask).to(device), torch.from_numpy(pseudo_mask).to(device))
        # print('iou: ', iou)
        # if np.sum(pseudo_mask) == 0 or iou < 0.33:
        #     # print('iou: ', iou)
        #     return patches_mask, False
        binary_mask = pseudo_mask
    except Exception as e:
        # in case crf failed for some reason use the original mask
        binary_mask = patches_mask
        success = False
    return binary_mask, success

if __name__ == "__main__":
    parser = argparse.ArgumentParser('MaskCut script')
    # default arguments
    parser.add_argument('--out-dir', type=str, default='mask_annotation', help='output directory')
    parser.add_argument('--model-name', type=str, default='dino_b8', choices=['dino_b8', 'dino_s8', 'dino_b16', 'dino_s16', 'dinov2_b14', 'dinov2_s14'], help='which architecture')
    parser.add_argument('--vit-feat', type=str, default='k', choices=['k', 'q', 'v', 'kqv'], help='which features')
    parser.add_argument('--patch-size', type=int, default=8, choices=[16, 8], help='patch size')
    parser.add_argument('--nb-vis', type=int, default=20, choices=[1, 200], help='nb of visualization')
    parser.add_argument('--img-path', type=str, 
                        # default=None, 
                        default='imgs/000000006471.jpg', 
                        help='single image visualization')

    # additional arguments
    parser.add_argument('--dataset-path', type=str, default="/data/xxx/datasets/coco/val2017", help='path to the dataset')
    parser.add_argument('--tau', type=float, default=0.15, help='threshold used for producing binary graph')
    parser.add_argument('--num-folder-per-job', type=int, default=1, help='the number of folders each job processes')
    parser.add_argument('--job-index', type=int, default=0, help='the index of the job (for imagenet: in the range of 0 to 1000/args.num_folder_per_job-1)')
    parser.add_argument('--fixed_size', type=int, default=480, help='rescale the input images to a fixed size')
    parser.add_argument("--resume", action="store_true", help="Resume from previous run")

    args = parser.parse_args()
    
    print(args)

    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
    # device = torch.device("cpu")

    backbone, patch_size = get_model(args.model_name, device)
    # import dino
    # url = "https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
    # feat_dim = 768
    # patch_size = 8
    # backbone = dino.ViTFeat(url, feat_dim, patch_size=patch_size)
    # backbone.eval()
    # backbone.to(device)

    start_time = time.time()
    os.makedirs(args.out_dir, exist_ok=True)



    out_file_basename=f'{args.out_dir}/coco_train_tau{args.tau}-cutonce.json'
    save_period = 100
    ts = time.time()

    if args.img_path != None:
        img_files = [args.img_path]
    else:
        img_files = sorted(os.listdir(args.dataset_path))
        # img_files = img_files[:5]
        worker_dir = 'tmp_e2e'
        ann_worker = CocoAnnotationsWorker(worker_dir)
        skipped_images_file = os.path.join(worker_dir, "skipped_images.txt")

        if args.resume:
            img_files = ann_worker.resume(img_files)
        else:
            ann_worker.cleanup()
            os.makedirs(worker_dir, exist_ok=True)
    
    pseudo_mask_list = []
    for index, img_name in enumerate(tqdm(img_files, desc="Creating pseudo labels")):
        # get image path
        if args.img_path != None:
            img_path = img_name
            # image_id = int(img_name.split('/')[-1].split('_')[-1].split('.')[0])
        else:
            img_path = os.path.join(args.dataset_path, img_name)
            image_id = int(img_name.split('.')[0])

        bipartitions, eigen_vec = maskcut(device, img_path, backbone, patch_size, args.tau, fixed_size=args.fixed_size)
        # try:
        #     bipartitions, _ = maskcut(device, img_path, backbone, patch_size, \
        #         args.tau, fixed_size=args.fixed_size)
        # except:
        #     print(f'Skipping {img_name}')
        #     continue
        t2 = time.time()

        I = Image.open(img_path).convert('RGB')
        width, height = I.size
        pseudo_mask_list.clear()
        num_masks = len(bipartitions)
        for idx, bipartition in enumerate(bipartitions):
            # pseudo_mask,success = mask_post_processing(bipartition, I)
            # pseudo_mask,success = mask_post_processing_offical(bipartition, I)
            pseudo_mask,success = mask_post_processing_new(bipartition, I)
            if not success:
                continue
            pseudo_mask = pseudo_mask.astype(np.uint8)
            score = 1.0 - idx/(2*num_masks)
            pseudo_mask_list.append({'data':pseudo_mask, 'score':score})
        
        if args.img_path == None:
            success = ann_worker.add_image_ann(image_id=image_id,file_name=img_name,height=height,width=width,image_masks=pseudo_mask_list)
            if not success:
                print(f"Failed to add image {img_name} to the annotations")
                with open(skipped_images_file, "a") as f:
                    f.write(f"{img_name}\n")
                continue
            
            if (index+1) % save_period == 0:
                ann_worker.flush_and_save_anns()

    if args.img_path == None:
        te = time.time()
        ann_worker.flush_and_save_anns()
        print(f'Total time cost: {te-ts} seconds')

        files = os.listdir(worker_dir)
        anns_files = [f for f in files if f.endswith('.json')]
        anns_paths = [os.path.join(worker_dir, fname) for fname in anns_files]
        anns = CocoAnnotationsWorker.collect_to_single_ann_dict(anns_paths)

        with open(out_file_basename, "w") as f:
            json.dump(anns, f, indent=2)
            print(f'dumping {out_file_basename}')

        from eval_coco_json import eval_coco_json
        eval_coco_json('coco', out_file_basename)
    else:
        num_masks = len(pseudo_mask_list)
        if num_masks == 0:
            print('=== segment fail ===')
        print('number of instances:', num_masks)
        input = np.array(I)

        for idx, pseudo_mask in enumerate(pseudo_mask_list):
            input = vis_mask(input, pseudo_mask['data'], Colors[idx%len(Colors)])
            # break
        for idx, pseudo_mask in enumerate(pseudo_mask_list):
            input = vis_box(input, pseudo_mask['data'], Colors[idx%len(Colors)])
            # break
        file_name = os.path.basename(args.img_path)
        name_without_ext = os.path.splitext(file_name)[0]
        # output_file = f"{name_without_ext}-cutonce.jpg"
        output_file = f"{name_without_ext}-cutonce.png"
        input.save(os.path.join("", output_file))
        print(f'dumping {output_file}')
        end_time = time.time()
        print(f'Time cost: {end_time - start_time} seconds')
