import torch
import torch.nn as nn
import torch.nn.functional as F

import math
from einops import rearrange
from .pair_wise_distance import PairwiseDistFunction

def index_reverse(index):
    index_r = torch.zeros_like(index)
    ind = torch.arange(0, index.shape[-1]).to(index.device)
    for i in range(index.shape[0]):
        index_r[i, index[i, :]] = ind
    return index_r

def superpixel_gather(x, index):
    dim = index.dim()
    assert x.shape[:dim] == index.shape, "x ({:}) and index ({:}) shape incompatible".format(x.shape, index.shape)

    for _ in range(x.dim() - index.dim()):
        index = index.unsqueeze(-1)
    index = index.expand(x.shape)

    shuffled_x = torch.gather(x, dim=dim - 1, index=index)
    return shuffled_x

def calc_init_centroid(images, num_spixels_width, num_spixels_height):
    """
    calculate initial superpixels
    Args:
        images: torch.Tensor
            A Tensor of shape (B, C, H, W)
        spixels_width: int
            initial superpixel width
        spixels_height: int
            initial superpixel height
    Return:
        centroids: torch.Tensor
            A Tensor of shape (B, C, H * W)
        init_label_map: torch.Tensor
            A Tensor of shape (B, H * W)
        num_spixels_width: int
            A number of superpixels in each column
        num_spixels_height: int
            A number of superpixels int each raw
    """
    batchsize, channels, height, width = images.shape
    device = images.device

    centroids = torch.nn.functional.adaptive_avg_pool2d(images, (num_spixels_height, num_spixels_width))

    with torch.no_grad():
        num_spixels = num_spixels_width * num_spixels_height
        labels = torch.arange(num_spixels, device=device).reshape(1, 1, *centroids.shape[-2:]).type_as(centroids)
        init_label_map = torch.nn.functional.interpolate(labels, size=(height, width), mode="nearest")
        init_label_map = init_label_map.repeat(batchsize, 1, 1, 1)

    init_label_map = init_label_map.reshape(batchsize, -1)
    centroids = centroids.reshape(batchsize, channels, -1)

    return centroids, init_label_map


@torch.no_grad()
def get_abs_indices(init_label_map, num_spixels_width):
    b, n_pixel = init_label_map.shape
    device = init_label_map.device
    r = torch.arange(-1, 2.0, device=device)
    relative_spix_indices = torch.cat([r - num_spixels_width, r, r + num_spixels_width], 0)

    abs_pix_indices = torch.arange(n_pixel, device=device)[None, None].repeat(b, 9, 1).reshape(-1).long()
    abs_spix_indices = (init_label_map[:, None] + relative_spix_indices[None, :, None]).reshape(-1).long()
    abs_batch_indices = torch.arange(b, device=device)[:, None, None].repeat(1, 9, n_pixel).reshape(-1).long()

    return torch.stack([abs_batch_indices, abs_spix_indices, abs_pix_indices], 0)


@torch.no_grad()
def get_hard_abs_labels(affinity_matrix, init_label_map, num_spixels_width):
    relative_label = affinity_matrix.max(1)[1]
    r = torch.arange(-1, 2.0, device=affinity_matrix.device)
    relative_spix_indices = torch.cat([r - num_spixels_width, r, r + num_spixels_width], 0)
    label = init_label_map + relative_spix_indices[relative_label]
    return label.long()


