# cpoied from UnSAM
# https://github.com/frank-xwang/UnSAM

from hmac import new
from math import dist
import os
from functools import partial
from xml.etree.ElementInclude import include
from matplotlib.pylab import single
import torch.nn.functional as F
import numpy as np
import torch
import heapq
import time

class DSU:
    def __init__(self, n: int):
        self.parent = list(range(n))

    def find(self, x):
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]

    def union(self, x, y):
        # merge x to y
        root_x = self.find(x)
        root_y = self.find(y)
        self.parent[root_x] = root_y


def merged_clusters(i, j, clusters):
    c1, c2 = clusters[i], clusters[j]
    weighted_sum = (c1['feature'] + c2['feature']).float()
    # weighted_sum = ((c1['feature'] + c2['feature']) / (c1['num_of_patch'] + c2['num_of_patch'])).float()
    #weighted_sum = ((c1['num_of_patch']*c1['feature'] + c2['num_of_patch']*c2['feature']) / (c1['num_of_patch'] + c2['num_of_patch'])).float()
    return {
        'feature': c1['feature'] + c2['feature'],
        'normalized_feature': F.normalize(weighted_sum, dim=0),
        'mask': torch.logical_or(c1['mask'], c2['mask']),
        'num_of_patch': c1['num_of_patch'] + c2['num_of_patch'],
        'neighbors': c1['neighbors'].union(c2['neighbors']).difference(set([i, j]))
    }

def merged_clusters_dsu(i, j, clusters, dsu: DSU):
    c1, c2 = clusters[i], clusters[j]
    weighted_sum = (c1['feature'] + c2['feature']).float()
    new_neighbors = set([dsu.find(n) for n in c1['neighbors']] + [dsu.find(n) for n in c2['neighbors']])
    new_neighbors.discard(dsu.find(i))
    return {
        'feature': c1['feature'] + c2['feature'],
        'normalized_feature': F.normalize(weighted_sum, dim=0),
        'mask': torch.logical_or(c1['mask'], c2['mask']),
        'num_of_patch': c1['num_of_patch'] + c2['num_of_patch'],
        'neighbors': new_neighbors
    }, dsu

def merged_clusters_inplace(i, j, clusters, k, features:torch.Tensor, normalized_features:torch.Tensor, masks:torch.Tensor):
    features[k] = features[i] + features[j]
    normalized_features[k] = F.normalize(features[k], dim=0)
    masks[k] = torch.logical_or(masks[i], masks[j])
    c1, c2 = clusters[i], clusters[j]
    return {
        'num_of_patch': c1['num_of_patch'] + c2['num_of_patch'],
        'neighbors': c1['neighbors'].union(c2['neighbors']).difference(set([i, j]))
    }, features, normalized_features, masks

