import os
import json
import argparse
from tqdm import tqdm
from os.path import join as ospj
import cv2
import torch
from torch.utils.data import Dataset
import numpy as np
import open3d as o3d
from pytorch3d.ops import ball_query
from dataset.utils import DATA_PATH, load_label_from_ori, load_pc_from_ori
from src.lib_o3d import batch_geo_feature,estimate_geo_feature
from pytorch3d.ops import sample_farthest_points
from process.conn import get_pc_mask_online
import pickle

import h5py

all_categories = ['Airplane', 'Bag', 'Cap', 'Car', 'Chair', 'Earphone', 'Guitar', 'Knife', 'Lamp', 'Laptop', 'Motorbike', 'Mug', 'Pistol', 'Rocket', 'Skateboard', 'Table']

def load_img(img_path):
    img_list = []
    for i in range(10):
        img = cv2.cvtColor(cv2.imread(os.path.join(img_path, str(i)+".png"), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
        img_list.append(img)
    img_list = np.array(img_list)/255
    return img_list

def load_pc_norm(path):
    pc = o3d.io.read_point_cloud("/PartSTAD_few_shot_2/rendered_pc/Bottle/3625/norm.ply")
    xyz = np.array(pc.points)
    norm = np.array(pc.normals)
    xyz = np.concatenate([xyz, norm], axis=-1)
    return xyz

# from torkit3d.ops.sample_farthest_points import sample_farthest_points
# def fps(pc,num_groups = 2048):
#     pc = torch.from_numpy(pc).unsqueeze(dim=0).cuda()
#     fps_idx = sample_farthest_points(pc, num_groups)
#     centers = batch_index_select(pc, fps_idx, dim=1)
#     return centers.squeeze(dim=0).cpu().numpy(), fps_idx.squeeze(dim=0).cpu().numpy()


def fps(xyz, K=2048):
    xyz = torch.from_numpy(xyz).cuda()
    xyz, sampled_indices = sample_farthest_points(xyz[None,:], K=K)
    return xyz[0].cpu().numpy(), sampled_indices[0].cpu().numpy()

def nearest_neighbors(point_cloud, sampled_indices, k=1):
    """
    计算点云 `point_cloud` (N,3) 到采样点 `sampled_indices` (M,) 的最近邻。
    
    参数:
        point_cloud (torch.Tensor): N x 3 点云
        sampled_indices (torch.Tensor): M x 1 采样点索引
        k (int): 选择最近的 k 个邻居，默认 k=1
    
    返回:
        nearest_indices (torch.Tensor): N x k 最近邻索引
        neighborhoods (torch.Tensor): M x k 对应的邻域索引
    """
    point_cloud = torch.from_numpy(point_cloud)
    sampled_indices = torch.from_numpy(sampled_indices)
    sampled_points = point_cloud[sampled_indices]  # 取出采样点 (M,3)
    
    # 计算所有点到采样点的欧式距离 (N, M)
    dists = torch.cdist(point_cloud, sampled_points, p=2)
    
    # 找到每个点最近的采样点索引 (N,)
    nearest_indices = dists.argmin(dim=1)# 转置后 M x N
    
    return nearest_indices.numpy()

def generate_is_seen(idx, npoint):
    nview = idx.shape[0]
    # 初始化 is_seen 张量
    is_seen = torch.zeros((nview, npoint), dtype=idx.dtype, device=idx.device)

    # 使用 torch.unique 和 torch.bincount 来避免显式循环
    for i in range(nview):
        # 获取当前视图的索引，并过滤掉 -1
        valid_idx = idx[i][idx[i] != -1]
        # 使用 torch.bincount 统计每个索引的出现次数
        counts = torch.bincount(valid_idx, minlength=npoint)
        # 标记被看到的点
        is_seen[i] = counts.clamp(max=1)  # 将大于 1 的值限制为 1

    return is_seen

def get_is_seen(idx, npoint):
    is_numpy = False
    if isinstance(idx, np.ndarray):
        is_numpy=True
        idx = torch.from_numpy(idx)
    
    nview = idx.shape[0]
    is_seen = torch.zeros((nview, npoint), dtype = idx.dtype, device=idx.device)
    for i in range(nview):
        tmp_idx = torch.unique(idx[i].reshape(-1))
        tmp_idx = tmp_idx[tmp_idx!=-1]
        is_seen[i,tmp_idx] = 1
    return is_seen
    
    if is_numpy:
        idx = idx.numpy()
    return is_seen



def get_seen_data(pc, label, idx, coor, pc_fpfh, pc_norm):
    N = label.shape[0]
    
    # is_seen = get_is_seen(idx, N)
    
    idx_seen = np.unique(idx[idx != -1].flatten())
    
    mapping = -np.ones(N, dtype=idx.dtype)
    mapping[idx_seen] = np.arange(idx_seen.shape[0], dtype=idx.dtype)
    pc_idx_seen = mapping[idx]
    pc_idx_seen[idx == -1] = -1
    
    pc_seen = pc[idx_seen]
    label_seen = label[idx_seen]
    coor_seen = coor[:, idx_seen]
    pc_fpfh_seen = pc_fpfh[idx_seen]
    pc_norm_seen = pc_norm[idx_seen]
    
    all_indices = np.arange(N)
    idx_unseen = ~np.isin(all_indices, idx_seen)
    pc_unseen = pc[idx_unseen]
    pc_label_unseen = label[idx_unseen]
    

    return pc_seen, label_seen, pc_unseen, pc_label_unseen, pc_idx_seen, coor_seen,pc_fpfh_seen, pc_norm_seen, idx_seen

def sample_pc(pc, label, idx, coor, pc_fpfh, pc_norm, K=2048*16):
    N = label.shape[0]
    
    # is_seen = get_is_seen(idx, N)
    
    # idx_seen = np.unique(idx[idx != -1].flatten())
    # pc_sampled, idx_sampled = fps(pc, K)
    if N<=K:
        return pc_seen, label_seen, pc_idx_seen, coor_seen, pc_fpfh, pc_norm
        
    idx_sampled = torch.randperm(N)[:K]
    
    mapping = -np.ones(N, dtype=idx.dtype)
    mapping[idx_sampled] = np.arange(idx_sampled.shape[0], dtype=idx.dtype)
    pc_idx_seen = mapping[idx]
    pc_idx_seen[idx == -1] = -1
    
    pc_seen = pc[idx_sampled]
    label_seen = label[idx_sampled]
    coor_seen = coor[:, idx_sampled]
    # is_seen = is_seen[:, idx_sampled]
    pc_fpfh = pc_fpfh[idx_sampled]
    pc_norm = pc_norm[idx_sampled]
    
    return pc_seen, label_seen, pc_idx_seen, coor_seen, pc_fpfh, pc_norm

def get_fpfh(pc, voxel_size=0.01):
    normal, fpfh = estimate_geo_feature(pc, voxel_size=voxel_size)
    return normal, fpfh


def get_ball_index(pc):
    pc = torch.from_numpy(pc).cuda().unsqueeze(dim=0).float()
    idx = ball_query(pc, pc, K=20, radius=0.01)
    return idx[1].squeeze(dim=0).cpu().numpy()

def calc_mask_conf(mask_label):
    hist = np.bincount(mask_label, minlength=10)
    lb = np.argmax(hist)
    conf = hist[lb]/(hist.sum()+1e-6)
    return lb, conf

def get_edge_index(edges):
    source_pt_list = []
    target_pt_list = []
    N = edges.shape[0]
    for i in range(N):
        for j in range(N):
            if edges[i,j]==1:
                source_pt_list.append(i)
                target_pt_list.append(j)
    edge_index = torch.stack([torch.tensor(source_pt_list), torch.tensor(target_pt_list)])
    return edge_index
    
def get_PN(mask_num, mask_group_ind, inverse_mapping, mask_area, flyd_edges, mask_label, id2mask, pc_idx, img, pc_label):
    
    valid_mask = []
    for i in range(mask_num):
        view, label = id2mask[i]
        img_ind = mask_label[view]==label
        pc_ind = pc_idx[view,img_ind]
        pc_ind = pc_ind[pc_ind!=-1]
        valid_mask.append(pc_ind.shape[0]/img_ind.sum()>0.9)
    valid_mask = np.array(valid_mask)
    mask_area = np.array(mask_area)
    valid_mask = valid_mask & (mask_area>100)
    PN_triples = []  # [(anchor_idx, pos_idx, neg_idx)]
    
    def save_img(i, mask_label, id2mask, img, save_dir):
        view, label = id2mask[i]
        img_ind = mask_label[view]==label
        cur_img = (img[view]*255).astype(np.uint8)
        cur_img[img_ind] = np.array([128,0,255])
        cv2.imwrite(save_dir, cur_img)
        
    save_dir = "./output/PN"
    os.makedirs(save_dir, exist_ok=True)
    for i in range(mask_num):
        if not valid_mask[i] or mask_area[i]>5000:
            continue
        pc_ind = mask_group_ind[i]
        num_points = len(pc_ind)
        if num_points < 2:
            continue
        pc_ind = inverse_mapping[pc_ind]

        pos_sample_idx = np.random.choice(pc_ind, size=num_points, replace=True) # shape: [num_points]

        # 构造负样本 mask 的采样概率
        tmp_mask_area = np.abs(mask_area - mask_area[i])
        valid_neg_mask = (flyd_edges[i] == 0) & valid_mask
        tmp_mask_area[~valid_neg_mask] = 0

        if np.sum(tmp_mask_area) == 0:
            continue  # 没有可采的负样本 mask

        # prob = tmp_mask_area / np.sum(tmp_mask_area)
        alpha = 1.0  
        prob = tmp_mask_area ** alpha
        prob = prob / np.sum(prob)

        # save_img(i, mask_label, id2mask, img, os.path.join(save_dir,"cur.png"))
        
        # 为每个 anchor 采一个负样本 mask，再采其中一个点
        neg_mask_ids = np.random.choice(mask_num, size=num_points, p=prob)
        neg_sample_idx = []
        for id, neg_id in enumerate(neg_mask_ids):
            neg_pc = mask_group_ind[neg_id]
            
            # if i==16:
            #     save_img(neg_id , mask_label, id2mask, img, os.path.join(save_dir,"neg.png"))
            if len(neg_pc) == 0:
                neg_sample_idx.append(-1)  # 占位标记无效
            else:
                neg_pc = inverse_mapping[neg_pc]
                neg_sample_idx.append(np.random.choice(neg_pc, replace=True))
            aa = 1
        aa=1

        # 过滤掉无效负样本
        neg_sample_idx = np.array(neg_sample_idx)
        good_mask = neg_sample_idx != -1
        # print(good_mask.shape)
        anchors = pc_ind[good_mask]
        positives = pos_sample_idx[good_mask]
        negatives = neg_sample_idx[good_mask]
        
        # a_label = pc_label[np.array(anchors)]
        # n_label = pc_label[np.array(negatives)]
        # acc = (a_label!=n_label).sum()/a_label.shape[0]
        

        PN_triples.extend(zip(anchors, positives, negatives))
    return np.array(PN_triples)

class Shape():
    def __init__(self,id, pc ,pc_label,unseen_pc, unseen_pc_label,img,mask_label,pc_idx,coords, pc_fpfh, pc_norm):   
        self.id = id
        self.pc=pc
        self.pc_label=pc_label
        self.unseen_pc = unseen_pc
        self.unseen_pc_label = unseen_pc_label
        self.img=img
        self.mask_label=mask_label
        self.pc_idx=pc_idx
        self.coords=coords
        self.pc_fpfh = pc_fpfh
        self.pc_norm = pc_norm
    def unpack(self):
        return self.id, self.pc, self.pc_label, self.unseen_pc, self.unseen_pc_label, self.img, self.mask_label, self.pc_idx, self.coords, self.pc_fpfh, self.pc_norm

def get_line_region(i,j,mask_pc_ind,nearest_index, grouped_indices):
    pc_ind_i = mask_pc_ind[i]
    sp_i_ind = nearest_index[pc_ind_i]
    pc_ind_j = mask_pc_ind[j]
    sp_j_ind = nearest_index[pc_ind_j]
    inter = np.intersect1d(sp_i_ind, sp_j_ind)
    tmp_pc_ind = []
    for k in inter:
        tmp_pc_ind.append(grouped_indices[k])
    if len(tmp_pc_ind)==0:
        return None
    tmp_pc_ind = np.concatenate(tmp_pc_ind)
    return tmp_pc_ind

class EdgeIndex():
    def __init__(self):
        self.s_nodes = []
        self.t_nodes = []
    def add(self, u,v):
        self.s_nodes.append(u)
        self.t_nodes.append(v)
    def get_edge_index(self):
        edge_index = torch.stack([torch.tensor(self.s_nodes), torch.tensor(self.t_nodes)])
        return edge_index                  
    def len(self):
        return len(self.s_nodes)


class ShapenetPartH5(Dataset):
    def __init__(self, split, category, args, save_data=False, show_figure=False):
        super().__init__()
        self.split = "train" if split == "few_shot" else split
        self.category = category.lower()
        self.data_path = "dataset/shapenet_part_seg_hdf5_data_Graph2"
        self.cat2id = {'airplane': 0, 'bag': 1, 'cap': 2, 'car': 3, 'chair': 4, 
                       'earphone': 5, 'guitar': 6, 'knife': 7, 'lamp': 8, 'laptop': 9, 
                       'motorbike': 10, 'mug': 11, 'pistol': 12, 'rocket': 13, 'skateboard': 14, 'table': 15}
        self.seg_num = [4, 2, 2, 4, 4, 3, 3, 2, 4, 2, 6, 2, 3, 3, 3, 3]
        self.index_start = [0, 4, 6, 8, 12, 16, 19, 22, 24, 28, 30, 36, 38, 41, 44, 47]
        self.id = self.cat2id[category.lower()]
        self.lb_start = self.index_start[self.id]
        
        self.data_list = json.load(open(ospj(self.data_path, f"sample_ind_{self.split}.json")))
        
        label_nums = json.load(open(os.path.join(self.data_path, "label_num.json")))
        self.num_label = label_nums[self.category.capitalize()]
        
        self.show_figure = show_figure
        
        self.use_cache = args.use_cache
        
        self.debug = args.debug
        self.use_propagate = args.use_propagate
        self.eliminate_sparseness = args.eliminate_sparseness
        self.use_gnn = args.use_gnn
        self.back_to_edges = args.back_to_edges
        self.use_pseudo_label = args.use_pseudo_label
        self.args = args
        self.shape_list = self.load_data(save_data)
    
    def get_conf_edges(self, graph, pc_label):
        edges_cnt = graph["edges_cnt"]
        mask_area = graph["mask_area"]
        self_i  = torch.arange(edges_cnt.shape[0])
        edges_cnt[self_i, self_i] = mask_area
        n_mask = max(graph["mask2id"].values())+1
        edges = np.zeros((n_mask, n_mask),dtype=np.int64)
        ratio_bar = 0.05
        for i in range(edges.shape[0]):
            edges[i,i]=1
            for j in range(i+1, edges.shape[1]):
                ri = edges_cnt[i,j] / mask_area[i]
                rj = edges_cnt[i,j] / mask_area[j]
                if ri>=ratio_bar and rj >=ratio_bar : #and (conf_i>=0.9 and conf_j>=0.9 and lb_i==lb_j): # and edges_cnt[i,j]>50:
                    edges[i,j]=1
                    edges[j,i]=1
                else:
                    edges[i,j]=0
                    edges[j,i]=0
                        
        return edges, get_edge_index(edges)
        
    def get_mask_norm_P(self, pc):
        mn = pc.min(axis=0)
        mx = pc.max(axis=0)
        dia = (mx-mn).max()
        center = pc.mean(axis=0)
        pc = (pc-center) / (dia+1e-4)
        return pc
    def get_normAng(self, pc, norm, pos, sample=False):
        N = pc.shape[0]
        v1 = pc-pos
        v2 = norm
        v1 = v1 / (np.linalg.norm(v1, axis=1, keepdims=True) + 1e-8) 
        v2 = v2 / (np.linalg.norm(v2, axis=1, keepdims=True) + 1e-8) 
        Ang = np.sum(v1 * v2, axis=1, keepdims=True)
        sample_num = 50
        if sample and N>sample_num:
            indices = np.random.choice(N, sample_num, replace=False)
            Ang = Ang[indices]
            
        return np.abs(Ang)
    
    def load_data(self, save_data = False):
        shape_id_list = self.data_list[self.category.capitalize()]
        shape_list = []
        # shape_id_list  = [shape_id_list[33]]
        for ii, shape_id in tqdm(enumerate(shape_id_list), total=len(shape_id_list), desc=f"loading {self.category} data"):
            shape_id = str(shape_id)
            
            # img_path = f"{self.data_path}/IMG/{self.category}/{shape_id}"
            # img = load_img(img_path)
            
            pc_meta_path = f"{self.data_path}/PC/{self.category}/{shape_id}/pc_meta.npz"
            pc_meta = np.load(pc_meta_path)
            pc = pc_meta["pc"]
            pc_label = pc_meta["pc_labels"] - self.lb_start
            pc_idx = pc_meta["pc_idx"]
            coords = pc_meta["coords"]
            img = pc_meta["feat"] 
            is_seen = pc_meta["is_seen"]
            # psd_label = pc_meta["psd_label"]
            
            
            pc_norm, pc_fpfh = get_fpfh(pc, voxel_size=0.02)
            
            unseen_pc, unseen_pc_label = pc, pc_label
            # pc, pc_label, unseen_pc, unseen_pc_label, pc_idx, coords, pc_fpfh, pc_norm, seen_ind= get_seen_data(pc, pc_label, pc_idx, coords, pc_fpfh, pc_norm)
            
            mask_path = f"{self.data_path}/SAM_PROCESS/{self.category}/{shape_id}"
            mask_label = np.load(os.path.join(mask_path, "mask_2.npy")).astype(np.int32)
            graph = {}
            graph["is_seen"] = is_seen
            graph_path = f"{self.data_path}/SAM_PROCESS/{self.category}/{shape_id}/graph_cnt.pkl"
            with open(graph_path, "rb") as f:
                mask_area, mask2id, id2mask, edges_cnt = pickle.load(f)
                graph["edges_cnt"] = edges_cnt
                graph["mask_area"] = mask_area
                graph["mask2id"] = mask2id
                graph["id2mask"] = id2mask
            
            graph_path = f"{self.data_path}/SAM_PROCESS/{self.category}/{shape_id}/pc_mask_ind_eli.pkl"
            
            with open(graph_path, "rb") as f:
                mask_pc_ind = pickle.load(f)
                graph["mask_pc_ind"] = mask_pc_ind["mask_pc_ind"]
            
            conf_edges, conf_edge_index = self.get_conf_edges(graph, pc_label)
            
            adj_edges_path = f"{self.data_path}/SAM_PROCESS/{self.category}/{shape_id}/adj_edges.npy"
            adj_edges = np.load(adj_edges_path)
            tmp_adj_edges = edges_cnt!=0
            adj_edges |= tmp_adj_edges
            adj_edges -=  conf_edges
            adj_edges[adj_edges==-1]=0
            
            graph["adj_edges"] = adj_edges
            graph["weak_edges"] = adj_edges
                        
            centers = np.arange(pc.shape[0])
            graph["centers"] = centers
            
            
            id2mask = graph["id2mask"]
            mask_pc_ind = graph["mask_pc_ind"]
            mask_num = max(graph["mask2id"].values())+1
            
            mask_group_ind = []
            mask_pc = []
            mask_normAng = []
            view_pos = np.load("my_view.npy")
            for i in range(mask_num):
                view, label = id2mask[i]
                pc_ind_i = mask_pc_ind[i]
                inter = np.intersect1d(pc_ind_i, centers)
                if inter.size > 0:
                    mask_group_ind.append(inter)
                    mask_pc.append(pc[inter])
                    mask_normAng.append(self.get_normAng(pc[inter], pc_norm[inter], view_pos[view], sample=False))
                else:
                    mask_group_ind.append(np.empty((0)))
                    mask_pc.append(np.empty((0)))
                    mask_normAng.append(np.empty((0)))
                    
            graph["mask_group_ind"] = mask_group_ind
            graph["mask_pc"] = mask_pc
            graph["mask_normAng"] = mask_normAng
            
                
            n_group = centers.shape[0]
            n_mask = max(mask2id.values())+1
            
            strong_edge_index = EdgeIndex()
            weak_edge_index = EdgeIndex()
            
            for i in range(n_mask):
                for j in range(i, n_mask):
                    if conf_edges[i,j]==1 and mask_group_ind[i].size>0 and mask_group_ind[j].size>0:
                        strong_edge_index.add(i, j) 
                        strong_edge_index.add(j, i)
            
            weak_edges = graph["weak_edges"]
            for i in range(n_mask):
                for j in range(i, n_mask):
                    if weak_edges[i,j]==1 and mask_group_ind[i].size>0 and mask_group_ind[j].size>0:
                        weak_edge_index.add(i, j) 
                        weak_edge_index.add(j, i)
            
            if strong_edge_index.len()==0 or weak_edge_index.len()==0:
                continue
            
            graph["strong_edge_index"] = strong_edge_index.get_edge_index()
            graph["weak_edge_index"] = weak_edge_index.get_edge_index()
            
            if save_data:
                pass
            
            if self.use_pseudo_label:
                graph["psd_label"] = psd_label
                
            shape_list.append(Shape(ii, pc, pc_label, unseen_pc, unseen_pc_label,img, mask_label,pc_idx,coords, graph, pc_norm))
            if self.debug:
                break
        
        return shape_list
    
    def __len__(self):
        return len(self.shape_list)
    
    def __getitem__(self, index):
        shape = self.shape_list[index]        
        return shape.unpack()
        
    def calc_ave_seen_vew_num(self):
        b = []
        for ii in range(len(self.shape_list)):
            pc_idx = torch.tensor(self.shape_list[ii].pc_idx)
            a = []
            for jj in range(10):
                pc_idx_perview = pc_idx[jj].flatten()
                pc_idx_perview, count = torch.unique(pc_idx_perview, return_counts=True)
                a.append(pc_idx_perview)
            a = torch.cat(a)
            _, count = torch.unique(a, return_counts=True)
            print(torch.mean(count.float()))
            b.append(torch.mean(count.float()))
        print("total", torch.mean(torch.tensor(b)))    
        pass

def data_process(args):
    for category in tqdm(all_categories):
        dataset = ShapenetPart("few_shot", category, args, save_data=True)
        dataset = ShapenetPart("test", category, args, save_data=True)
        print("good")
    exit(0)

def calc_conn_ratio(args):
    def count_connected_components(adj_matrix):
        n = adj_matrix.shape[0]
        visited = np.zeros(n, dtype=bool)

        def dfs(node):
            visited[node] = True
            neighbors = np.where(adj_matrix[node] > 0)[0]
            for neighbor in neighbors:
                if not visited[neighbor]:
                    dfs(neighbor)

        count = 0
        for i in range(n):
            if not visited[i]:
                dfs(i)
                count += 1
        return count
    # category = "Clock"
    # all_categories = ['KitchenPot','Lighter']
    print(len(all_categories))
    data_to_write = []
    lines = []
    edges_ratio = []
    for category in all_categories:
        dataset = PartnetEpc("test", category, args)
        ratios = []
        for iii in tqdm(range(len(dataset))):
            ii, pc, pc_label, unseen_pc, unseen_pc_label,img, mask_label,pc_idx,coords, graph, pc_norm = dataset[iii]
            from src.utils import get_batchlabel_color
            # from src.render_pc import render_pc
            mask2id = graph["mask2id"]
            id2mask = graph["id2mask"]
            mask_pc_ind = graph["mask_pc_ind"]
            # edge_index = graph["edge_index"]
            edges = graph["edges"]
            
            conn_cnt = count_connected_components(edges)
            # print(edges.shape[0], conn_cnt)
            ratios.append(conn_cnt/edges.shape[0])
            edges_ratio.append(edges.sum()/edges.shape[0])
        line = f"{category}, {sum(ratios)/len(ratios)}, {sum(edges_ratio)/len(edges_ratio)}\n"
        print(line)
        lines.append(line)
        with open("./output/conn_ratio.csv","w") as f:
            f.writelines(lines)

def calc_good_bad_edge(args):
    category = "Clock"
    # all_categories = ['KitchenPot','Lighter']
    print(len(all_categories))
    data_to_write = []
    for category in all_categories:
        dataset = PartnetEpc("test", category, args)
        goods = []
        mids = []
        for iii in tqdm(range(len(dataset))):
            ii, pc, pc_label, unseen_pc, unseen_pc_label,img, mask_label,pc_idx,coords, graph, pc_norm = dataset[iii]
            from src.utils import get_batchlabel_color
            # from src.render_pc import render_pc
            mask2id = graph["mask2id"]
            id2mask = graph["id2mask"]
            mask_pc_ind = graph["mask_pc_ind"]
            # edge_index = graph["edge_index"]
            edges = graph["edges"]
            # mask_weight = graph["weight"]
            
            edges_cnt = graph["edges_cnt"]
            mask_area = graph["mask_area"]
            
            n_mask = max(mask2id.values())+1
            n_view=10
            rgb_mask_label = (get_batchlabel_color(mask_label)*255).astype(np.uint8)
            img_label = pc_label[pc_idx]+1
            img_label[pc_idx==-1]=0
            rgb_img_label = (get_batchlabel_color(img_label)*255).astype(np.uint8)
            
            def check_mask(img_ind_i, pc_idx_i):
                cnt1 = (pc_idx_i!=-1).sum()
                cnt2 = img_ind_i.sum()
                return cnt1 / cnt2 > 0.9

            
            good_edge_cnt = 0
            bad_edge_cnt = 0
            mid_edge_cnt = 0
            for i in range(n_mask):
                view, label = id2mask[i]
                img_ind = mask_label[view]==label
                # pc_ind = pc_idx[view,img_ind]
                # pc_ind = pc_ind[pc_ind!=-1]
                pc_ind_i = mask_pc_ind[i]
                if pc_ind_i.size>0:
                    mask_true_label = pc_label[pc_ind_i]
                    hist_i = np.bincount(mask_true_label, minlength=10)
                    lb_i = np.argmax(hist_i)
                    conf_i = hist_i[lb_i]/(hist_i.sum()+1e-6)
                else:
                    conf_i=0  
                cnt = 0
                for j in range(n_mask):
                    if edges[i,j] ==0:
                        continue
                    view, label = id2mask[j]
                    img_ind = mask_label[view]==label
                    # pc_ind = pc_idx[view,img_ind]
                    # pc_ind = pc_ind[pc_ind!=-1]
                    pc_ind_j = mask_pc_ind[j]
                    if pc_ind_j.size>0:
                        mask_true_label = pc_label[pc_ind_j]
                        hist_j = np.bincount(mask_true_label, minlength=10)
                        lb_j = np.argmax(hist_j)
                        conf_j = hist_j[lb_j]/(hist_j.sum()+1e-6)
                    else:
                        conf_j=0
                    if conf_i>0.9 and conf_j>0.9 and lb_i==lb_j:
                        set1 = mask_pc_ind[i]
                        set2 = mask_pc_ind[j]
                        intersection = np.intersect1d(set1, set2)
                        union = np.union1d(set1, set2)
                        if intersection.size/union.size > 0.9:
                            mid_edge_cnt+=1
                        else:
                            good_edge_cnt+=1
                    else:
                        bad_edge_cnt+=1
            goods.append(good_edge_cnt/(good_edge_cnt+bad_edge_cnt+mid_edge_cnt))
            mids.append(mid_edge_cnt/(good_edge_cnt+bad_edge_cnt+mid_edge_cnt))
            
        line = f"{category}, {sum(goods)/len(goods)}, {sum(mids)/len(mids)}\n"
        print(line)
        data_to_write.append(line)
        # if category == "Bottle":
        #     break
        with open("./output/good_mid_ratio005.csv", "w") as f_out:
            f_out.writelines(data_to_write)

def visualize_conn(args):
    category = "Pen"
    dataset = PartnetEpc("test", category, args)
    from src.utils import get_batchlabel_color
    def check_mask(img_ind_i, pc_idx_i):
        cnt1 = (pc_idx_i!=-1).sum()
        cnt2 = img_ind_i.sum()
        return cnt1 / cnt2 > 0.9
    for iii in tqdm(range(len(dataset))):
        ii, pc, pc_label, unseen_pc, unseen_pc_label,img, mask_label,pc_idx,coords, graph, pc_norm = dataset[iii]
        
        mask2id = graph["mask2id"]
        id2mask = graph["id2mask"]
        mask_pc_ind = graph["mask_pc_ind"]
        edges = graph["adj_edges"]
        adj_edges = graph["adj_edges"]
        edges_cnt = graph["edges_cnt"]
        mask_area = graph["mask_area"]
        n_mask = max(mask2id.values())+1
        n_view=10
        rgb_mask_label = (get_batchlabel_color(mask_label)*255).astype(np.uint8)
        img_label = pc_label[pc_idx]+1
        img_label[pc_idx==-1]=0
        rgb_img_label = (get_batchlabel_color(img_label)*255).astype(np.uint8)
        path = f"output/check_mask_conn/"
        os.makedirs(path, exist_ok=True)
        # for k in range(n_mask):
        #     for p in range(n_mask):
        #         if edges[p,k]==0:
        #             continue
        #         for q in range(n_mask):
        #             if edges[p,k]==1 and edges[k,q]==1:
        #                 edges[p,q]=1
        for i in range(n_mask):
            view, label = id2mask[i]
            img_ind = mask_label[view]==label
            if not check_mask(img_ind, pc_idx[view, img_ind]):
                continue
            img_curr = (img[view]*255).astype(np.uint8)
            cv2.imwrite(os.path.join(path,f"aimg.png"), img_curr)
            img_curr[img_ind] = np.array([20,20,200])
            cv2.imwrite(os.path.join(path,f"vis_curr.png"), img_curr)
            cv2.imwrite(os.path.join(path,f"aimg_label.png"), rgb_img_label[view])
            
            cnt = 0
            for j in range(n_mask):
                if edges[i,j] ==0:
                    continue
                
                view, label = id2mask[j]
                img_ind = mask_label[view]==label
                if not check_mask(img_ind, pc_idx[view, img_ind]):
                    continue
                img_curr = (img[view]*255).astype(np.uint8)
                img_curr[img_ind] = np.array([20,20,200])
                cv2.imwrite(os.path.join(path,f"zvis_curr_{cnt}.png"), img_curr)
                
                print(f"cnt={cnt} , ({i},{j})")
                cnt+=1 
            print(f"good")
        print("good2")

def check_psuedo_label(args):
    from src.utils import get_batchlabel_color
    from test_pc import get_legend
    
    category = "Keyboard"
    meta = json.load(open("PartNetE_meta.json"))
    label_name = ["BK","uncertain","None"]
    label_name.extend(meta[category])
    dataset = PartnetEpc("test", category, args)
    save_dir = "./output/check_psuedo_label"
    os.makedirs(save_dir, exist_ok=True)
    for iii in tqdm(range(len(dataset))):
        ii, pc, pc_label, unseen_pc, unseen_pc_label,img, mask_label,pc_idx,coords, graph, pc_norm = dataset[iii]
        mask2id = graph["mask2id"]
        id2mask = graph["id2mask"]
        mask_pc_ind = graph["mask_pc_ind"]
        edges = graph["edges"]
        edges_cnt = graph["edges_cnt"]
        mask_area = graph["mask_area"]
        n_mask = max(mask2id.values())+1
        psuedo_label = pc_norm
        
        img_label = psuedo_label[pc_idx]+2
        img_label[pc_idx==-1]=0
        rgb_img_label = (get_batchlabel_color(img_label)*255).astype(np.uint8)
        
        for view in range(10):
            rgb = (img[view]*255).astype(np.uint8)
            cv2.imwrite(f"{save_dir}/{view}_img.png", rgb)
            
        for view in range(10):
            rgb = rgb_img_label[view]
            rgb[img_label[view]==0] = np.array([0,0,0])
            rgb = get_legend(rgb, img_label[view], label_name)
            cv2.imwrite(f"{save_dir}/{view}.png", rgb)
        
        img_label = pc_label[pc_idx]+2
        img_label[pc_idx==-1]=0
        rgb_img_label = (get_batchlabel_color(img_label)*255).astype(np.uint8)
        for view in range(10):
            rgb = rgb_img_label[view]
            rgb[img_label[view]==0] = np.array([0,0,0])
            rgb = get_legend(rgb, img_label[view], label_name)
            cv2.imwrite(f"{save_dir}/{view}_gt.png", rgb)  
        print("good")          
        
def render_mask(args):
    from src.render_pc import render_pc
    from src.utils import get_seg_color
    category = "Pen"
    dataset = PartnetEpc("test", category, args)
    def render(pc, pc_label, save_dir="output/render", device="cuda:0"):
        pc_colors = get_seg_color(pc_label)
        render_pc(pc, pc_colors, save_dir, device=device)
    def render_o(pc, pc_label):
        pc_colors = get_seg_color(pc_label)
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(pc)
        pcd.colors = o3d.utility.Vector3dVector(pc_colors)

        # 启动 Web 可视化服务器（需要浏览器访问）
        o3d.visualization.webrtc_server.enable_webrtc()
        o3d.visualization.draw([pcd])
        
    for iii in tqdm(range(len(dataset))):
        ii, pc, pc_label, unseen_pc, unseen_pc_label,img, mask_label,pc_idx,coords, graph, pc_norm = dataset[iii]
        
        mask_pc_ind = graph["mask_pc_ind"]
        
        # ind = mask_pc_ind[121]
        # render(pc[ind], pc_label[ind])
        
        mask_num = max(graph["mask2id"].values())+1
        for i in range(mask_num):
            ind = mask_pc_ind[i]
            if ind.size > 0:
                render(pc[ind], pc_label[ind])
                # render_o(pc[ind], pc_label[ind])
                print("good")
        print("good")
        
def check_PN_quality(args):
    # all_categories = ["Phone"]
    for category in all_categories:
        data_to_write = []
        dataset = PartnetEpc("test", category, args)
        acc_P = []
        acc_N = []
        for iii in tqdm(range(len(dataset))):
            ii, pc, pc_label, unseen_pc, unseen_pc_label,img, mask_label,pc_idx,coords, graph, pc_norm = dataset[iii]
            PN_tri = graph["PN_tri"]
            # print(PN_tri.shape)
            anchor_label = pc_label[PN_tri[:,0]]
            pos_label = pc_label[PN_tri[:,1]]
            neg_label = pc_label[PN_tri[:,2]]
            accp = (anchor_label==pos_label).sum()/PN_tri.shape[0]
            
            accn = (anchor_label!=neg_label).sum()/PN_tri.shape[0]
            acc_P.append(accp)
            acc_N.append(accn)
            
        line = f"{category}, {sum(acc_P)/len(acc_P)}, {sum(acc_N)/len(acc_N)}\n"
        print(line)
        data_to_write.append(line)
        with open("./output/tri_alpha2.csv", "w") as f_out:
            f_out.writelines(data_to_write)
def mask_visualize(args):
    category = "Laptop"
    split = "test"
    dataset = PartnetEpc(split, category, args)
    from src.utils import get_batchlabel_color
    mask_w_path = "output/mask_w"
    root_path = f"./output/mask"
    os.makedirs(root_path, exist_ok=True)
    from src.utils import get_batchlabel_color
    def check_mask(img_ind_i, pc_idx_i):
        cnt1 = (pc_idx_i!=-1).sum()
        cnt2 = img_ind_i.sum()
        return cnt1 / cnt2 > 0.9
    for iii in tqdm(range(len(dataset))):
        ii, pc, pc_label, unseen_pc, unseen_pc_label,img, mask_label,pc_idx,coords, graph, pc_norm = dataset[iii]
        # W= np.load(os.path.join(mask_w_path,f"{iii}.npy"))
        mask2id = graph["mask2id"]
        id2mask = graph["id2mask"]
        edges = graph["adj_edges"]
        n_mask = max(mask2id.values())+1
        for i in range(n_mask):
            view, label = id2mask[i]
            img_ind = mask_label[view]==label
            if not check_mask(img_ind, pc_idx[view, img_ind]):
                continue
            img_curr = (img[view]*255).astype(np.uint8)
            cv2.imwrite(os.path.join(root_path,f"aimg.png"), img_curr)
            img_curr[img_ind] = np.array([20,20,200])
            cv2.imwrite(os.path.join(root_path,f"vis_curr.png"), img_curr)
            # print("W=== ",W[i])
            print("good")
        print("good")
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--category", type=str, nargs='+', default=["Phone"])# ["Box","Bucket","Camera","Cart","Clock"]
    # parser.add_argument("--cuda", type=str, default="cuda:2")
    parser.add_argument("--mode", type=str, default="train")  #!!!!! if change test change output_dir
    parser.add_argument("--output_dir", type=str, default="res/tmp5")
    parser.add_argument("--epoch", type=int, default=20) 
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--lr", type=float, default=0.01) 
    
    parser.add_argument("--use_2d_feat", type=int, default=1) 
    parser.add_argument("--use_3d_feat", type=int, default=0)
    
    parser.add_argument("--img_encoder", type=str, default="dinov2")
    parser.add_argument("--use_cache", type=int, default=0)
    
    parser.add_argument("--use_propagate", type=int, default=1) # 2
    parser.add_argument("--eliminate_sparseness", type=int, default=1)
    parser.add_argument("--ave_per_mask", type=int, default=0)
    parser.add_argument("--use_gnn", type=int, default=1) 
    parser.add_argument("--ave_inter_mask", type=int, default=0)
    
    
    parser.add_argument("--use_slow_start", type=int, default=-2)
    parser.add_argument("--use_new_classifier", type=int, default=0) 
    parser.add_argument("--use_js2weight", type=int, default=0) 
    parser.add_argument("--use_attn_ave", type=int, default=0)
    parser.add_argument("--use_ave_best", type=int, default=0) 
    parser.add_argument("--back_to_edges", type=int, default=0)
    parser.add_argument("--use_pseudo_label", type=int, default=0)
    
    parser.add_argument("--conf_label_edge", type=int, default=0) 
    parser.add_argument("--gt_label_edge", type=int, default=0)
    parser.add_argument("--ps_label_edge", type=int, default=0)
    parser.add_argument("--img_feat_on_mask", type=int, default=0)
    
    
    
    parser.add_argument("--All_graph", type=int, default=1)
    parser.add_argument("--use_ball_propagate", type=int, default=0)
    parser.add_argument("--graph4", type=int, default=0)
    
    parser.add_argument("--self_supervised", type=int, default=0)

    parser.add_argument("--use_proxy_contrast_loss", type=int, default=0)
    parser.add_argument("--use_contrast_loss2", type=int, default=0)
    parser.add_argument("--use_ref_loss", type=int, default=0) 
    parser.add_argument("--use_mask_consist_loss", type=int, default=0) 
    parser.add_argument("--use_triplet_loss", type=int, default=0) 
    
    
    parser.add_argument("--debug", type=int, default=1)
    args = parser.parse_args()
    
    data_process(args)
    # check_PN_quality(args)
    # calc_good_bad_edge(args)
    # check_psuedo_label(args)
    # visualize_conn(args)
    # render_mask(args)
    # mask_visualize(args)
    