# Process data batches

import torch

def unique_eigen(Lambda):
    # Use a delimiter that is small enough (smaller than all eigenvalues)
    # to separate eigenvalues for each graph
    delim = torch.min(Lambda) - 1e6

    batch_size, _ = Lambda.shape
    # Insert a delimiter before eigenvalues of every graph
    Lambda_delim = torch.cat([torch.full((batch_size, 1), delim,
                                         device=Lambda.device), Lambda], dim=1)

    uniqued, inverse, counts = torch.unique_consecutive(
        Lambda_delim, return_counts=True, return_inverse=True
    )

    mask_no_delim = torch.ones_like(uniqued, dtype=torch.bool)
    mask_no_delim[uniqued < torch.min(Lambda) - 1e5] = 0

    # Remove delimiter
    uniqued, inverse, counts = (uniqued[mask_no_delim], 
                                inverse[:, 1:] - 
                                torch.arange(1, batch_size + 1, 
                                             device=Lambda.device).unsqueeze(-1),
                                counts[mask_no_delim])
    
    return uniqued, inverse, counts

def process_data(Lambda, U, max_mult=None, use_signnet=False, num_eigenvecs=None, 
                 method='hard_split', projects=None):
    """
    Input:
        - Lambda: (batch_size, num_eigvals)
        - U: (batch_size, num_nodes, num_eigvals)
        - method: 'hard_split' or 'learnable_project'

    Output a dict:

        - keys = all possible eigenspace dimensions 
        (corresponding to nonzero eigvals)

        - Maps eigenspace dim to [a dict] containing

            * 'U': (num_eigenspaces, num_nodes, eigenspace_dim) 
            * 'Lambda': (num_eigenspaces, )
            * 'graph_idx': (num_eigenspaces, )
    
    If set "use_signnet = True", then output:
        
        - Lambda: (batch_size, num_eigenvecs)
        - U: (batch_size, num_nodes, num_eigenvecs)
    
    where "num_eigenvecs" must be provided.
    """
    batch_size, eigen_size = Lambda.shape

    if use_signnet:
        assert num_eigenvecs is not None
        assert num_eigenvecs <= eigen_size

        Lambda: torch.Tensor = Lambda
        sort_Lambda, sort_ind = Lambda.sort()
        left_ind = torch.arange(batch_size).unsqueeze(-1
                ).broadcast_to(*Lambda.shape)
        sort_U = U[left_ind, :, sort_ind].permute(0, 2, 1)

        nonzero_ind = (sort_Lambda != 0).to(torch.int8).argsort(descending=True)
        sort_Lambda = sort_Lambda[left_ind, nonzero_ind]
        sort_U = sort_U[left_ind, :, nonzero_ind].permute(0, 2, 1)

        # Now zero eigenvalues are at the end, 
        # while keeping non-zero part increasing

        return (sort_Lambda[:, :num_eigenvecs], 
                sort_U[:, :, :num_eigenvecs])

    
    # Map an eigenspace to the graph it belongs to
    graph_idx = torch.arange(batch_size, device=Lambda.device
                             ).repeat_interleave(eigen_size)

    _, inverse, counts = unique_eigen(Lambda)
    count_sorted, sort_index = counts.sort()

    sort_inverse = torch.empty_like(sort_index)
    sort_inverse[sort_index] = torch.arange(sort_index.shape[0], 
                                            device=sort_index.device)
    # Map eigvals to indices into (array of eigvals sorted by mult)
    Lambda_sorted_idx = sort_inverse[inverse]

    main_idx = Lambda_sorted_idx.reshape(-1).argsort()

    Lambda_sorted = Lambda.reshape(-1)[main_idx]
    U_sorted = U.permute(1, 0, 2).reshape(-1, Lambda_sorted.shape[0])[:, main_idx]
    graph_idx_sorted = graph_idx[main_idx]

    # Non-zero eigenvalues
    eigen_mask = (Lambda_sorted != 0)
    Lambda_sorted = Lambda_sorted[eigen_mask]
    U_sorted, graph_idx_sorted = U_sorted[:, eigen_mask], graph_idx_sorted[eigen_mask]

    counts_set, counts_counts = torch.unique_consecutive(count_sorted, 
                                                         return_counts=True)
    split_eigenspace = torch.arange(counts_set.shape[0], device=counts_set.device
                                    ).repeat_interleave(
        counts_set * counts_counts)[eigen_mask]
    
    out_dict = {}

    for i in range(counts_set.shape[0]):
        mask_i = split_eigenspace == i
        dim = counts_set[i]
        num_masks = mask_i.sum()
        cc = int(num_masks / dim)
        if num_masks != 0:
            if max_mult is None or dim <= int(max_mult):
                out_dict[int(dim)] = {
                    'U': # (num_eigenspaces, num_nodes, eigenspace_dim) 
                    U_sorted[:, mask_i].reshape(-1, cc, dim).permute(1, 0, 2),
                    'Lambda': # (num_eigenspaces, )
                    Lambda_sorted[mask_i].reshape(cc, dim)[:, 0],
                    'graph_idx': # (num_eigenspaces, )
                    graph_idx_sorted[mask_i].reshape(cc, dim)[:, 0],
                }
            else:
                max_mult = int(max_mult)

                if method == 'hard_split':
                    n, r = int(dim // max_mult), int(dim % max_mult)
                    if max_mult in out_dict:
                        out_dict[max_mult]['U'] = torch.cat(
                            [out_dict[max_mult]['U'],
                            U_sorted[:, mask_i].reshape(-1, cc, dim
                            )[:, :, :n*max_mult].reshape(-1, cc*n, max_mult
                            ).permute(1, 0, 2)]
                        )
                        out_dict[max_mult]['Lambda'] = torch.cat(
                            [out_dict[max_mult]['Lambda'],
                            Lambda_sorted[mask_i].reshape(cc, dim)[:, 0
                            ].repeat_interleave(n)]
                        )
                        out_dict[max_mult]['graph_idx'] = torch.cat(
                            [out_dict[max_mult]['graph_idx'],
                            graph_idx_sorted[mask_i].reshape(cc, dim)[:, 0
                            ].repeat_interleave(n)]
                        )
                    else:
                        out_dict[max_mult] = {}
                        out_dict[max_mult]['U'] = U_sorted[
                            :, mask_i].reshape(-1, cc, dim
                            )[:, :, :n*max_mult].reshape(-1, cc*n, max_mult
                            ).permute(1, 0, 2)
                        out_dict[max_mult]['Lambda'] = Lambda_sorted[
                            mask_i].reshape(cc, dim)[:, 0
                            ].repeat_interleave(n)
                        out_dict[max_mult]['graph_idx'] = graph_idx_sorted[
                            mask_i].reshape(cc, dim)[:, 0
                            ].repeat_interleave(n)

                    if r != 0:
                        if r in out_dict:
                            out_dict[r]['U'] = torch.cat(
                                [out_dict[r]['U'],
                                U_sorted[:, mask_i].reshape(-1, cc, dim
                                )[:, :, n*max_mult:].permute(1, 0, 2)]
                            )
                            out_dict[r]['Lambda'] = torch.cat(
                                [out_dict[r]['Lambda'],
                                Lambda_sorted[mask_i].reshape(cc, dim)[:, 0]]
                            )
                            out_dict[r]['graph_idx'] = torch.cat(
                                [out_dict[r]['graph_idx'],
                                graph_idx_sorted[mask_i].reshape(cc, dim)[:, 0]]
                            )
                        else:
                            out_dict[r] = {}
                            out_dict[r]['U'] = U_sorted[:, mask_i].reshape(
                                -1, cc, dim)[:, :, n*max_mult:].permute(1, 0, 2)
                            out_dict[r]['Lambda'] = Lambda_sorted[
                                mask_i].reshape(cc, dim)[:, 0]
                            out_dict[r]['graph_idx'] = graph_idx_sorted[
                                mask_i].reshape(cc, dim)[:, 0]

                elif method == 'learnable_project':
                    proj = projects[str(int(dim))]
                    U_ = proj(U_sorted[:, mask_i].reshape(-1, cc, dim).permute(1, 0, 2))
                    L_ = Lambda_sorted[mask_i].reshape(cc, dim)[:, 0]
                    ID_ = graph_idx_sorted[mask_i].reshape(cc, dim)[:, 0]

                    if max_mult in out_dict:
                        out_dict[max_mult]['U'] = torch.cat(
                            [out_dict[max_mult]['U'], U_]
                        )
                        out_dict[max_mult]['Lambda'] = torch.cat(
                            [out_dict[max_mult]['Lambda'], L_]
                        )
                        out_dict[max_mult]['graph_idx'] = torch.cat(
                            [out_dict[max_mult]['graph_idx'], ID_]
                        )
                    else:
                        out_dict[max_mult] = {}
                        out_dict[max_mult]['U'] = U_
                        out_dict[max_mult]['Lambda'] = L_
                        out_dict[max_mult]['graph_idx'] = ID_

    return out_dict

