import itertools
import math

import numpy as np
import numba
from scipy.spatial.distance import cdist, pdist
import amd
from amd.periodicset import PeriodicSet
from amd.periodicset import PeriodicSet
from amd._nearest_neighbors import _integer_lattice_generator, _lattice_to_cloud
from amd.calculate import _collapse_into_groups, PDD, AMD_estimate, AMD


def ADA(pset: PeriodicSet, k: int):
    """Returns the 'pointwise deviation from asymptotic' of a periodic
    set. This is the difference between the actual values of the PDD
    and its asymptotic curve, essentially a 'normalisation' of the PDD.
    """
    return  AMD(pset, k) - AMD_estimate(pset, k)


def PDA(pset: PeriodicSet, k: int):
    """Returns the 'pointwise deviation from asymptotic' of a periodic
    set. This is the difference between the actual values of the PDD
    and its asymptotic curve, essentially a 'normalisation' of the PDD.
    """
    pdd = PDD(pset, k)
    pdd[:, 1:] -= AMD_estimate(pset, k)
    return pdd


def pauling_homometric_crystal(u):
    motif = np.array([
        [u,0,0.25],[-u,0.5,0.25],[0.5-u,0,0.75],[u+0.5,0.5,0.75],[0.25,u,0],
        [0.25,-u,0.5],[0.75,0.5-u,0],[0.75,u+0.5,0.5],[0,0.25,u],[0.5,0.25,-u],
        [0,0.75,0.5-u],[0.5,0.75,u+0.5],[-u,0,0.75],[u,0.5,0.75],[u+0.5,0,0.25],
        [0.5-u,0.5,0.25],[0.75,-u,0],[0.75,u,0.5],[0.25,u+0.5,0],[0.25,0.5-u,0.5],
        [0,0.75,-u],[0.5,0.75,u],[0,0.25,u+0.5],[0.5,0.25,0.5-u],
    ])
    return amd.PeriodicSet(
        np.mod(motif, 1), np.identity(3),
        asym_unit=np.array([0]),
        multiplicities=np.array([len(motif)])
    )


def PDM(pset: PeriodicSet, k: int, h: int, l: int, flatten: bool = False):
    """Return the moments of PDDh up to k up to the l-th moment for a periodic
    set.
    """
    pdd = PDDh(pset, k, h, lexsort=False)
    pdm = _matrix_moments(pdd, l)
    if flatten:
        pdm = pdm.reshape((pdm.shape[0] * pdm.shape[1], ))
    return pdm


def PDDh(
        pset: PeriodicSet,
        k: int,
        h: int,
        lexsort: bool = True,
        collapse: bool = True,
        collapse_tol: float = 1e-4,
        return_row_groups: bool = False
):
    """Return the h-th order PDD up to k for a periodic set."""

    if h == 1:
        return PDD(
            pset, k,
            lexsort=lexsort,
            collapse=collapse,
            collapse_tol=collapse_tol,
            return_row_groups=return_row_groups
        )
    
    m = pset.motif.shape[0]
    if pset.asym_unit is None or pset.multiplicities is None:
        asym_unit = np.arange(len(pset.motif))
        weights = np.full((m, ), 1 / m, dtype=np.float64)
    else:
        asym_unit = pset.asym_unit
        weights = pset.multiplicities / m
    
    dists = _smallest_simplex_perimeters(pset.motif, pset.cell, asym_unit, k, h)

    groups = [[i] for i in range(len(dists))]
    
    if collapse:
        overlapping = pdist(dists, metric='chebyshev') <= collapse_tol
        if overlapping.any():
            groups = _collapse_into_groups(overlapping)
            weights = np.array([np.sum(weights[group]) for group in groups])
            dists = np.array(
                [np.average(dists[group], axis=0) for group in groups],
                dtype=np.float64
            )
    
    pdd = np.empty(shape=(len(dists), k + 1), dtype=np.float64)
    
    if lexsort:
        lex_ordering = np.lexsort(np.rot90(dists))
        pdd[:, 0] = weights[lex_ordering]
        pdd[:, 1:] = dists[lex_ordering]
        if return_row_groups:
            groups = [groups[i] for i in lex_ordering]
    else:
        pdd[:, 0] = weights
        pdd[:, 1:] = dists

    pdd[:, 1:] *= (2 / (h * (h + 1)))

    if return_row_groups:
        return pdd, groups
    return pdd


