# Adopted from https://github.com/DeMoriarty/fast_pytorch_kmeans/tree/master.

import math
import torch
from torch.nn.functional import normalize
from time import time
import numpy as np
from sklearn import cluster, mixture

try:
    from .init_methods import init_methods
    from .utils import *
except Exception as e:
    print(e)
    from init_methods import init_methods
    from utils import *

class Clustering_Base:
    def __init__(self, *args, **kwargs):
        self.args = args
        for key, value in kwargs.items():
            setattr(self, key, value)

    def fit(self):
        # get allocation
        pass

    def get_parent(self):
        # get parent model for each cluster
        pass

    def predict(self):
        # get parent prediction for each sample
        pass

    def get_z(self, X):
        if self.z_distribution == "uniform":
            z = (torch.rand_like(X, device = "cuda") * 2 - 1) / X.numel()
        elif self.z_distribution == "normal":
            z = torch.randn_like(X, device = "cuda") / X.numel()
        else:
            raise NotImplementedError()
        return z

    def get_pairwise_distance(self, X):
        return torch.cdist(X.unsqueeze(0), X.unsqueeze(0), p=2).squeeze(0)

    def get_zdistance(self, X, z):
        return torch.norm(X - z, p=2, dim=1)

    def vote_parent(self, X):
        z = self.get_z(X[0:1,:])
        z_dist = self.get_zdistance(X, z)
        _, min_idx = torch.min(z_dist, dim=0)
        return min_idx

    def forward(self, X):
        self.fit(X)
        self.get_parent(X)
        return self.predict()

class DBSCAN(Clustering_Base):
    def __init__(self, eps=0.5, metric = "precomputed", z_distribution = "uniform",  **kwargs):
        '''
            X: parent first, child second, instance in each row
        '''
        super().__init__(eps = eps, metric = metric,  z_distribution = z_distribution, **kwargs)
        self.alg = cluster.DBSCAN(eps=eps, metric=metric, **kwargs)

    def fit(self, X):
        dist = self.get_pairwise_distance(X)
        return self.alg.fit(dist.cpu().numpy())

    def get_parent(self, X):
        cluster_label = torch.from_numpy(self.alg.labels_).cuda()
        cluster_parent = torch.zeros(cluster_label.max()+1, dtype=torch.long, device = "cuda")
        for cluster_index in range(0, cluster_label.max()+1):
            cluster_mask = cluster_label == cluster_index
            if cluster_mask.sum() > 1:
                sub_X = X[cluster_mask]
                sub_index = self.vote_parent(sub_X)
            else:
                sub_index = 0
            cluster_parent[cluster_index] = torch.nonzero(cluster_mask).squeeze(-1)[sub_index]
        self.cluster_parent = cluster_parent
        self.cluster_label = cluster_label
        return cluster_parent

    def predict(self):
        prediction = torch.full_like(self.cluster_label, -1)
        mask = self.cluster_label >= 0
        prediction[mask] = self.cluster_parent[self.cluster_label[mask]]
        self.prediction = prediction
        return prediction

class GMM(Clustering_Base):
    def __init__(self, n_components, covariance_type= "full", random_state= 42, z_distribution = "uniform",  **kwargs):
        '''
            X: parent first, child second, instance in each row
            defualt config:
                n_components= number of clusters,
                covariance_type= "full",
                random_state= 42,
        '''
        super().__init__( 
            n_components = n_components, 
            covariance_type= covariance_type, 
            random_state= random_state,  z_distribution = z_distribution,  **kwargs)
        self.alg =  mixture.GaussianMixture(
            n_components, 
            covariance_type= covariance_type, 
            random_state= random_state, 
            **kwargs)

    def fit(self, X):
        return self.alg.fit(X.cpu().numpy())

    def get_parent(self, X):

        cluster_label = self.alg.predict(X.cpu().numpy())
        cluster_label = torch.from_numpy(cluster_label).cuda()

        cluster_parent = torch.zeros(cluster_label.max()+1, dtype=torch.long, device = "cuda")
        for cluster_index in range(0, cluster_label.max()+1):
            cluster_mask = cluster_label == cluster_index
            sub_X = X[cluster_mask]
            cluster_parent[cluster_index] = torch.nonzero(cluster_mask).squeeze(-1)[self.vote_parent(sub_X)]
        self.cluster_parent = cluster_parent
        self.cluster_label = cluster_label
        return cluster_parent

    def predict(self):
        prediction = torch.full_like(self.cluster_label, -1)
        mask = self.cluster_label >= 0
        prediction[mask] = self.cluster_parent[self.cluster_label[mask]]
        self.prediction = prediction
        return prediction