@torch.inference_mode()
def iterative_merge(features: torch.Tensor, threshes=[0.6, 0.5, 0.4, 0.3, 0.2, 0.1], min_size=4, merge_masks=False):
    # features: H x W x C

    clusters = []
    similarities = []
    H, W = features.shape[:2]
    features = features.float().flatten(0, 1)
    normalized_features = F.normalize(features, dim=1)
    dot_prod1 = (normalized_features[:-1] * normalized_features[1:]).sum(dim=1)
    dot_prod2 = (normalized_features[:-W] * normalized_features[W:]).sum(dim=1)
    masks = torch.eye(H*W, dtype=torch.bool, device=features.device).view(H*W, H, W)

    # start_time = time.time()
    stored_sims = []
    cluster_idx = 0
    for y in range(H):
        for x in range(W):
            clusters.append({
                'feature': features[cluster_idx],
                'normalized_feature': normalized_features[cluster_idx],
                'mask': masks[cluster_idx],
                'num_of_patch': 1,
                'neighbors': set() #set([cluster_idx + d for c, d in [(x>0, -1),(y>0, -W),(x<W-1, 1),(y<H-1, W)]if c])
            })
            sims = []

            if x > 0:
                clusters[cluster_idx]['neighbors'].add(cluster_idx-1)
                clusters[cluster_idx-1]['neighbors'].add(cluster_idx)
                sim = dot_prod1[cluster_idx-1].item()
                # sim=torch.dot(clusters[cluster_idx-1]['normalized_feature'], clusters[cluster_idx]['normalized_feature']).item()
                sims.append((-sim,cluster_idx-1, cluster_idx)) if sim > threshes[-1] else None
            if y > 0:
                clusters[cluster_idx]['neighbors'].add(cluster_idx-W)
                clusters[cluster_idx-W]['neighbors'].add(cluster_idx)
                sim = dot_prod2[cluster_idx-W].item()
                # sim=torch.dot(clusters[cluster_idx-W]['normalized_feature'], clusters[cluster_idx]['normalized_feature']).item()
                sims.append((-sim,cluster_idx-W, cluster_idx)) if sim > threshes[-1] else None

            cluster_idx += 1
            # similarities.extend(sims)
            # sort simes in descending order
            sims.sort(key=lambda x: x[0], reverse=True)
            stored_sims.append(sims)
    similarities = [sim[-1] for sim in stored_sims if len(sim) > 0]
    heapq.heapify(similarities)
    cluster_exists = [True]*len(clusters)
    dsu = DSU(len(clusters) * 2)
    # print("init time: ", time.time()-start_time)

    # features = torch.cat([features, torch.zeros_like(features)], dim=0)
    # normalized_features = torch.cat([normalized_features, torch.zeros_like(normalized_features)], dim=0)
    # masks = torch.cat([masks, torch.zeros_like(masks)], dim=0)
    all_masks = []
    all_masks_ids = []
    for thresh in threshes:
        # start_time = time.time()
        ip_time = 0
        sim_time = 0
        while len(similarities):
            sim, i, j = heapq.heappop(similarities)
            if -sim < thresh: break
            if not cluster_exists[j]: continue
            if not cluster_exists[i]: 
                # stored in stored_sims[j], pop until an existing similarity is found
                while len(stored_sims[j]) > 0 and not cluster_exists[stored_sims[j][-1][1]]:
                    stored_sims[j].pop()
                if len(stored_sims[j]) > 0:
                    heapq.heappush(similarities, stored_sims[j][-1])
                continue

            dsu.union(i, cluster_idx)
            dsu.union(j, cluster_idx)
            merged, dsu = merged_clusters_dsu(i, j, clusters, dsu)
            # merged = merged_clusters(i, j, clusters)
            # merged, features, normalized_features, masks = merged_clusters_inplace(i, j, clusters, cluster_idx, features, normalized_features, masks)
            clusters.append(merged)
            cluster_exists.append(True)
            cluster_exists[i] = False
            cluster_exists[j] = False

            neighbors = list(merged['neighbors'])
            if len(neighbors) == 0:
                cluster_idx += 1
                continue
            # ip_start_time = time.time()
            neighbor_normalized_feature = [clusters[neighbor]['normalized_feature'] for neighbor in neighbors]
            neighbor_normalized_feature = torch.stack(neighbor_normalized_feature)
            # neighbor_normalized_feature = normalized_features[neighbors]
            sims=(neighbor_normalized_feature @ merged['normalized_feature']).cpu().numpy()
            # sims=(neighbor_normalized_feature @ normalized_features[cluster_idx]).cpu().numpy()
            # ip_time += time.time() - ip_start_time
            sims_tuple = []
            # for neighbor in merged['neighbors']:
            for sim, neighbor in zip(sims, neighbors):
                # clusters[neighbor]['neighbors'].discard(i)
                # clusters[neighbor]['neighbors'].discard(j)
                # heapq.heappush(similarities,(-sim,neighbor, cluster_idx)) if sim > threshes[-1] else None
                # clusters[neighbor]['neighbors'].add(cluster_idx)
                if sim > threshes[-1]:
                    sims_tuple.append((-sim, neighbor, cluster_idx))
            sims_tuple.sort(key=lambda x: x[0], reverse=True)
            stored_sims.append(sims_tuple)
            if len(sims_tuple) > 0:
                heapq.heappush(similarities, sims_tuple[-1])
            # sim_time += time.time() - ip_start_time

            cluster_idx += 1
        # print(f"merging time: {time.time() - start_time:.4f}s, inner_prod percentage: {ip_time/(time.time() - start_time):.2f}, sim percentage: {sim_time/(time.time() - start_time):.2f}")
        
        if not merge_masks:
            single_level_masks = [cluster['mask'] for i, cluster in enumerate(clusters) if cluster_exists[i] and cluster['num_of_patch'] >= min_size]
            all_masks.append(torch.stack(single_level_masks)) if len(single_level_masks) else None
            # all_masks.append(masks[single_level_masks_ids]) if len(single_level_masks_ids) else None
        else:
            single_level_masks_ids = [i for i, cluster in enumerate(clusters) if cluster_exists[i] and cluster['num_of_patch'] >= min_size]
            all_masks_ids.append(single_level_masks_ids) if len(single_level_masks_ids) else None

    if merge_masks:
        unique_masks_ids = set(sum(all_masks_ids, []))
        merged_masks = [clusters[i]['mask'] for i in unique_masks_ids]
        all_masks = [torch.stack(merged_masks)]
    return all_masks