def PDDh_finite(
        cloud: np.ndarray,
        h: int,
        lexsort: bool = True,
        collapse: bool = True,
        collapse_tol: float = 1e-4
):
    """Return the h-th order PDD for a finite point set."""

    upto = math.comb(len(cloud) - 1, h)
    dists = np.sort(_all_simplex_perimeters(cloud, cloud, h))[:, :upto]
    m = len(cloud)
    weights = np.full((m, ), 1 / m, dtype=np.float64)
    groups = [[i] for i in range(len(dists))]
    
    if collapse:
        overlapping = pdist(dists, metric='chebyshev') <= collapse_tol
        if overlapping.any():
            groups = _collapse_into_groups(overlapping)
            weights = np.array([np.sum(weights[group]) for group in groups])
            dists = np.array(
                [np.average(dists[group], axis=0) for group in groups],
                dtype=np.float64
            )
    
    pdd = np.empty(shape=(len(dists), dists.shape[-1] + 1), dtype=np.float64)
    
    if lexsort:
        lex_ordering = np.lexsort(np.rot90(dists))
        pdd[:, 0] = weights[lex_ordering]
        pdd[:, 1:] = dists[lex_ordering]
    else:
        pdd[:, 0] = weights
        pdd[:, 1:] = dists

    pdd[:, 1:] *= (2 / (h * (h + 1)))
    return pdd


@numba.njit(cache=True, fastmath=True)
def _matrix_moments(pdd, l):
    k = pdd.shape[-1] - 1
    m = pdd.shape[0]
    ret = np.zeros((l, k), dtype=np.float64)
    weights = pdd[:, 0]
    for moment in range(1, l+1):
        ret[moment - 1] = (m ** (1/l - 1)) * (
            np.sum((pdd[:, 1:] ** moment) * weights[:, None], axis=0)
        ) ** (1. / moment)
    return ret


def _smallest_simplex_perimeters(motif, cell, x_inds, k, h):
    x_inds = x_inds.astype(np.int64)
    x = motif[x_inds]
    m, dims = motif.shape
    int_lat_gen = iter(_integer_lattice_generator(dims))
    
    layers = []
    n_points = 0
    while math.comb(n_points, h) < k:
        layer = next(int_lat_gen)
        layers.append(layer)
        n_points += len(layer) * m
    cloud = _lattice_to_cloud(motif, np.concatenate(layers) @ cell)
    x_cloud_cdist = cdist(x, cloud)
    dists = _simplex_perimeters(cloud, x_inds, h, x_cloud_cdist, np.amin(x_cloud_cdist, axis=0))
    dists.partition(k - 1)
    dists = dists[:, :k]
    dists.sort()

    motif_diam = cdist(x, motif).max()
    x_norm_max = np.linalg.norm(x, axis=-1).max()
    max_d = dists[:, -1].max()
    bound = max_d / h + motif_diam
    bound2 = max_d / h + x_norm_max
    
    while True:
        lattice = next(int_lat_gen) @ cell
        lattice = lattice[np.linalg.norm(lattice, axis=-1) <= bound]
        if lattice.size == 0:  # None are close enough
            break

        layer = _lattice_to_cloud(motif, lattice)
        cloud = np.vstack((cloud, layer))
        # we have to keep the motif in the same place in cloud so we can ignore
        # the asym points properly when searching through simplices
        mask = np.linalg.norm(cloud, axis=-1) <= bound2
        mask[:len(motif)] = True
        cloud = cloud[mask]
        x_cloud_cdist = cdist(x, cloud)
        x_cloud_mins = np.amin(x_cloud_cdist, axis=0)
        dists_ = _simplex_perimeters(cloud, x_inds, h, x_cloud_cdist, x_cloud_mins, max_d=max_d)
 
        if dists_.shape[-1] > k:
            dists_.partition(k - 1)
            dists_ = dists_[:, :k]
        dists_.sort()
        if np.allclose(dists, dists_):
            break
        dists = dists_.copy()
        max_d = dists[:, -1].max()
        bound = max_d / h + motif_diam
        bound2 = max_d / h + x_norm_max
    return dists