class KMeans(Clustering_Base):
    '''
    Kmeans clustering algorithm implemented with PyTorch

    Parameters:
        n_clusters: int, 
        Number of clusters

        max_iter: int, default: 100
        Maximum number of iterations

        tol: float, default: 0.0001
        Tolerance
        
        verbose: int, default: 0
        Verbosity

        Type of distance measure
        
        init_method: {'fps'}
        Type of initialization
        
    Attributes:
        centroids: torch.Tensor, shape: [n_clusters, n_features]
        cluster centroids
    '''
    def __init__(
                self, 
                n_clusters, 
                max_iter=100, 
                tol=0.0001, 
                init_method="fps", 
                pcthreshold = None,
                **kwargs
            ):
        super().__init__(
            n_clusters=n_clusters, 
            max_iter=max_iter, 
            tol=tol, 
            init_method=init_method, 
            pcthreshold = pcthreshold,
            **kwargs
        )

        self.n_clusters = n_clusters
        self.max_iter = max_iter
        self.tol = tol
        self.init_func = init_methods[init_method]
        self.pcthreshold = pcthreshold

        self.centroids = None
        self.labels_ = None


    def min_dist(self, a, b):
        """
        Compute maximum similarity (or minimum distance) of each vector
        in `a` with all of the vectors in `b`

        Parameters:
        a: torch.Tensor, shape: [n_samples, n_features]

        b: torch.Tensor, shape: [n_clusters, n_features]
        """

        dist = torch.cdist(a.unsqueeze(0), b.unsqueeze(0), p=2).squeeze(0)
        min_dist_value, min_dist_index = dist.min(dim=-1)

        return min_dist_value, min_dist_index


    def fit(self, X, centroids=None):
        """
        Parameters:
        X: torch.Tensor, shape: [n_samples, n_features]

        centroids: {torch.Tensor, None}, default: None
            if given, centroids will be initialized with given tensor
            if None, centroids will be randomly chosen from X

        Return:
        cluster labels: torch.Tensor, shape: [n_samples]
        """
        assert isinstance(X, torch.Tensor), "input must be torch.Tensor"
        assert X.dtype in [torch.half, torch.float, torch.double], "input must be floating point"
        assert X.ndim == 2, "input must be a 2d tensor with shape: [n_samples, n_features] "

        batch_size, emb_dim = X.shape
        device = X.device

        if self.pcthreshold is not None:
            estimated_parents_index, estimated_child_index = get_parent_child_subset(X, threshold=self.pcthreshold)
            centroid_candidates = X[estimated_parents_index]
        else:
            estimated_parents_index, estimated_child_index  = None, None
            centroid_candidates = X

        if centroids is None:
            self.centroids = self.init_func(centroid_candidates, self.n_clusters)
        else:
            self.centroids = centroids

        closest = None
        arranged_mask = torch.arange(self.n_clusters, device=device)[:, None]
        
        for i in range(self.max_iter):

            closest = self.min_dist(a=X, b=self.centroids)[1]

            expanded_closest = closest[None].expand(self.n_clusters, -1)
            mask = (expanded_closest==arranged_mask).to(X.dtype)
            c_grad = mask @ X / mask.sum(-1)[..., :, None]
            torch.nan_to_num_(c_grad)

            error = (c_grad - self.centroids).pow(2).sum()
            self.centroids = c_grad
            
            if error <= self.tol:
                break
                
        self._labels = closest
        return closest

    def get_parent(self, X):

        cluster_label = self._labels

        cluster_parent = torch.zeros(cluster_label.max()+1, dtype=torch.long, device = "cuda")
        for cluster_index in range(0, cluster_label.max()+1):
            cluster_mask = cluster_label == cluster_index
            sub_X = X[cluster_mask]
            cluster_parent[cluster_index] = torch.nonzero(cluster_mask).squeeze(-1)[self.vote_parent(sub_X)]
        self.cluster_parent = cluster_parent
        self.cluster_label = cluster_label
        return cluster_parent

    def predict(self):
        prediction = torch.full_like(self.cluster_label, -1)
        mask = self.cluster_label >= 0
        prediction[mask] = self.cluster_parent[self.cluster_label[mask]]
        self.prediction = prediction
        return prediction

