import torch
import torch.nn.functional as F
import torch.nn as nn
from core.utils.utils import bilinear_sampler, coords_grid, manual_pad

try:
    import corr_sampler
except:
    pass

try:
    import alt_cuda_corr
except:
    # alt_cuda_corr is not compiled
    pass
import math


class CorrSampler(torch.autograd.Function):
    @staticmethod
    def forward(ctx, volume, coords, radius):
        ctx.save_for_backward(volume, coords)
        ctx.radius = radius
        (corr,) = corr_sampler.forward(volume, coords, radius)
        return corr

    @staticmethod
    def backward(ctx, grad_output):
        volume, coords = ctx.saved_tensors
        grad_output = grad_output.contiguous()
        (grad_volume,) = corr_sampler.backward(volume, coords, grad_output, ctx.radius)
        return grad_volume, None, None


class CorrBlockFast1D:
    def __init__(self, fmap1, fmap2, num_levels=4):
        self.num_levels = num_levels
        self.corr_pyramid = []
        # all pairs correlation
        corr = CorrBlockFast1D.corr(fmap1, fmap2)
        batch, h1, w1, dim, w2 = corr.shape
        corr = corr.reshape(batch * h1 * w1, dim, 1, w2)
        for i in range(self.num_levels):
            self.corr_pyramid.append(corr.view(batch, h1, w1, -1, w2 // 2**i))
            corr = F.avg_pool2d(corr, [1, 2], stride=[1, 2])

    def __call__(self, coords, radius):
        out_pyramid = []
        bz, _, ht, wd = coords.shape
        coords = coords[:, [0]]
        for i in range(self.num_levels):
            corr = CorrSampler.apply(
                self.corr_pyramid[i].squeeze(3), coords / 2**i, radius
            )
            out_pyramid.append(corr.view(bz, -1, ht, wd))
        return torch.cat(out_pyramid, dim=1)

    @staticmethod
    def corr(fmap1, fmap2):
        B, D, H, W1 = fmap1.shape
        _, _, _, W2 = fmap2.shape
        fmap1 = fmap1.view(B, D, H, W1)
        fmap2 = fmap2.view(B, D, H, W2)
        corr = torch.einsum("aijk,aijh->ajkh", fmap1, fmap2)
        corr = corr.reshape(B, H, W1, 1, W2).contiguous()
        return corr / torch.sqrt(torch.tensor(D).float())


class PytorchAlternateCorrBlock1D:
    def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
        self.num_levels = num_levels
        self.radius = radius
        self.corr_pyramid = []
        self.fmap1 = fmap1
        self.fmap2 = fmap2

    def corr(self, fmap1, fmap2, coords):
        B, D, H, W = fmap2.shape
        # map grid coordinates to [-1,1]
        xgrid, ygrid = coords.split([1, 1], dim=-1)
        xgrid = 2 * xgrid / (W - 1) - 1
        ygrid = 2 * ygrid / (H - 1) - 1

        grid = torch.cat([xgrid, ygrid], dim=-1)
        output_corr = []
        for grid_slice in grid.unbind(3):
            fmapw_mini = F.grid_sample(fmap2, grid_slice, align_corners=True)
            corr = torch.sum(fmapw_mini * fmap1, dim=1)
            output_corr.append(corr)
        corr = torch.stack(output_corr, dim=1).permute(0, 2, 3, 1)

        return corr / torch.sqrt(torch.tensor(D).float())

    def __call__(self, coords):
        r = self.radius
        coords = coords.permute(0, 2, 3, 1)
        batch, h1, w1, _ = coords.shape
        fmap1 = self.fmap1
        fmap2 = self.fmap2
        out_pyramid = []
        for i in range(self.num_levels):
            dx = torch.zeros(1)
            dy = torch.linspace(-r, r, 2 * r + 1)
            delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
            centroid_lvl = coords.reshape(batch, h1, w1, 1, 2).clone()
            centroid_lvl[..., 0] = centroid_lvl[..., 0] / 2**i
            coords_lvl = centroid_lvl + delta.view(-1, 2)
            corr = self.corr(fmap1, fmap2, coords_lvl)
            fmap2 = F.avg_pool2d(fmap2, [1, 2], stride=[1, 2])
            out_pyramid.append(corr)
        out = torch.cat(out_pyramid, dim=-1)
        return out.permute(0, 3, 1, 2).contiguous().float()


class CorrBlock1D:
    def __init__(self, fmap1, fmap2, num_levels=4):
        self.num_levels = num_levels
        self.corr_pyramid = []

        # all pairs correlation
        corr = CorrBlock1D.corr(fmap1, fmap2)

        batch, h1, w1, dim, w2 = corr.shape
        corr = corr.reshape(batch * h1 * w1, dim, 1, w2)

        self.corr_pyramid.append(corr)
        for i in range(self.num_levels):
            corr = F.avg_pool2d(corr, [1, 2], stride=[1, 2])
            self.corr_pyramid.append(corr)

    def __call__(
        self, coords, radius=4, propagation=False, propagate_flow=None, step=1
    ):
        if not propagation:
            r = radius
            coords = coords.permute(0, 2, 3, 1)
            batch, h1, w1, _ = coords.shape

            out_pyramid = []
            for i in range(self.num_levels):
                corr = self.corr_pyramid[i]
                dx = torch.linspace(-r * step, r * step, 2 * r + 1)
                dx = dx.view(1, 1, 2 * r + 1, 1).to(coords.device)
                x0 = dx + coords.reshape(batch * h1 * w1, 1, 1, 1) / 2**i
                y0 = torch.zeros_like(x0)

                coords_lvl = torch.cat([x0, y0], dim=-1)
                corr = bilinear_sampler(corr, coords_lvl)
                corr = corr.view(batch, h1, w1, -1)
                out_pyramid.append(corr)
        else:
            batch, _, h1, w1 = coords.shape
            # propagate_flow[:, :, 1, :, :] = flow.unsqueeze(1).repeat(
            #     1, search_num, 1, 1, 1
            # )[:, :, 1, :, :]

            # [N, search_num, 2, H, W]
            coords = torch.unsqueeze(coords, 1) + propagate_flow
            coords = coords.permute(0, 3, 4, 1, 2).reshape(
                -1, 1, 9, 2
            )  # [B, 9, 2,H,W] -> [B*H*W, 1, 9, 2]
            out_pyramid = []
            for i in range(self.num_levels):
                corr = self.corr_pyramid[i]
                corr = bilinear_sampler(corr, coords)
                corr = corr.view(batch, h1, w1, -1)
                out_pyramid.append(corr)

        out = torch.cat(out_pyramid, dim=-1)
        return out.permute(0, 3, 1, 2).contiguous().float()

    @staticmethod
    def corr(fmap1, fmap2):
        B, D, H, W1 = fmap1.shape
        _, _, _, W2 = fmap2.shape
        fmap1 = fmap1.view(B, D, H, W1)
        fmap2 = fmap2.view(B, D, H, W2)
        corr = torch.einsum("aijk,aijh->ajkh", fmap1, fmap2)
        corr = corr.reshape(B, H, W1, 1, W2).contiguous()
        # 加入代价聚合
        return corr / torch.sqrt(torch.tensor(D).float())


class CorrBlock1D_Aggregation:
    def __init__(self, fmap1, fmap2, num_levels=4, offset=None):
        self.num_levels = num_levels
        self.corr_pyramid = []

        # all pairs correlation
        corr = CorrBlock1D_Aggregation.corr(fmap1, fmap2)

        batch, h1, w1, dim, w2 = corr.shape
        corr = corr.reshape(batch * h1 * w1, dim, 1, w2)

        self.corr_pyramid.append(corr)

    def __call__(
        self, coords, radius=4, propagation=False, propagate_flow=None, step=1
    ):
        if not propagation:
            r = radius
            coords = coords[:, :1].permute(0, 2, 3, 1)
            batch, h1, w1, _ = coords.shape

            out_pyramid = []
            for i in range(self.num_levels):
                corr = self.corr_pyramid[i]
                dx = torch.linspace(-r * step, r * step, 2 * r + 1)
                dx = dx.view(1, 1, 2 * r + 1, 1).to(coords.device)
                x0 = dx + coords.reshape(batch * h1 * w1, 1, 1, 1) / 2**i
                y0 = torch.zeros_like(x0)

                coords_lvl = torch.cat([x0, y0], dim=-1)
                corr = bilinear_sampler(corr, coords_lvl)
                corr = corr.view(batch, h1, w1, -1)
                out_pyramid.append(corr)
        else:
            batch, _, h1, w1 = coords.shape
            # propagate_flow[:, :, 1, :, :] = flow.unsqueeze(1).repeat(
            #     1, search_num, 1, 1, 1
            # )[:, :, 1, :, :]

            # [N, search_num, 2, H, W]
            coords_lvl = torch.unsqueeze(coords, 1) + propagate_flow
            coords_lvl = coords_lvl.permute(0, 3, 4, 1, 2).reshape(
                -1, 1, 9, 2
            )  # [B, 9, 2,H,W] -> [B*H*W, 1, 9, 2]
            out_pyramid = []
            for i in range(self.num_levels):
                corr = self.corr_pyramid[i]
                corr = bilinear_sampler(corr, coords_lvl)
                corr = corr.view(batch, h1, w1, -1)
                out_pyramid.append(corr)

        out = torch.cat(out_pyramid, dim=-1)
        return out.permute(0, 3, 1, 2).contiguous().float()

    @staticmethod
    def corr(fmap1, fmap2):
        B, D, H, W1 = fmap1.shape
        _, _, _, W2 = fmap2.shape
        fmap1 = fmap1.view(B, D, H, W1)
        fmap2 = fmap2.view(B, D, H, W2)
        corr = torch.einsum("aijk,aijh->ajkh", fmap1, fmap2)
        corr = corr.reshape(B, H, W1, 1, W2).contiguous()
        return corr / torch.sqrt(torch.tensor(D).float())


class CorrgroupBlock1D_Aggregation:
    def __init__(self, fmap1, fmap2, num_levels=4, max_disp=96, window_hight=32):
        self.num_levels = num_levels
        self.corr_pyramid = []

        # all pairs correlation
        batch, _, h1, w1 = fmap1.shape
        self.window_hight = window_hight
        self.window_width = max_disp * 2
        window_width = 2 * max_disp
        num_groups = w1 // window_hight
        end_group = (w1 - max_disp) // window_hight
        step_size = math.ceil((w1 - window_width) / end_group)

        self.increments = build_increments(
            num_groups, window_hight, self.window_width, step_size, w1, fmap1.device
        )
        self.increments = self.increments.expand(batch, 1, h1, w1)
        corr = CorrgroupBlock1D_Aggregation.corr(
            fmap1, fmap2, window_width, num_groups, step_size, window_hight
        )
        corr = corr.reshape(batch * h1 * w1, 1, 1, self.window_width)

        self.corr_pyramid.append(corr)

    def __call__(
        self, coords, radius=4, propagation=False, propagate_flow=None, step=1
    ):
        coords[:, :1] = coords[:, :1] - self.increments
        if not propagation:
            r = radius
            coords = coords[:, :1].permute(0, 2, 3, 1)
            batch, h1, w1, _ = coords.shape

            out_pyramid = []
            for i in range(self.num_levels):
                corr = self.corr_pyramid[i]
                dx = torch.linspace(-r * step, r * step, 2 * r + 1)
                dx = dx.view(1, 1, 2 * r + 1, 1).to(coords.device)
                x0 = dx + coords.reshape(batch * h1 * w1, 1, 1, 1) / 2**i
                y0 = torch.zeros_like(x0)

                coords_lvl = torch.cat([x0, y0], dim=-1)
                corr = bilinear_sampler(corr, coords_lvl)
                corr = corr.view(batch, h1, w1, -1)
                out_pyramid.append(corr)
        else:
            batch, _, h1, w1 = coords.shape
            # propagate_flow[:, :, 1, :, :] = flow.unsqueeze(1).repeat(
            #     1, search_num, 1, 1, 1
            # )[:, :, 1, :, :]

            # [N, search_num, 2, H, W]
            coords = torch.unsqueeze(coords, 1) + propagate_flow
            coords = coords.permute(0, 3, 4, 1, 2).reshape(
                -1, 1, 9, 2
            )  # [B, 9, 2,H,W] -> [B*H*W, 1, 9, 2]
            out_pyramid = []
            for i in range(self.num_levels):
                corr = self.corr_pyramid[i]
                corr = bilinear_sampler(corr, coords)
                corr = corr.view(batch, h1, w1, -1)
                out_pyramid.append(corr)

        out = torch.cat(out_pyramid, dim=-1)
        return out.permute(0, 3, 1, 2).contiguous().float()

    @staticmethod
    def corr(fmap1, fmap2, window_width, num_groups, step_size, window_hight=32):
        Batch, Channel, H, W = fmap1.shape
        # 创建索引
        B_indices = torch.arange(window_width).unsqueeze(0) + torch.arange(
            0, num_groups * step_size, step_size
        ).unsqueeze(1)

        B_indices = B_indices.view(num_groups, window_width)  # 将索引展平成二维
        B_indices = W - B_indices.flip(dims=[0]) - 1  # 反向索引并翻转
        B_indices = B_indices.flip(dims=[1])
        # 检查索引是否超出范围
        under_limit = B_indices < 0
        first_valid_indices = (
            torch.arange(0, window_width).unsqueeze(0).expand(num_groups, -1)
        )
        B_indices = torch.where(
            under_limit.any(dim=1, keepdim=True), first_valid_indices, B_indices
        )

        B_indices = B_indices.view(-1)  # 将索引展平成一维

        # 裁剪 B，保留所有需要的窗口
        B_expanded = fmap2[:, :, :, B_indices].view(
            Batch, Channel, H, num_groups, window_width
        )
        # 对 A 进行分组
        A_grouped = fmap1.view(Batch, Channel, H, num_groups, window_hight)

        # 将 A_grouped 和 B_expanded 的维度调整为适合广播的形式
        A_grouped = A_grouped.permute(0, 2, 3, 4, 1)  # [4, 80, 10, 32, 128]
        B_expanded = B_expanded.permute(0, 2, 3, 4, 1)  # [4, 80, 10, 192, 128]

        # 使用 einsum 进行逐元素乘加操作
        corr = torch.einsum("bnkqc,bnkpc->bnkqp", A_grouped, B_expanded)

        # 调整结果的维度
        corr = corr.reshape(Batch, H, W, 1, window_width).contiguous()
        return corr / torch.sqrt(torch.tensor(Channel).float())


def build_increments(
    num_groups, window_height, window_width, step_size, W, device
):  # issue
    increments = (
        torch.arange(num_groups - 1, -1, -1)
        .to(device)
        .unsqueeze(1)
        .expand(num_groups, window_height)
        * step_size
    )
    increments = increments.contiguous().view(1, 1, 1, num_groups * window_height)

    # 检查增量值是否超出范围并进行调整
    limit = W - window_width
    increments = torch.where(increments + window_width > W, limit, increments)
    increments = W - window_width - increments
    return increments


class CorrgroupBlock1D_Aggregation_patchmatch:
    # 不用多level的volume，只用一个，对于propagation，选择8个邻居和其周围两个点，
    #   共9*（1*2+1）=27个，search用三级所以有9*3个
    def __init__(
        self,
        fmap1,
        fmap2,
        num_levels=4,
        max_disp=96,
        window_hight=32,
        radius=4,
        window_width=384,
    ):
        self.num_levels = num_levels
        self.radius = radius
        self.corr_pyramid = []

        # all pairs correlation
        batch, _, h1, w1 = fmap1.shape
        self.window_hight = window_hight
        self.window_width = window_width
        num_groups = w1 // window_hight
        end_group = (w1 - max_disp) // window_hight
        step_size = (w1 - self.window_width - 1) // (end_group - 1)
        # self.dxs = []
        # step = 1
        # for i in range(self.num_levels):
        #     dx = (torch.linspace(-self.radius * step, self.radius * step, 2 * self.radius + 1)
        #     .to(fmap1.device)
        #     .view(1, 1, 2 * self.radius + 1, 1))
        #     self.dxs.append(dx)
        # self.dx = torch.cat(dxs, dim=2)

        # self.pdx = torch.arange(-1, 2).view(1, 1, 3, 1).to(fmap1.device)

        # # self.pdx = torch.cat([torch.zeros_like(self.pdx), self.pdx], dim=3)
        # self.pdx = torch.cat([self.pdx, torch.zeros_like(self.pdx)], dim=3)
        # # 扩展维度以匹配原始张量
        # self.pdx = self.pdx.repeat(batch * h1 * w1, 1, 9, 1)
        # self.pdx = self.pdx.view(batch * h1 * w1, 1, 27, 2)
        step = 1
        self.dx = torch.linspace(
            -self.radius * step, self.radius * step, 2 * self.radius + 1
        )
        self.dx = self.dx.view(1, 1, 2 * self.radius + 1, 1).to(fmap1.device)

        self.increments = build_increments(
            num_groups, window_hight, self.window_width, step_size, w1, fmap1.device
        )
        self.increments = self.increments.expand(batch, 1, h1, w1)

        self.corr_pyramid = CorrgroupBlock1D_Aggregation_patchmatch.corr(
            fmap1, fmap2, self.window_width, num_groups, step_size, window_hight
        )
        self.corr_pyramid = self.corr_pyramid.reshape(
            batch * h1 * w1, 1, 1, self.window_width
        )
        self.corr_pyramid_list = []
        self.corr_pyramid_list.append(self.corr_pyramid)

    def __call__(
        self, coords, propagation=False, propagate_flow=None, step=1, radius=4
    ):
        batch, c, h1, w1 = coords.shape
        assert c == 2
        coords = coords[:, :1] - self.increments

        if not propagation:
            out_pyramid = []
            for i in range(self.num_levels):
                corr_pyramid = self.corr_pyramid_list[i]
                x0 = self.dx + coords.reshape(batch * h1 * w1, 1, 1, 1) / 2**i
                y0 = torch.zeros_like(x0)

                coords_lvl = torch.cat([x0, y0], dim=-1)
                out = bilinear_sampler(corr_pyramid, coords_lvl)
                out = out.view(batch, h1, w1, -1)
                out_pyramid.append(out)
            corr = torch.cat(out_pyramid, dim=-1)

            # coords = coords[:, :1].permute(0, 2, 3, 1)
            # batch, h1, w1, _ = coords.shape

            # x0 = self.dx + coords.reshape(batch * h1 * w1, 1, 1, 1)
            # y0 = torch.zeros_like(x0)

            # coords_lvl = torch.cat([x0, y0], dim=-1)
            # corr = bilinear_sampler(self.corr_pyramid, coords_lvl)
            # corr = corr.view(batch, h1, w1, -1)
        else:
            batch, _, h1, w1 = coords.shape
            # propagate_flow[:, :, 1, :, :] = flow.unsqueeze(1).repeat(
            #     1, search_num, 1, 1, 1
            # )[:, :, 1, :, :]

            # [N, search_num, 2, H, W]
            coords_lvl = torch.unsqueeze(coords, 1) + propagate_flow
            coords_lvl = coords_lvl.permute(0, 3, 4, 1, 2).reshape(
                -1, 1, 9, 2
            )  # [B, 9, 2,H,W] -> [B*H*W, 1, 9, 2]
            out_pyramid = []
            for i in range(self.num_levels):
                corr = self.corr_pyramid_list[i]
                coords_lvl[:, 0] = coords_lvl[:, 0] / 2**i
                corr = bilinear_sampler(corr, coords_lvl)
                corr = corr.view(batch, h1, w1, -1)
                out_pyramid.append(corr)
            corr = torch.cat(out_pyramid, dim=-1)

            # batch, _, h1, w1 = coords.shape

            # coords = torch.unsqueeze(coords, 1) + propagate_flow
            # coords = coords.permute(0, 3, 4, 1, 2).reshape(
            #     -1, 1, 9, 2
            # )  # [B, 9, 2,H,W] -> [B*H*W, 1, 9, 2]
            # # 广播和扩展原始张量
            # coords = coords.unsqueeze(3).repeat(1, 1, 1, 3, 1)
            # coords = coords.view(batch * h1 * w1, 1, 27, 2)
            # # 加上偏移量
            # coords = coords + self.pdx
            # corr = bilinear_sampler(self.corr_pyramid, coords)
            # corr = corr.view(batch, h1, w1, -1)

        return corr.permute(0, 3, 1, 2).contiguous().float()

    @staticmethod
    def corr(fmap1, fmap2, window_width, num_groups, step_size, window_hight=32):
        Batch, Channel, H, W = fmap1.shape
        # 创建索引
        B_indices = torch.arange(window_width).unsqueeze(0) + torch.arange(
            0, num_groups * step_size, step_size
        ).unsqueeze(1)

        B_indices = B_indices.view(num_groups, window_width)  # 将索引展平成二维
        B_indices = W - B_indices.flip(dims=[0]) - 1  # 反向索引并翻转
        B_indices = B_indices.flip(dims=[1])
        # 检查索引是否超出范围
        under_limit = B_indices < 0
        first_valid_indices = (
            torch.arange(0, window_width).unsqueeze(0).expand(num_groups, -1)
        )
        B_indices = torch.where(
            under_limit.any(dim=1, keepdim=True), first_valid_indices, B_indices
        )
        B_indices = B_indices.view(-1)  # 将索引展平成一维

        # tensorrt 跑,先检测能不能跑起来
        # torchscript
        # tensorrt能全跑起来

        # 裁剪 B，保留所有需要的窗口
        B_expanded = fmap2[:, :, :, B_indices].view(
            Batch, Channel, H, num_groups, window_width
        )
        # 对 A 进行分组
        A_grouped = fmap1.view(Batch, Channel, H, num_groups, window_hight)

        # 将 A_grouped 和 B_expanded 的维度调整为适合广播的形式
        A_grouped = A_grouped.permute(0, 2, 3, 4, 1)  # [4, 80, 10, 32, 128]
        B_expanded = B_expanded.permute(0, 2, 3, 4, 1)  # [4, 80, 10, 192, 128]

        # 使用 einsum 进行逐元素乘加操作
        corr = torch.einsum("bnkqc,bnkpc->bnkqp", A_grouped, B_expanded)

        # 调整结果的维度
        corr = corr.reshape(Batch, H, W, 1, window_width).contiguous()
        return corr / torch.sqrt(torch.tensor(Channel).float())


class AlternateCorrBlock:
    def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
        raise NotImplementedError
        self.num_levels = num_levels
        self.radius = radius

        self.pyramid = [(fmap1, fmap2)]
        for i in range(self.num_levels):
            fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
            fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
            self.pyramid.append((fmap1, fmap2))

    def __call__(self, coords):
        coords = coords.permute(0, 2, 3, 1)
        B, H, W, _ = coords.shape
        dim = self.pyramid[0][0].shape[1]

        corr_list = []
        for i in range(self.num_levels):
            r = self.radius
            fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
            fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()

            coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
            (corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
            corr_list.append(corr.squeeze(1))

        corr = torch.stack(corr_list, dim=1)
        corr = corr.reshape(B, -1, H, W)
        return corr / torch.sqrt(torch.tensor(dim).float())


# class CorrBlock1D:
#     def __init__(self, fmap1, fmap2, num_levels=4):
#         self.num_levels = num_levels
#         self.corr_pyramid = []

#         # all pairs correlation
#         corr = CorrBlock1D.corr(fmap1, fmap2)

#         batch, h1, w1, dim, w2 = corr.shape
#         corr = corr.reshape(batch * h1 * w1, dim, 1, w2)

#         self.corr_pyramid.append(corr)
#         for i in range(self.num_levels):
#             corr = F.avg_pool2d(corr, [1, 2], stride=[1, 2])
#             self.corr_pyramid.append(corr)

#     def __call__(self, coords, radius):
#         r = radius
#         coords = coords[:, :1].permute(0, 2, 3, 1)
#         batch, h1, w1, _ = coords.shape

#         out_pyramid = []
#         for i in range(self.num_levels):
#             corr = self.corr_pyramid[i]
#             dx = torch.linspace(-r, r, 2 * r + 1)
#             dx = dx.view(1, 1, 2 * r + 1, 1).to(coords.device)
#             x0 = dx + coords.reshape(batch * h1 * w1, 1, 1, 1) / 2**i
#             y0 = torch.zeros_like(x0)

#             coords_lvl = torch.cat([x0, y0], dim=-1)
#             corr = bilinear_sampler(corr, coords_lvl)
#             corr = corr.view(batch, h1, w1, -1)
#             out_pyramid.append(corr)

#         out = torch.cat(out_pyramid, dim=-1)
#         return out.permute(0, 3, 1, 2).contiguous().float()

#     @staticmethod
#     def corr(fmap1, fmap2):
#         B, D, H, W1 = fmap1.shape
#         _, _, _, W2 = fmap2.shape
#         fmap1 = fmap1.view(B, D, H, W1)
#         fmap2 = fmap2.view(B, D, H, W2)
#         corr = torch.einsum("aijk,aijh->ajkh", fmap1, fmap2)
#         corr = corr.reshape(B, H, W1, 1, W2).contiguous()
#         return corr / torch.sqrt(torch.tensor(D).float())


class AlternateCorrBlock:
    def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
        raise NotImplementedError
        self.num_levels = num_levels
        self.radius = radius

        self.pyramid = [(fmap1, fmap2)]
        for i in range(self.num_levels):
            fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
            fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
            self.pyramid.append((fmap1, fmap2))

    def __call__(self, coords):
        coords = coords.permute(0, 2, 3, 1)
        B, H, W, _ = coords.shape
        dim = self.pyramid[0][0].shape[1]

        corr_list = []
        for i in range(self.num_levels):
            r = self.radius
            fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
            fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()

            coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
            (corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
            corr_list.append(corr.squeeze(1))

        corr = torch.stack(corr_list, dim=1)
        corr = corr.reshape(B, -1, H, W)
        return corr / torch.sqrt(torch.tensor(dim).float())


class PAGCL:
    """
    Implementation of Adaptive Group Correlation Layer (AGCL).
    """

    """和crestereo不同的点,在做匹配的时候,我们是直接进行交换邻居的匹配关系,而非与邻居进行匹配"""

    def __init__(self, fmap1, fmap2, grid, grid_1d, att=None):
        self.fmap1 = fmap1
        self.fmap2 = fmap2
        self.grid = grid
        self.grid_1d = grid_1d

        self.att = att
        self.coords = coords_grid(fmap1.shape[0], fmap1.shape[2], fmap1.shape[3]).to(
            fmap1.device
        )

    def get_correlation(self, left_feature, right_feature, psize=(3, 3), dilate=(1, 1)):

        N, C, H, W = left_feature.shape

        di_y, di_x = dilate[0], dilate[1]
        pady, padx = psize[0] // 2 * di_y, psize[1] // 2 * di_x

        right_pad = manual_pad(right_feature, pady, padx)

        corr_list = []
        for h in range(0, pady * 2 + 1, di_y):
            for w in range(0, padx * 2 + 1, di_x):
                right_crop = right_pad[:, :, h : h + H, w : w + W]
                assert right_crop.shape == left_feature.shape
                corr = torch.mean(left_feature * right_crop, dim=1, keepdims=True)
                corr_list.append(corr)

        corr_final = torch.cat(corr_list, dim=1)

        return corr_final

    def __call__(self, flow, propagation=True):
        if propagation:
            corr = self.propagation_fn(self.fmap1, self.fmap2, flow, self.grid)
        else:
            corr = self.search_fn(self.fmap1, self.fmap2, flow, self.grid_1d)
        return corr

    def propagation_fn(self, left_feature, right_feature, flow, grid):

        N, C, H, W = left_feature.shape

        if self.att is not None:
            left_feature = left_feature.permute(0, 2, 3, 1).reshape(
                N, H * W, C
            )  # 'n c h w -> n (h w) c'
            right_feature = right_feature.permute(0, 2, 3, 1).reshape(
                N, H * W, C
            )  # 'n c h w -> n (h w) c'
            # 'n (h w) c -> n c h w'
            left_feature, right_feature = self.att(left_feature, right_feature)
            # 'n (h w) c -> n c h w'
            left_feature, right_feature = [
                x.reshape(N, H, W, C).permute(0, 3, 1, 2)
                for x in [left_feature, right_feature]
            ]

        lefts = torch.split(left_feature, left_feature.shape[1] // 4, dim=1)
        rights = torch.split(right_feature, right_feature.shape[1] // 4, dim=1)

        C = C // 4

        search_num = 9

        corrs = []
        for i in range(4):
            left_feature, right_feature = lefts[i], rights[i]
            propagate_flow = F.grid_sample(
                flow, grid, mode="bilinear", padding_mode="border", align_corners=True
            ).reshape(
                N, search_num, 2, H, W
            )  # propagate_flow [N, search_num,2, H, W]
            propagate_flow[:, :, 1, :, :] = 0
            # propagate_flow[:, :, 1, :, :] = flow.unsqueeze(1).repeat(
            #     1, search_num, 1, 1, 1
            # )[:, :, 1, :, :]

            # [N, search_num, 2, H, W]
            coords = torch.unsqueeze(self.coords, 1) + propagate_flow
            coords = coords.permute(0, 1, 4, 2, 3)  # [N, search_num, H, W,2]
            coords[:, 0] = coords[:, 0] / ((W - 1) / 2) - 1
            coords[:, 1] = coords[:, 1] / ((H - 1) / 2) - 1
            coords = coords.reshape(N, -1, W, 2)  # [N, search_num*H, W, 2]
            right_feature = F.grid_sample(
                right_feature,
                coords,
                mode="bilinear",
                padding_mode="border",
                align_corners=True,
            ).reshape(
                N, C, search_num, H, W
            )  # [N, C, search_num, H, W]
            left_feature = left_feature.unsqueeze(2).repeat_interleave(
                right_feature.shape[2], dim=2
            )

            corr = torch.mean(left_feature * right_feature, dim=1)

            corrs.append(corr)

        final_corr = torch.cat(corrs, dim=1)

        return final_corr

    def get_correlation(self, left_feature, right_feature, psize=(3, 3), dilate=(1, 1)):

        N, C, H, W = left_feature.shape

        di_y, di_x = dilate[0], dilate[1]
        pady, padx = psize[0] // 2 * di_y, psize[1] // 2 * di_x

        right_pad = manual_pad(right_feature, pady, padx)

        corr_list = []
        for h in range(0, pady * 2 + 1, di_y):
            for w in range(0, padx * 2 + 1, di_x):
                right_crop = right_pad[:, :, h : h + H, w : w + W]
                assert right_crop.shape == left_feature.shape
                corr = torch.mean(left_feature * right_crop, dim=1, keepdims=True)
                corr_list.append(corr)

        corr_final = torch.cat(corr_list, dim=1)

        return corr_final

    # def search_fn(self, left_feature, right_feature, flow, grid):

    #     coords = self.coords + flow
    #     coords = coords.permute(0, 2, 3, 1)
    #     right_feature = bilinear_sampler(right_feature, coords)

    #     psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)]
    #     dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]

    #     N, C, H, W = left_feature.shape
    #     lefts = torch.split(left_feature, left_feature.shape[1] // 4, dim=1)
    #     rights = torch.split(right_feature, right_feature.shape[1] // 4, dim=1)

    #     corrs = []
    #     for i in range(len(psize_list)):
    #         corr = self.get_correlation(
    #             lefts[i], rights[i], psize_list[i], dilate_list[i]
    #         )
    #         corrs.append(corr)

    #     final_corr = torch.cat(corrs, dim=1)

    #     return final_corr

    def search_fn(self, left_feature, right_feature, flow, grid):

        N, C, H, W = left_feature.shape

        lefts = torch.split(left_feature, left_feature.shape[1] // 4, dim=1)
        rights = torch.split(right_feature, right_feature.shape[1] // 4, dim=1)

        C = C // 4

        search_num = 9

        corrs = []
        for i in range(4):
            left_feature, right_feature = lefts[i], rights[i]
            coords = self.coords.permute(0, 2, 3, 1).unsqueeze(1) + torch.clip(
                flow.permute(0, 2, 3, 1).unsqueeze(1)
                + grid.reshape(N, search_num, H, W, 2),
                min=0,
            )  # [N, search_num,H, W, 2]
            coords = coords.reshape(N, -1, W, 2)  # [N, search_num*H, W, 2]

            right_feature = F.grid_sample(
                right_feature,
                coords,
                mode="bilinear",
                padding_mode="border",
                align_corners=True,
            ).reshape(
                N, C, search_num, H, W
            )  # [N, C, search_num, H, W]
            left_feature = left_feature.unsqueeze(2).repeat_interleave(
                right_feature.shape[2], dim=2
            )

            corr = torch.mean(left_feature * right_feature, dim=1)

            corrs.append(corr)

        final_corr = torch.cat(corrs, dim=1)

        return final_corr


class AGCL:
    """
    Implementation of Adaptive Group Correlation Layer (AGCL).
    """

    def __init__(self, fmap1, fmap2, att=None):
        self.fmap1 = fmap1
        self.fmap2 = fmap2

        self.att = att

        self.coords = coords_grid(fmap1.shape[0], fmap1.shape[2], fmap1.shape[3]).to(
            fmap1.device
        )

    def __call__(self, flow, extra_offset, small_patch=False, iter_mode=False):
        if iter_mode:
            corr = self.corr_iter(self.fmap1, self.fmap2, flow, small_patch)
        else:
            corr = self.corr_att_offset(
                self.fmap1, self.fmap2, flow, extra_offset, small_patch
            )
        return corr

    def get_correlation(self, left_feature, right_feature, psize=(3, 3), dilate=(1, 1)):

        N, C, H, W = left_feature.shape

        di_y, di_x = dilate[0], dilate[1]
        pady, padx = psize[0] // 2 * di_y, psize[1] // 2 * di_x

        right_pad = manual_pad(right_feature, pady, padx)

        corr_list = []
        for h in range(0, pady * 2 + 1, di_y):
            for w in range(0, padx * 2 + 1, di_x):
                right_crop = right_pad[:, :, h : h + H, w : w + W]
                assert right_crop.shape == left_feature.shape
                corr = torch.mean(left_feature * right_crop, dim=1, keepdims=True)
                corr_list.append(corr)

        corr_final = torch.cat(corr_list, dim=1)

        return corr_final

    def corr_iter(self, left_feature, right_feature, flow, small_patch):

        coords = self.coords + flow
        coords = coords.permute(0, 2, 3, 1)
        right_feature = bilinear_sampler(right_feature, coords)

        if small_patch:
            psize_list = [(3, 3), (3, 3), (3, 3), (3, 3)]
            dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
        else:
            psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)]
            dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]

        N, C, H, W = left_feature.shape
        lefts = torch.split(left_feature, left_feature.shape[1] // 4, dim=1)
        rights = torch.split(right_feature, right_feature.shape[1] // 4, dim=1)

        corrs = []
        for i in range(len(psize_list)):
            corr = self.get_correlation(
                lefts[i], rights[i], psize_list[i], dilate_list[i]
            )
            corrs.append(corr)

        final_corr = torch.cat(corrs, dim=1)

        return final_corr

    def corr_att_offset(
        self, left_feature, right_feature, flow, extra_offset, small_patch
    ):

        N, C, H, W = left_feature.shape

        if self.att is not None:
            left_feature = left_feature.permute(0, 2, 3, 1).reshape(
                N, H * W, C
            )  # 'n c h w -> n (h w) c'
            right_feature = right_feature.permute(0, 2, 3, 1).reshape(
                N, H * W, C
            )  # 'n c h w -> n (h w) c'
            # 'n (h w) c -> n c h w'
            left_feature, right_feature = self.att(left_feature, right_feature)
            # 'n (h w) c -> n c h w'
            left_feature, right_feature = [
                x.reshape(N, H, W, C).permute(0, 3, 1, 2)
                for x in [left_feature, right_feature]
            ]

        lefts = torch.split(left_feature, left_feature.shape[1] // 4, dim=1)
        rights = torch.split(right_feature, right_feature.shape[1] // 4, dim=1)

        C = C // 4

        if small_patch:
            psize_list = [(3, 3), (3, 3), (3, 3), (3, 3)]
            dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
        else:
            psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)]
            dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]

        search_num = 9
        extra_offset = extra_offset.reshape(N, search_num, 2, H, W).permute(
            0, 1, 3, 4, 2
        )  # [N, search_num, 1, 1, 2]

        corrs = []
        for i in range(len(psize_list)):
            left_feature, right_feature = lefts[i], rights[i]
            psize, dilate = psize_list[i], dilate_list[i]

            psizey, psizex = psize[0], psize[1]
            dilatey, dilatex = dilate[0], dilate[1]

            ry = psizey // 2 * dilatey
            rx = psizex // 2 * dilatex
            x_grid, y_grid = torch.meshgrid(
                torch.arange(-rx, rx + 1, dilatex, device=self.fmap1.device),
                torch.arange(-ry, ry + 1, dilatey, device=self.fmap1.device),
                indexing="xy",
            )

            offsets = torch.stack((x_grid, y_grid))
            offsets = offsets.reshape(2, -1).permute(1, 0)
            for d in sorted((0, 2, 3)):
                offsets = offsets.unsqueeze(d)
            offsets = offsets.repeat_interleave(N, dim=0)
            offsets = offsets + extra_offset

            coords = self.coords + flow  # [N, 2, H, W]
            coords = coords.permute(0, 2, 3, 1)  # [N, H, W, 2]
            coords = torch.unsqueeze(coords, 1) + offsets
            coords = coords.reshape(N, -1, W, 2)  # [N, search_num*H, W, 2]

            right_feature = bilinear_sampler(
                right_feature, coords
            )  # [N, C, search_num*H, W]
            right_feature = right_feature.reshape(
                N, C, -1, H, W
            )  # [N, C, search_num, H, W]
            left_feature = left_feature.unsqueeze(2).repeat_interleave(
                right_feature.shape[2], dim=2
            )

            corr = torch.mean(left_feature * right_feature, dim=1)

            corrs.append(corr)

        final_corr = torch.cat(corrs, dim=1)

        return final_corr