import matplotlib.pyplot as plt
import cupy as cp
import cuml
from cuml.cluster import DBSCAN, HDBSCAN
from cuml.manifold import UMAP
import sklearn
from sklearn.cluster import DBSCAN as skDBSCAN
from hdbscan import HDBSCAN as skHDBSCAN
from concurrent.futures import ThreadPoolExecutor
import warnings
warnings.filterwarnings(
    "ignore",
    category=UserWarning,
    message=r".*c10d::broadcast_: an autograd kernel was not registered.*"
)

def torch_to_cupy(x: torch.Tensor) -> cp.ndarray:
    return cp.from_dlpack(torch.utils.dlpack.to_dlpack(x))

def cupy_to_torch(a: cp.ndarray) -> torch.Tensor:
    return torch.utils.dlpack.from_dlpack(a.toDlpack())

global_timer = 0
@torch.inference_mode()
def dbscan_on_tensor(x: torch.Tensor, eps=0.6, min_samples=3, metric = 'l2'):
    if metric == 'cosine':
        X = F.normalize(x, dim=-1)
        distance = 1 - (X @ X.T)
    elif metric == 'dot': # already normalized cosine similarity
        distance = 1 - (x @ x.T)
    else:
        distance = torch.cdist(x, x, p=2)
    if x.shape[0] < 256:
        # do it on CPU
        db = skDBSCAN(
            eps=eps,
            min_samples=min_samples,
            metric='precomputed',
            n_jobs=8 # useless if metric is precomputed
        )
        labels = db.fit_predict(distance.cpu().numpy())
        labels = torch.from_numpy(labels).to(x.device)
        return labels
    dev_id = x.device.index
    with cp.cuda.Device(dev_id):
        # X = torch_to_cupy(x)
        D = torch_to_cupy(distance)
        stream = cp.cuda.Stream(non_blocking=True)
        h = cuml.Handle(stream.ptr)
        db = DBSCAN(
            eps=eps,
            handle=h,
            min_samples=min_samples,
            metric='precomputed',
            output_type='cupy' # keep results on GPU
        )
        labels = cupy_to_torch(db.fit_predict(D))
    return labels

@torch.inference_mode()
def hdbscan_on_tensor(x: torch.Tensor, min_cluster_size=7, min_samples=5):
    dev_id = x.device.index
    with cp.cuda.Device(dev_id):
        X = torch_to_cupy(x)
        hdb = HDBSCAN(
            min_cluster_size=min_cluster_size,
            min_samples=min_samples or min_cluster_size,
            # allow_single_cluster=True,
            output_type='cupy'                   # keep results on GPU
        )
        labels = cupy_to_torch(hdb.fit_predict(X))
    return labels

def run_umap_2d(emb_cupy, *, n_neighbors=15, min_dist=0.1):
    umap = UMAP(
        n_components=2,
        n_neighbors=n_neighbors,
        min_dist=min_dist,
        metric='cosine'
    )
    return umap.fit_transform(emb_cupy)   # returns CuPy (still on GPU)

def visualise_cluster(tensor: torch.Tensor, labels_cupy: cp.ndarray, X2d_host=None, 
                      umap_kw: dict = None, out_path: str = "./playground/hdbscan.png"):
    dev = tensor.device.index or 0
    with cp.cuda.Device(dev):             # make sure we’re on the same GPU
        if X2d_host is None:
            X2d = run_umap_2d(torch_to_cupy(tensor), **(umap_kw or {}))
            X2d_host     = cp.asnumpy(X2d)     # tiny (n × 2) copy → CPU
        labels_host  = cp.asnumpy(labels_cupy)

    # ---- quick and tidy scatter plot on the host ---------------------------
    unique = np.unique(labels_host)
    cmap   = plt.get_cmap("tab10", len(unique))
    plt.figure(figsize=(6, 5))

    for lbl in unique:
        mask = labels_host == lbl
        plt.scatter(
            X2d_host[mask, 0],
            X2d_host[mask, 1],
            s=40,
            color="lightgrey" if lbl == -1 else cmap(lbl),
            label="noise" if lbl == -1 else f"cluster {lbl}",
            alpha=0.9,
            edgecolors="k" if lbl != -1 else "none",
            linewidths=.3,
        )

    plt.title("DBSCAN/HDBSCAN clusters (UMAP 2-D projection)")
    plt.axis("off")
    plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
    plt.tight_layout()

    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    plt.savefig(out_path, dpi=300, bbox_inches="tight")
    # plt.show()
    plt.close()