class MeanShift(Clustering_Base):
    def __init__(
                self, 
                max_iter=100, 
                bandwidth=0.1, 
                centroid_merge_threshold=0.01, 
                **kwargs
            ):
        super().__init__(
            max_iter=max_iter, 
            bandwidth=bandwidth, 
            centroid_merge_threshold=centroid_merge_threshold, 
            **kwargs
        )

        self.max_iter = max_iter
        self.bandwidth = bandwidth
        self.centroid_merge_threshold = centroid_merge_threshold

    def min_dist(self, a, b):

        dist = torch.cdist(a.unsqueeze(0), b.unsqueeze(0), p=2).squeeze(0)
        min_dist_value, min_dist_index = dist.min(dim=-1)

        return min_dist_value, min_dist_index

    def one_iter(self, centers, X, bandwidth):
        cdist = torch.cdist(centers.unsqueeze(0), X.unsqueeze(0), p=2).squeeze(0)
        mask = cdist < bandwidth # n_centers, n_samples
        weight = mask.to(dtype = X.dtype)
        weight = weight / weight.sum(dim=1, keepdim=True)
        centers = weight @ X
        return centers

    def find_connected_components(self, mask):
        n = mask.shape[0]
        visited = torch.zeros(n, dtype=torch.bool)
        components = []

        def dfs(node, component):
            visited[node] = True
            component.append(node)
            for neighbor in range(n):
                if mask[node, neighbor] and not visited[neighbor]:
                    dfs(neighbor, component)

        for node in range(n):
            if not visited[node]:
                component = []
                dfs(node, component)
                components.append(component)

        return components

    def merge_centers(self,centers, threshold):
        cdist = torch.cdist(centers.unsqueeze(0), centers.unsqueeze(0), p=2).squeeze(0)
        mask = cdist < threshold
        components = self.find_connected_components(mask)
        centers = torch.stack([centers[component[0]] for component in components])
        return centers


    def fit(self, X, centroids=None, pcthreshold = None):
        centers = X
        for i in range(self.max_iter):
            centers = self.one_iter(centers, X, self.bandwidth)
        centers = self.merge_centers(centers, self.centroid_merge_threshold)
        _, labels = self.min_dist(X, centers)
        self._labels = labels
        return labels

    def get_parent(self, X):

        cluster_label = self._labels

        cluster_parent = torch.zeros(cluster_label.max()+1, dtype=torch.long, device = "cuda")
        for cluster_index in range(0, cluster_label.max()+1):
            cluster_mask = cluster_label == cluster_index
            sub_X = X[cluster_mask]
            cluster_parent[cluster_index] = torch.nonzero(cluster_mask).squeeze(-1)[self.vote_parent(sub_X)]
        self.cluster_parent = cluster_parent
        self.cluster_label = cluster_label
        return cluster_parent

    def predict(self):
        prediction = torch.full_like(self.cluster_label, -1)
        mask = self.cluster_label >= 0
        prediction[mask] = self.cluster_parent[self.cluster_label[mask]]
        self.prediction = prediction
        return prediction


class KMeans_Phylogeny(KMeans):
    def __init__(
                self, 
                n_clusters, 
                max_iter=100, 
                tol=0.0001, 
                init_method="fps", 
                alpha = 1,
                pcthreshold = None,
                **kwargs
            ):
        super().__init__(
            n_clusters=n_clusters, 
            max_iter=max_iter, 
            tol=tol, 
            init_method=init_method, 
            alpha = alpha,
            pcthreshold = pcthreshold,
            **kwargs
        )

        self.centroid_indices = None

    def choose_center(self, X, assignment):

        arranged_mask = torch.arange(self.n_clusters, device=X.device)[:, None]
        expanded_assignmentt = assignment[None].expand(self.n_clusters, -1)
        mask = (expanded_assignmentt==arranged_mask).to(X.dtype)
        cluster_filter = (1 - mask).transpose(0,1) * (1e6)

        if torch.rand(1) < self.alpha:
            centers = mask @ X / mask.sum(-1)[..., :, None]

            dist_X_mean = torch.cdist(X.unsqueeze(0), centers.unsqueeze(0), p=2).squeeze(0) # num_samples, num_clusters
            final_dist = dist_X_mean + cluster_filter
            center_index = final_dist.argmin(dim=0)
        else:
            z = self.get_z(X[0:1,:])
            dist_X_z = torch.cdist(X.unsqueeze(0), z.unsqueeze(0), p=2).squeeze(0) # num_samples, 1
            expaned_dist_X_z = dist_X_z.expand(-1, self.n_clusters)
            final_dist = expaned_dist_X_z + cluster_filter

            center_index = final_dist.argmin(dim=0)
            centers = X[center_index]
            
        return centers, center_index

    def fit(self, X, centroids=None):
        batch_size, emb_dim = X.shape
        device = X.device

        if self.pcthreshold is not None:
            estimated_parents_index, estimated_child_index = get_parent_child_subset(X, threshold=self.pcthreshold)
            if len(estimated_parents_index) > 0:
                centroid_candidates = X[estimated_parents_index]
            else:
                print("No parent found, using all samples")
                centroid_candidates = X
        else:
            estimated_parents_index, estimated_child_index  = None, None
            centroid_candidates = X

        if centroids is None:
            self.centroids = self.init_func(centroid_candidates, self.n_clusters)
        else:
            self.centroids = centroids

        closest = None
        arranged_mask = torch.arange(self.n_clusters, device=device)[:, None]
        
        for i in range(self.max_iter):

            closest = self.min_dist(a=X, b=self.centroids)[1]

            new_centers, new_center_indices = self.choose_center(X, closest)

            error = (new_centers - self.centroids).pow(2).sum()

            self.centroids = new_centers
            self.centroid_indices = new_center_indices
            
            if error <= self.tol:
                break
                
        self._labels = closest
        return closest

    def get_parent(self, X):

        cluster_label = self._labels
        cluster_parent = self.centroid_indices 

        self.cluster_parent = cluster_parent
        self.cluster_label = cluster_label
        return cluster_parent