@numba.njit()
def _simplex_perimeters(cloud, x_inds, h, x_cloud_cdist, x_cloud_mins, max_d=None):
    """Find perimiters (sum of pairwise distances) of all simplices in the
    cloud, filtering out those which are too large.
    """

    ret = []
    len_x = len(x_cloud_cdist)
    n = len(cloud)
    indices = np.arange(h, dtype=np.int64)
    pointer = 1
    while True:

        partial_perim = 0
        for i_ in range(pointer):
            for j_ in range(i_+1, pointer+1):
                partial_perim += np.sqrt(((cloud[indices[i_]] - cloud[indices[j_]]) ** 2).sum())

        # bound = partial_perim + (np.sum(x_cloud_mins[indices[:pointer+1]]) * (h - pointer))
        bound = partial_perim + np.sum(x_cloud_mins[indices[:pointer+1]])

        if max_d is None or bound <= max_d:
            if pointer == h - 1:
                for x_i in range(len_x):
                    if np.any(indices == x_inds[x_i]):
                        ret.append(np.inf)
                    else:
                        ret.append(partial_perim + np.sum(x_cloud_cdist[x_i, indices]))
            else:
                pointer += 1
                indices[pointer] = indices[pointer - 1] + 1
                continue

        done = True
        for i in range(pointer, -1, -1):
            if indices[i] != i + n - h:
                done = False
                break
        if done:
            break
        indices[i] += 1
        pointer = i

    n_simplices = int(len(ret) / len_x)
    ret_ = np.empty((len_x, n_simplices), dtype=np.float64)
    p = 0
    for j in range(n_simplices):
        for i in range(len_x):
            ret_[i, j] = ret[p]
            p += 1
    return ret_


def _all_simplex_perimeters(x, cloud, h):
    """Return an array shape (|x|, |cloud| choose h) containing the sum
    of pairwise distances between points in all simplices including
    at least one point of x and h points of cloud.
    
    NOTE: if a point of x is in cloud, it will not be selected twice for
    a simplex.
    """
    
    @numba.njit()
    def _fill_all_simplex_perims(dm, cloud, all_inds, ret):
        for j, inds in enumerate(all_inds):
            subset = cloud[inds]
            partial_perim = 0
            for i_ in range(len(subset)):
                for j_ in range(i_+1, len(subset)):
                    partial_perim += np.sqrt(((subset[i_] - subset[j_]) ** 2).sum())
            for x_ind in range(ret.shape[0]):
                norms = dm[x_ind, inds]
                if np.any(norms < 1e-15):
                    ret[x_ind, j] = np.inf
                else:
                    ret[x_ind, j] = np.sum(norms) + partial_perim
    
    ret = np.empty((len(x), math.comb(len(cloud), h)), dtype=np.float64)
    all_inds = np.fromiter(
        itertools.chain.from_iterable(
            itertools.combinations(range(len(cloud)), h)
        ), int).reshape(-1, h)
    dm = cdist(x, cloud)
    _fill_all_simplex_perims(dm, cloud, all_inds, ret)
    return ret





if __name__ == '__main__':

    cubic_pset = amd.PeriodicSet.cubic()
    k = 100
    h = 2
    p = PDDh(cubic_pset, k, h)
    print(f'Cubic lattice PDD, h={h}, k=5: ', p[:, :5])

    hex_pset = amd.PeriodicSet.hexagonal()
    p2 = PDDh(hex_pset, k, h)
    emd = amd.EMD(p, p2)
    print(f'Hexagonal lattice PDD, h={h}, k=5: ', p2[:, :5])
    print(f'Distance between cubic and hexagonal lattices by PDD, h={h}, k={k}: ', emd)
    
    c1 = pauling_homometric_crystal(0.03)
    c2 = pauling_homometric_crystal(-0.03)
    emd1 = amd.EMD(PDDh(c1, k, 1), PDDh(c2, k, 1))
    emd2 = amd.EMD(PDDh(c1, k, 2), PDDh(c2, k, 2))
    emd3 = amd.EMD(PDDh(c1, k, 3), PDDh(c2, k, 3))
    print(f"Distance between Pauling's homometric structures by PDD^h, k={k}:")
    print(f"PDD^1 earth mover's distance = {emd1}")
    print(f"PDD^2 earth mover's distance = {emd2}")
    print(f"PDD^3 earth mover's distance = {emd3}")