def ssn_iter(pixel_features, stoken_size=[16,16], n_iter=2):
    """
    computing assignment iterations
    detailed process is in Algorithm 1, line 2 - 6
    Args:
        pixel_features: torch.Tensor
            A Tensor of shape (B, C, H, W)
        num_spixels: int
            A number of superpixels
        n_iter: int
            A number of iterations
        return_hard_label: bool
            return hard assignment or not
    """
    height, width = pixel_features.shape[-2:]
    sheight, swidth = stoken_size
    num_spixels_height = height // sheight
    num_spixels_width = width // swidth
    num_spixels = num_spixels_height * num_spixels_width

    spixel_features, init_label_map = \
        calc_init_centroid(pixel_features, num_spixels_width, num_spixels_height)
    abs_indices = get_abs_indices(init_label_map, num_spixels_width)

    pixel_features = pixel_features.reshape(*pixel_features.shape[:2], -1)
    permuted_pixel_features = pixel_features.permute(0, 2, 1).contiguous()

    with torch.no_grad():
        for k in range(n_iter):
            if k < n_iter - 1:

                dist_matrix = PairwiseDistFunction.apply(
                    pixel_features, spixel_features, init_label_map, num_spixels_width, num_spixels_height)

                affinity_matrix = (-dist_matrix).softmax(1)
                reshaped_affinity_matrix = affinity_matrix.reshape(-1)

                mask = (abs_indices[1] >= 0) * (abs_indices[1] < num_spixels)
                sparse_abs_affinity = torch.sparse_coo_tensor(abs_indices[:, mask], reshaped_affinity_matrix[mask])

                abs_affinity = sparse_abs_affinity.to_dense().contiguous()
                spixel_features = torch.bmm(abs_affinity, permuted_pixel_features) \
                    / (abs_affinity.sum(2, keepdim=True) + 1e-16)

                spixel_features = spixel_features.permute(0, 2, 1).contiguous()
            else:
                dist_matrix = PairwiseDistFunction.apply(
                pixel_features, spixel_features, init_label_map, num_spixels_width, num_spixels_height)

                affinity_matrix = (-dist_matrix).softmax(1)
                reshaped_affinity_matrix = affinity_matrix.reshape(-1)

                mask = (abs_indices[1] >= 0) * (abs_indices[1] < num_spixels)
                sparse_abs_affinity = torch.sparse_coo_tensor(abs_indices[:, mask], reshaped_affinity_matrix[mask])

                abs_affinity = sparse_abs_affinity.to_dense().contiguous()
                spixel_features = torch.bmm(abs_affinity, permuted_pixel_features) \
                    / (abs_affinity.sum(2, keepdim=True) + 1e-16)

        spixel_features = spixel_features.permute(0, 2, 1).contiguous()
        # hard_labels = get_hard_abs_labels(affinity_matrix, init_label_map, num_spixels_width)

    return abs_affinity, spixel_features


def ssn_iter_origin(pixel_features, num_spixels, n_iter):
    """
    computing assignment iterations
    detailed process is in Algorithm 1, line 2 - 6

    Args:
        pixel_features: torch.Tensor
            A Tensor of shape (B, C, H, W)
        num_spixels: int
            A number of superpixels
        n_iter: int
            A number of iterations
        return_hard_label: bool
            return hard assignment or not
    """
    height, width = pixel_features.shape[-2:]
    num_spixels_width = int(math.sqrt(num_spixels * width / height))
    num_spixels_height = int(math.sqrt(num_spixels * height / width))

    spixel_features, init_label_map = \
        calc_init_centroid(pixel_features, num_spixels_width, num_spixels_height)
    abs_indices = get_abs_indices(init_label_map, num_spixels_width)

    pixel_features = pixel_features.reshape(*pixel_features.shape[:2], -1)
    permuted_pixel_features = pixel_features.permute(0, 2, 1).contiguous()

    with torch.no_grad():
        for _ in range(n_iter):
            dist_matrix = PairwiseDistFunction.apply(
                pixel_features, spixel_features, init_label_map, num_spixels_width, num_spixels_height)

            affinity_matrix = (-dist_matrix).softmax(1)
            reshaped_affinity_matrix = affinity_matrix.reshape(-1)

            mask = (abs_indices[1] >= 0) * (abs_indices[1] < num_spixels)
            sparse_abs_affinity = torch.sparse_coo_tensor(abs_indices[:, mask], reshaped_affinity_matrix[mask])

            abs_affinity = sparse_abs_affinity.to_dense().contiguous()
            spixel_features = torch.bmm(abs_affinity, permuted_pixel_features) \
                / (abs_affinity.sum(2, keepdim=True) + 1e-16)

            spixel_features = spixel_features.permute(0, 2, 1).contiguous()
            # print('sp',dist_matrix.shape)


        # hard_labels = get_hard_abs_labels(affinity_matrix, init_label_map, num_spixels_width)
    # print(spixel_features.shape)
    return abs_affinity, spixel_features