class MeanShift_Phylogeny(MeanShift):
    def __init__(
                self, 
                max_iter=100, 
                bandwidth=0.1, 
                centroid_merge_threshold=0.01, 
                alpha = 1,
                **kwargs,
            ):
        super().__init__(
            max_iter=max_iter, 
            bandwidth=bandwidth, 
            centroid_merge_threshold=centroid_merge_threshold, 
            alpha = alpha,
            **kwargs
        )

        self.centroid = None
        self.centroid_indices = None

    def choose_center(self, X, bool_mask):

        mask = (bool_mask).to(X.dtype)
        cluster_filter = (1 - mask).transpose(0,1) * (1e6)

        if torch.rand(1) < self.alpha:
            centers = mask @ X / mask.sum(-1)[..., :, None]

            dist_X_mean = torch.cdist(X.unsqueeze(0), centers.unsqueeze(0), p=2).squeeze(0) # num_samples, num_clusters
            final_dist = dist_X_mean + cluster_filter
            center_index = final_dist.argmin(dim=0)
        else:
            z = self.get_z(X[0:1,:])
            dist_X_z = torch.cdist(X.unsqueeze(0), z.unsqueeze(0), p=2).squeeze(0) # num_samples, 1
            expaned_dist_X_z = dist_X_z.expand(-1, cluster_filter.shape[1])
            final_dist = expaned_dist_X_z + cluster_filter

            center_index = final_dist.argmin(dim=0)
            centers = X[center_index]
            
        return centers, center_index

    def one_iter(self, centers, X, bandwidth):
        cdist = torch.cdist(centers.unsqueeze(0), X.unsqueeze(0), p=2).squeeze(0)
        mask = cdist < bandwidth # n_centers, n_samples
        centers, center_index = self.choose_center(X, mask)
        return centers, center_index

    def fit(self, X, centroids=None, pcthreshold = None):
        centers = X
        for i in range(self.max_iter):
            centers, center_index = self.one_iter(centers, X, self.bandwidth)

        self.centroids = centers
        self.centroid_indices = center_index

        _, labels = self.min_dist(X, centers)
        self._labels = labels
        return labels

    def get_parent(self, X):

        cluster_label = self._labels
        cluster_parent = self.centroid_indices 

        self.cluster_parent = cluster_parent
        self.cluster_label = cluster_label
        return cluster_parent

def get_alg(args):
    if args.method == "DBSCAN":
        return DBSCAN(eps = args.eps, z_distribution = args.z)
    elif args.method == "GMM":
        return GMM(n_components=args.n_clusters, z_distribution = args.z)
    elif args.method == "KMeans":
        return KMeans(n_clusters=args.n_clusters, max_iter=args.max_iter, tol=args.tol, pcthreshold = None, z_distribution = args.z)
    elif args.method == "MeanShift":
        return MeanShift(max_iter=args.max_iter, bandwidth=args.bandwidth, centroid_merge_threshold=args.centroid_merge_threshold, z_distribution = args.z)
    elif args.method == "KMeans_Phylogeny":
        return KMeans_Phylogeny(n_clusters=args.n_clusters, max_iter=args.max_iter, tol=args.tol, pcthreshold = args.pcthreshold, alpha = args.alpha, z_distribution = args.z)
    elif args.method == "MeanShift_Phylogeny":
        return MeanShift_Phylogeny(max_iter=args.max_iter, bandwidth=args.bandwidth, centroid_merge_threshold=args.centroid_merge_threshold, alpha = args.alpha, z_distribution = args.z)
    else:
        raise ValueError(f"method {args.method} is not supported")