# cython: language_level=2
import numpy as np
cimport numpy as np
from libc.math cimport INFINITY


cdef inline np.npy_int64 condensed_index(np.npy_int64 n, np.npy_int64 i,
                                         np.npy_int64 j):
    """
    Calculate the condensed index of element (i, j) in an n x n condensed
    matrix.
    """
    if i < j:
        return n * i - (i * (i + 1) / 2) + (j - i - 1)
    elif i > j:
        return n * j - (j * (j + 1) / 2) + (i - j - 1)


cdef double _average(double d_xi, double d_yi, double d_xy,
                     int size_x, int size_y, int size_i):
    return (size_x * d_xi + size_y * d_yi) / (size_x + size_y)


def nn_chain_improved(double[:] dists, int n, int max_size):
    """Perform hierarchy clustering using nearest-neighbor chain algorithm.
    Parameters
    ----------
    dists : ndarray
        A condensed matrix stores the pairwise distances of the observations.
    n : int
        The number of observations.
    max_size : int
        Maximum size of the one cluster
    Returns
    -------
    Z : ndarray, shape (n - 1, 4)
        Computed linkage matrix.
    """
    Z_arr = np.zeros((n - 1, 4))
    cdef double[:, :] Z = Z_arr

    cdef double[:] D = dists.copy()  # Distances between clusters.
    cdef int[:] size = np.ones(n, dtype=np.intc)  # Sizes of clusters.
    cdef int[:] unforbid = np.ones(n, dtype=np.intc)

    # Variables to store neighbors chain.
    cdef int[:] cluster_chain = np.ndarray(n, dtype=np.intc)
    cdef int chain_length = 0

    cdef int i, j, k, x, y = 0, nx, ny, ni
    cdef double dist, current_min

    for k in range(n - 1):
        if chain_length == 0:
            for i in range(n):
                if unforbid[i]:
                    cluster_chain[0] = i
                    chain_length = 1
                    break
        if chain_length == 0: break

        # Go through chain of neighbors until two mutual neighbors are found.
        while True:
            x = cluster_chain[chain_length - 1]

            # We want to prefer the previous element in the chain as the
            # minimum, to avoid potentially going in cycles.
            if chain_length > 1:
                y = cluster_chain[chain_length - 2]
                current_min = D[condensed_index(n, x, y)]
            else:
                y = -1
                current_min = INFINITY

            for i in range(n):
                if (not unforbid[i]) or x == i:
                    continue
                if size[x] + size[i] > max_size:
                    continue

                dist = D[condensed_index(n, x, i)]
                if dist < current_min:
                    current_min = dist
                    y = i

            if y == -1: break
            if chain_length > 1 and y == cluster_chain[chain_length - 2]:
                break

            cluster_chain[chain_length] = y
            chain_length += 1

        if chain_length == 1: 
            unforbid[cluster_chain[0]] = 0
            chain_length = 0
            continue

        # Merge clusters x and y and pop them from stack.
        chain_length -= 2

        # This is a convention used in fastcluster.
        if x > y:
            x, y = y, x

        # get the original numbers of points in clusters x and y
        nx = size[x]
        ny = size[y]

        # Record the new node.
        Z[k, 0] = x
        Z[k, 1] = y
        Z[k, 2] = current_min
        Z[k, 3] = nx + ny
        size[x] = 0  # Cluster x will be dropped.
        unforbid[x] = 0
        size[y] = nx + ny  # Cluster y will be replaced with the new cluster
        if size[y] == max_size: unforbid[y] = 0

        # Update the distance matrix.
        for i in range(n):
            ni = size[i]
            if ni == 0 or i == y:
                continue

            D[condensed_index(n, i, y)] = _average(
                D[condensed_index(n, i, x)],
                D[condensed_index(n, i, y)],
                current_min, nx, ny, ni)

    # Sort Z by cluster distances.
    order = np.argsort(Z_arr[:, 2], kind='mergesort')
    Z_arr = Z_arr[order]

    return Z_arr