@torch.inference_mode()
def split_masks(mask: torch.Tensor, label: torch.Tensor, include_noise=False):
    # mask: BoolTensor of shape (H, W)
    # label: LongTensor of shape (mask.sum())
    # return: new_masks: LongTensor of shape (num_clusters + num_noises, H, W)
    H, W = mask.shape
    num_clusters = label.max().item() + 1
    noise_idx = (label == -1).nonzero(as_tuple=False).flatten()
    num_noises = noise_idx.shape[0]
    new_label = label.clone()
    # print(f"mask size: {mask.sum()}, num_clusters: {num_clusters}, num_noises: {num_noises}")
    if include_noise and num_noises:
        new_label[noise_idx] = torch.arange(num_clusters, num_clusters + num_noises, device=new_label.device, dtype=new_label.dtype)
    elif num_clusters == 0:
        # no clusters found, set the only cluster to be the whole mask
        new_label = 0
        num_clusters = 1

    new_masks = torch.zeros(num_clusters + num_noises*include_noise, H, W, dtype=mask.dtype, device=mask.device)
    y, x = mask.nonzero(as_tuple=True)
    new_masks[new_label, y, x] = True

    return new_masks

@torch.inference_mode()
def cluster_in_masks(features: torch.Tensor, sam_masks: torch.Tensor, algo='dbscan', cluster_args: dict = {}, skip_ids: list = []):
    # features: torch.Tensor of shape (H, W, C)
    # sam_masks: torch.ByteTensor of shape (N, H, W)
    H, W, C = features.shape
    N = sam_masks.shape[0]
    assert H == sam_masks.shape[1] and W == sam_masks.shape[2], "features and sam_masks should have the same spatial dimensions"
    normalize=cluster_args.pop('normalize', True)
    cluster_thresh=cluster_args.pop('cluster_thresh', 10)
    include_noise=cluster_args.pop('include_noise', False)
    start = time.time()
    features = features.float().clamp(-1e9, 1e9) # avoid NaN
    if normalize:
        features = F.normalize(features, dim=-1)

    if algo == 'dbscan':
        cluster_algo = partial(dbscan_on_tensor, **cluster_args)
    elif algo == 'hdbscan':
        cluster_algo = partial(hdbscan_on_tensor, **cluster_args)
    else:
        raise ValueError(f"Unknown clustering algorithm: {algo}")
    skip_ids = set([i if i>=0 else N+i for i in skip_ids])
    mask_ids_to_cluster = [i for i in range(N) if i not in skip_ids and sam_masks[i].sum() >= cluster_thresh]

    # apply clustering algorithm on each masked feature
    # def worker(features:torch.Tensor):
    #     return cluster_algo(features)
    # futures, labels = [], []
    # with ThreadPoolExecutor(max_workers=4) as pool:
    #     futures = [pool.submit(worker, features[sam_masks[i]]) for i in mask_ids_to_cluster]
    #     labels = [f.result() for f in futures]

    labels = []
    for i in mask_ids_to_cluster:
        labels.append(cluster_algo(features[sam_masks[i]]))
    # labels = [cupy_to_torch(l) for l in labels]

    new_sam_masks = []
    idx = 0
    for i in range(N):
        if i in skip_ids or sam_masks[i].sum() < cluster_thresh:
            new_sam_masks.append(sam_masks[i].unsqueeze(0))
        else:
            mask = sam_masks[i]
            label = labels[idx]
            new_mask = split_masks(mask, label, include_noise=include_noise)
            new_sam_masks.append(new_mask)
            idx += 1
    new_sam_masks = torch.cat(new_sam_masks, dim=0)
    return new_sam_masks

# from concurrent.futures import ThreadPoolExecutor

# def worker(tensor):
#     return hdbscan_on_tensor(tensor, min_cluster_size=10)

# with ThreadPoolExecutor(max_workers=len(feature_groups)) as ex:
#     clustered = list(ex.map(worker, feature_groups))