class GenSP(nn.Module):
    def __init__(self, n_iter=3):
        super().__init__()
        
        self.n_iter = n_iter        
    def forward(self, x, num_spixels):
        # soft_association, spixel_features = ssn_iter(x, [num_spixels, num_spixels], self.n_iter)
        soft_association, spixel_features = ssn_iter_origin(x, num_spixels, self.n_iter)
        return soft_association, spixel_features

if __name__ == '__main__':
    import numpy as np
    # spa = GenSP(3)
    # # a = torch.ones((1,3,64,64))
    a = torch.randn((1,3,16,16))
    x_input = a.flatten(2).permute(0, 2, 1)
    print(x_input.shape)
    # stoken=[2,2]
    # affinity_matrix, num_spixels = spa(a, stoken)
    # print("_____________iter___________")
    # print(affinity_matrix, num_spixels)
    # pixel_features = a.reshape(*a.shape[:2], -1)
    # print(pixel_features.shape)
    # spixel_features_1 = torch.bmm(affinity_matrix, pixel_features.permute(0, 2, 1).contiguous()) / (affinity_matrix.sum(2, keepdim=True) + 1e-16)
    # print(spixel_features_1.shape)
    abs_affinity, hard_labels = ssn_iter_origin(a, 6, 3)
    print("_____________ssn_iter_origin___________")
    print(abs_affinity.shape, hard_labels.shape)
    print(hard_labels.shape)
    x_sort_values, x_sort_indices = torch.sort(abs_affinity, dim=-1, stable=False)
    print(x_sort_indices.shape)
    x_sort_indices_reverse = index_reverse(x_sort_indices)
    print(x_sort_indices_reverse)
    semantic_x = superpixel_gather(x_input, x_sort_indices)
    y = superpixel_gather(semantic_x, x_sort_indices_reverse)
    print(semantic_x.shape)
    # print("++++++++++++++++++++++++++++")
    # print(abs_affinity.shape, hard_labels.shape, spixel_features.shape)
    # print("===========================")
    # print(torch.max(hard_labels), torch.min(hard_labels))
    # print((abs_affinity.sum(2, keepdim=True) + 1e-16).shape)
    # ata = torch.bmm(abs_affinity.permute(0,2,1), abs_affinity) / (abs_affinity.permute(0,2,1).sum(2, keepdim=True) + 1e-16)
    # print(ata)
    # recon = torch.bmm(abs_affinity.permute(0,2,1).contiguous(),(spixel_features.permute(0,2,1).contiguous() * (abs_affinity.sum(2, keepdim=True) + 1e-16)))
    # print(recon.shape)
    # print("=================================")
    # b = torch.randn((1,25,64))
    # def pairwise_distance(x):
    #     """
    #     Compute pairwise distance of a point cloud.
    #     Args:
    #         x: tensor (batch_size, num_points, num_dims)
    #     Returns:
    #         pairwise distance: (batch_size, num_points, num_points)
    #     """
    #     with torch.no_grad():
    #         x_inner = -2*torch.matmul(x, x.transpose(2, 1))
    #         x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True)
    #         return x_square + x_inner + x_square.transpose(2, 1)
    
    # dist = pairwise_distance(b)
    # _, nn_idx = torch.topk(-dist, k=16) # b, n, k
    # center_idx = torch.arange(0, 25).repeat(1, 16, 1).transpose(2, 1)
    # print(nn_idx)
    # print(torch.stack((nn_idx, center_idx), dim=0))
