import amd
import collections
from scipy.spatial.distance import squareform, pdist
import numpy as np


def _collapse_into_groups(overlapping):

    overlapping = squareform(overlapping)
    group_nums = {}  # row_ind: group number
    group = 0
    for i, row in enumerate(overlapping):
        if i not in group_nums:
            group_nums[i] = group
            group += 1

            for j in np.argwhere(row).T[0]:
                if j not in group_nums:
                    group_nums[j] = group_nums[i]

    groups = collections.defaultdict(list)
    for row_ind, group_num in sorted(group_nums.items()):
        groups[group_num].append(row_ind)
    groups = list(groups.values())

    return groups


def custom_PDD(
        periodic_set,
        k: int,
        lexsort: bool = False,
        collapse: bool = False,
        collapse_tol: float = 1e-4,
        return_row_groups: bool = True,
        constrained: bool = True,
) -> np.ndarray:

    motif, cell, asymmetric_unit, weights = extract_motif_cell(periodic_set)
    weights = np.full((len(motif),), 1 / len(motif))
    dists, cloud, inds = amd.nearest_neighbours(motif, cell, motif, k)
    groups = [[i] for i in range(len(dists))]

    if collapse and collapse_tol >= 0:
        overlapping = pdist(dists, metric='chebyshev')
        overlapping = overlapping <= collapse_tol
        types_match = pdist(periodic_set.types.reshape((-1, 1))) == 0
        neighbors_match = (pdist(periodic_set.types[inds % periodic_set.types.shape[0]]) == 0)

        if constrained:
            overlapping = overlapping & types_match & neighbors_match
        if overlapping.any():
            groups = _collapse_into_groups(overlapping)
            weights = np.array([sum(weights[group]) for group in groups])
            dists = np.array([np.average(dists[group], axis=0) for group in groups])

    pdd = np.hstack((weights[:, None], dists))

    if lexsort:
        lex_ordering = np.lexsort(np.rot90(dists))
        if return_row_groups:
            groups = [groups[i] for i in lex_ordering]
        pdd = pdd[lex_ordering]

    if return_row_groups:
        return pdd, groups, inds, cloud
    else:
        return pdd, inds, cloud


def extract_motif_cell(pset: amd.PeriodicSet):

    if isinstance(pset, amd.PeriodicSet):
        motif, cell = pset.motif, pset.cell
        asym_unit = pset.asymmetric_unit
        wyc_muls = pset.wyckoff_multiplicities
        if asym_unit is None or wyc_muls is None:
            asymmetric_unit = motif
            weights = np.full((len(motif),), 1 / len(motif))
        else:
            asymmetric_unit = pset.motif[asym_unit]
            weights = wyc_muls / np.sum(wyc_muls)
    else:
        motif, cell = pset
        asymmetric_unit = motif
        weights = np.full((len(motif),), 1 / len(motif))

    return motif, cell, asymmetric_unit, weights
