import torch
import torch.nn as nn
import torch.nn.functional as F
from core.update import (
    BasicMultiUpdateBlock,
    BasicUpdateBlock,
    PatchMatchMultiUpdateBlock,
    C_BasicUpdateBlock,
)
from core.extractor import (
    BasicEncoder,
    MultiBasicEncoder,
    ResidualBlock,
    C_BasicEncoder,
)
from core.corr import (
    CorrBlock1D,
    PytorchAlternateCorrBlock1D,
    CorrBlockFast1D,
    AlternateCorrBlock,
    PAGCL,
    CorrBlock1D_Aggregation,
    CorrgroupBlock1D_Aggregation,
    CorrgroupBlock1D_Aggregation_patchmatch,
)
from core.utils.utils import coords_grid, upflow8
import functools
from torchvision.ops.deform_conv import DeformConv2d

try:
    autocast = torch.cuda.amp.autocast
except:
    # dummy autocast for PyTorch < 1.6
    class autocast:
        def __init__(self, enabled):
            pass

        def __enter__(self):
            pass

        def __exit__(self, *args):
            pass


class RAFTStereo_modified(nn.Module):
    def __init__(self, args):
        super().__init__()
        output_dim = 2
        flow_dim = 2
        self.args = args
        self.dropout = 0
        self.window_height = 8
        self.window_width = int(1.5 * self.args.max_disp // 4)
        context_dims = args.hidden_dims
        self.cnet = MultiBasicEncoder(
            output_dim=[args.hidden_dims, context_dims],
            norm_fn=args.context_norm,
            downsample=args.n_downsample,
        )

        if args.shared_backbone:
            self.conv2 = nn.Sequential(
                ResidualBlock(128, 128, "instance", stride=1),
                nn.Conv2d(128, 256, 3, padding=1),
            )
        else:
            self.fnet = BasicEncoder(
                output_dim=256, norm_fn="instance", downsample=args.n_downsample
            )

        self.update_block = C_BasicUpdateBlock(
            hidden_dim=args.hidden_dims[2],
            cor_planes=self.args.corr_levels * 9,
            mask_size=4,
            flow_dim=flow_dim,
            output_dim=output_dim,
        )

        self.update_block8 = C_BasicUpdateBlock(
            hidden_dim=args.hidden_dims[2],
            cor_planes=self.args.corr_levels * 9,
            mask_size=2,
            flow_dim=flow_dim,
            output_dim=output_dim,
            update=False,
        )

        self.update_block16 = C_BasicUpdateBlock(
            hidden_dim=args.hidden_dims[2],
            cor_planes=self.args.corr_levels * 9,
            mask_size=2,
            flow_dim=flow_dim,
            output_dim=output_dim,
            update=False,
        )

        self.num_neighbors = args.num_neighbors
        out = self.num_neighbors * 2
        self.propa_conv = nn.Sequential(
            nn.Conv2d(
                in_channels=args.hidden_dims[0] * 2,
                out_channels=out,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=True,
            ),
            nn.BatchNorm2d(out),
            nn.LeakyReLU(0.2),
        )
        self.propa_conv8 = nn.Sequential(
            nn.Conv2d(
                in_channels=args.hidden_dims[0] * 2,
                out_channels=out,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=True,
            ),
            nn.BatchNorm2d(out),
            nn.LeakyReLU(0.2),
        )
        self.propa_conv16 = nn.Sequential(
            nn.Conv2d(
                in_channels=args.hidden_dims[0] * 2,
                out_channels=out,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=True,
            ),
            nn.BatchNorm2d(out),
            nn.LeakyReLU(0.2),
        )

        d_out = self.num_neighbors
        self.path_conv = nn.Sequential(
            nn.Conv2d(
                in_channels=args.hidden_dims[0] * 2,
                out_channels=d_out,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=True,
            ),
            nn.BatchNorm2d(d_out),
            nn.LeakyReLU(0.2),
        )
        self.path_conv8 = nn.Sequential(
            nn.Conv2d(
                in_channels=args.hidden_dims[0] * 2,
                out_channels=d_out,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=True,
            ),
            nn.BatchNorm2d(d_out),
            nn.LeakyReLU(0.2),
        )
        self.path_conv16 = nn.Sequential(
            nn.Conv2d(
                in_channels=args.hidden_dims[0] * 2,
                out_channels=d_out,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=True,
            ),
            nn.BatchNorm2d(d_out),
            nn.LeakyReLU(0.2),
        )

        self.d_conv = DeformConv2d(self.window_width, self.window_width, 3, 1, 1)
        self.d_conv8 = DeformConv2d(
            self.window_width // 2, self.window_width // 2, 3, 1, 1
        )
        self.d_conv16 = DeformConv2d(
            self.window_width // 4, self.window_width // 4, 3, 1, 1
        )
        # self.refine = Refine(1 + 256, 1, [1, 3, 1, 1])

    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    def initialize_flow(self, fmap):
        """Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
        N, _, H, W = fmap.shape

        coords0 = coords_grid(N, H, W).to(fmap.device)
        coords1 = coords_grid(N, H, W).to(fmap.device)
        # res = torch.rand(N, 1, H, W, device=fmap.device) * W
        # coords1[:, 0, :] = coords1[:, 0, :] + res[:, 0, :]
        # flow = coords1 - coords1
        # flow = torch.zeros_like(coords0).to(fmap.device)

        return coords0, coords1

    def get_path_offset(self, offset):

        # 初始化结果 tensor
        batch, _, height, width = offset.shape
        coordinates = torch.zeros(
            batch, 18, height, width, dtype=torch.float32, device=offset.device
        )

        # 定义 dy 和 dx 的索引
        dy_indices = torch.tensor(
            [-1, -1, -1, 0, 0, 0, 1, 1, 1], device=offset.device, dtype=torch.float32
        )
        dx_indices = torch.tensor(
            [-1, 0, 1, -1, 0, 1, -1, 0, 1], device=offset.device, dtype=torch.float32
        )

        # 广播 dy 和 dx 索引到 (8, 8, 100, 100) 的形状
        dy_indices = dy_indices.view(1, 9, 1, 1).expand(batch, 9, height, width)
        dx_indices = dx_indices.view(1, 9, 1, 1).expand(batch, 9, height, width)

        # 计算 dy 和 dx
        dy = offset * dy_indices
        dx = offset * dx_indices

        # 将 dy 和 dx 交错合并到 coordinates 中
        coordinates[:, 0::2, :, :] = dy
        coordinates[:, 1::2, :, :] = dx

        return coordinates

    def get_grid(self, batch, height, width, device, offset, dilation=1):

        original_offset = [
            [-dilation, -dilation],
            [-dilation, 0],
            [-dilation, -dilation],
            [0, -dilation],
            [0, dilation],
            [dilation, dilation],
            [dilation, 0],
            [dilation, dilation],
        ]
        if self.num_neighbors == 16:
            for i in range(len(original_offset)):
                offset_x, offset_y = original_offset[i]
                original_offset.append([2 * offset_x, 2 * offset_y])

        with torch.no_grad():
            y_grid, x_grid = torch.meshgrid(
                [
                    torch.arange(0, height, dtype=torch.float32, device=device),
                    torch.arange(0, width, dtype=torch.float32, device=device),
                ]
            )
            y_grid, x_grid = y_grid.contiguous().view(
                height * width
            ), x_grid.contiguous().view(height * width)
            xy = torch.stack((x_grid, y_grid))  # [2, H*W]
            xy = torch.unsqueeze(xy, 0).repeat(batch, 1, 1)  # [B, 2, H*W]
        xy_list = []
        xy_list.append(xy.unsqueeze(2))
        for i in range(self.num_neighbors):
            original_offset_y, original_offset_x = original_offset[i]
            offset_x = original_offset_x + offset[:, 2 * i, :].unsqueeze(1)
            offset_y = original_offset_y + offset[:, 2 * i + 1, :].unsqueeze(1)
            temp = torch.cat((offset_x, offset_y), dim=1)
            xy_list.append((xy + temp).unsqueeze(2))

        xy = torch.cat(xy_list, dim=2)  # [B, 2, 9, H*W]
        x_normalized = xy[:, 0, :, :] / ((width - 1) / 2) - 1
        y_normalized = xy[:, 1, :, :] / ((height - 1) / 2) - 1
        grid = torch.stack((x_normalized, y_normalized), dim=3)  # [B, 9, H*W, 2]
        grid = grid.view(batch, (self.num_neighbors + 1) * height, width, 2)
        return grid
        # return grid

    def upsample_flow(self, flow, mask, n_downsample=2):
        """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
        N, D, H, W = flow.shape
        factor = 2**n_downsample

        mask = mask.view(N, 1, 9, factor, factor, H, W)
        mask = torch.softmax(mask, dim=2)
        up_flow = F.unfold(factor * flow, [3, 3], padding=1)
        up_flow = up_flow.view(N, D, 9, 1, 1, H, W)

        up_flow = torch.sum(mask * up_flow, dim=2)
        up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
        return up_flow.reshape(N, D, factor * H, factor * W)

    def corr_aggregation(self, corr):
        batch, h1, w1, w2 = corr.corr_pyramid.shape
        corr.corr_pyramid_list[0] = corr.corr_pyramid.reshape(batch * h1 * w1, 1, 1, w2)
        corr_now = corr.corr_pyramid
        for i in range(self.args.corr_levels - 1):
            corr_now = F.avg_pool2d(corr_now, [1, 2], stride=[1, 2])

            corr.corr_pyramid_list.append(corr_now)
        return corr

    def insert_mid(self, offset):
        mid = torch.zeros(
            (offset.shape[0], 1, offset.shape[2], offset.shape[3]), device=offset.device
        )
        a_part1 = offset[:, :4, :, :]  # 取得 a 的前四个 slice
        a_part2 = offset[:, 4:, :, :]  # 取得 a 的后四个 slice

        # 在第二个维度上拼接起来
        res = torch.cat((a_part1, mid, a_part2), dim=1)
        return res

    def forward(
        self,
        image1,
        image2,
        patchmatch_rounds,
        iters=12,
        flow_init=None,
        test_mode=False,
    ):
        """Estimate optical flow between pair of frames"""

        image1 = (2 * (image1 / 255.0) - 1.0).contiguous()
        image2 = (2 * (image2 / 255.0) - 1.0).contiguous()

        # run the context network
        hdim = self.args.hidden_dims[2]

        # run the feature network
        with autocast(enabled=self.args.mixed_precision):
            if self.args.shared_backbone:
                *cnet_list, x = self.cnet(
                    torch.cat((image1, image2), dim=0),
                    dual_inp=True,
                    num_layers=self.args.n_gru_layers,
                )
                fmap1, fmap2 = self.conv2(x).split(dim=0, split_size=x.shape[0] // 2)
            else:
                cnet_list = self.cnet(image1, num_layers=self.args.n_gru_layers)
                fmap1, fmap2 = self.fnet([image1, image2])
            net_list = [torch.tanh(x[0]) for x in cnet_list]
            inp_list = [torch.relu(x[1]) for x in cnet_list]
        net, net_dw8, net_dw16 = net_list[0], net_list[1], net_list[2]
        inp, inp_dw8, inp_dw16 = inp_list[0], inp_list[1], inp_list[2]

        # with autocast(enabled=self.args.mixed_precision):
        #     fmap1, fmap2 = self.fnet([image1, image2])

        with autocast(enabled=self.args.mixed_precision):
            # feature
            fmap1_dw8 = F.avg_pool2d(fmap1, 2, stride=2)
            fmap2_dw8 = F.avg_pool2d(fmap2, 2, stride=2)
            fmap1_dw16 = F.avg_pool2d(fmap1, 4, stride=4)
            fmap2_dw16 = F.avg_pool2d(fmap2, 4, stride=4)

        batch, _, height, width = fmap1.shape

        if self.args.corr_implementation == "reg":  # Default
            # corr_block = CorrBlock1D
            corr_block = CorrgroupBlock1D_Aggregation_patchmatch
            fmap1, fmap2 = fmap1.float(), fmap2.float()
        elif self.args.corr_implementation == "alt":  # More memory efficient than reg
            corr_block = PytorchAlternateCorrBlock1D
            fmap1, fmap2 = fmap1.float(), fmap2.float()
        elif self.args.corr_implementation == "reg_cuda":  # Faster version of reg
            corr_block = CorrBlockFast1D
        elif self.args.corr_implementation == "alt_cuda":  # Faster version of alt
            corr_block = AlternateCorrBlock
        # corr_fn = corr_block(fmap1, fmap2, num_levels=self.args.corr_levels)

        fmap1_dw8 = F.avg_pool2d(fmap1, 2, stride=2)
        fmap2_dw8 = F.avg_pool2d(fmap2, 2, stride=2)
        fmap1_dw16 = F.avg_pool2d(fmap1, 4, stride=4)
        fmap2_dw16 = F.avg_pool2d(fmap2, 4, stride=4)
        propa_weight = self.propa_conv(fmap1).view(
            batch, self.num_neighbors * 2, height * width
        )
        grid = self.get_grid(batch, height, width, fmap1.device, propa_weight)
        propa_weight8 = self.propa_conv8(fmap1_dw8).view(
            batch, self.num_neighbors * 2, fmap1_dw8.shape[2] * fmap1_dw8.shape[3]
        )
        grid8 = self.get_grid(
            batch, fmap1_dw8.shape[2], fmap1_dw8.shape[3], fmap1.device, propa_weight8
        )
        propa_weight16 = self.propa_conv16(fmap1_dw16).view(
            batch, self.num_neighbors * 2, fmap1_dw16.shape[2] * fmap1_dw16.shape[3]
        )
        grid16 = self.get_grid(
            batch,
            fmap1_dw16.shape[2],
            fmap1_dw16.shape[3],
            fmap1.device,
            propa_weight16,
        )
        # corr_fn = PAGCL(fmap1, fmap2, grid, grid_1d)
        # corr_fn_dw8 = PAGCL(fmap1_dw8, fmap2_dw8, grid8, grid_1d8)
        # corr_fn_dw16 = PAGCL(fmap1_dw16, fmap2_dw16, grid16, grid_1d16)

        corr_fn = corr_block(
            fmap1,
            fmap2,
            self.args.corr_levels,
            self.args.max_disp // 4,
            self.window_height,
            4,
            self.window_width,
        )
        corr_fn_dw8 = corr_block(
            fmap1_dw8,
            fmap2_dw8,
            self.args.corr_levels,
            self.args.max_disp // 8,
            self.window_height // 2,
            4,
            self.window_width // 2,
        )
        corr_fn_dw16 = corr_block(
            fmap1_dw16,
            fmap2_dw16,
            self.args.corr_levels,
            self.args.max_disp // 16,
            self.window_height // 4,
            4,
            self.window_width // 4,
        )

        corr_fn.corr_pyramid = corr_fn.corr_pyramid.reshape(
            batch, height, width, self.window_width
        ).permute(0, 3, 1, 2)
        corr_fn_dw8.corr_pyramid = corr_fn_dw8.corr_pyramid.reshape(
            batch, height // 2, width // 2, self.window_width // 2
        ).permute(0, 3, 1, 2)
        corr_fn_dw16.corr_pyramid = corr_fn_dw16.corr_pyramid.reshape(
            batch, height // 4, width // 4, self.window_width // 4
        ).permute(0, 3, 1, 2)

        d_offset = self.path_conv(fmap1).view(batch, self.num_neighbors, height, width)
        d_offset8 = self.path_conv8(fmap1_dw8).view(
            batch, self.num_neighbors, fmap1_dw8.shape[2], fmap1_dw8.shape[3]
        )
        d_offset16 = self.path_conv16(fmap1_dw16).view(
            batch, self.num_neighbors, fmap1_dw16.shape[2], fmap1_dw16.shape[3]
        )

        d_offset = self.insert_mid(d_offset)
        d_offset8 = self.insert_mid(d_offset8)
        d_offset16 = self.insert_mid(d_offset16)

        d_offset = self.get_path_offset(d_offset)
        d_offset8 = self.get_path_offset(d_offset8)
        d_offset16 = self.get_path_offset(d_offset16)

        corr_fn.corr_pyramid = self.d_conv(corr_fn.corr_pyramid, d_offset)
        corr_fn_dw8.corr_pyramid = self.d_conv8(corr_fn_dw8.corr_pyramid, d_offset8)
        corr_fn_dw16.corr_pyramid = self.d_conv16(corr_fn_dw16.corr_pyramid, d_offset16)

        corr_fn.corr_pyramid = corr_fn.corr_pyramid.permute(0, 2, 3, 1).reshape(
            batch * fmap1.shape[2] * fmap1.shape[3], 1, 1, -1
        )
        corr_fn_dw8.corr_pyramid = corr_fn_dw8.corr_pyramid.permute(0, 2, 3, 1).reshape(
            batch * fmap1_dw8.shape[2] * fmap1_dw8.shape[3], 1, 1, -1
        )
        corr_fn_dw16.corr_pyramid = corr_fn_dw16.corr_pyramid.permute(
            0, 2, 3, 1
        ).reshape(batch * fmap1_dw16.shape[2] * fmap1_dw16.shape[3], 1, 1, -1)
        # 1.加入代价平滑处理：当邻近像素的视差数值大于1时，被认为是剧烈变化增加惩罚项
        # 2.加入代价聚合：在构建cost volume时也要考虑同一视差

        corr_fn = self.corr_aggregation(corr_fn)
        corr_fn_dw8 = self.corr_aggregation(corr_fn_dw8)
        corr_fn_dw16 = self.corr_aggregation(corr_fn_dw16)

        coords0, coords1 = self.initialize_flow(fmap1_dw16)

        flow_predictions = []
        flow_predictions8 = []
        flow_predictions16 = []
        gradient_predictions = []
        gradient_predictions8 = []
        gradient_predictions16 = []
        steps = [1, 1, 1]
        step = steps[0]
        for itr in range(iters // 4):
            if itr % patchmatch_rounds == patchmatch_rounds - 1:
                # if itr == patchmatch_rounds:
                propagation = True
            else:
                propagation = False
            propagate_flow_16 = None
            coords1 = coords1.detach()
            # gradient16 = gradient16.detach()
            if propagation:
                propagate_flow_16 = F.grid_sample(
                    coords1 - coords0,
                    grid16,
                    mode="bilinear",
                    padding_mode="border",
                    align_corners=True,
                ).reshape(
                    batch, 9, 2, fmap1_dw16.shape[2], fmap1_dw16.shape[3]
                )  # propagate_flow [N, search_num,2, H, W]
                propagate_flow_16[:, :, 1, :, :] = 0
            out_corrs = corr_fn_dw16(
                coords1,
                radius=self.args.corr_radius,
                propagation=propagation,
                propagate_flow=propagate_flow_16,
                step=max(1, step / 2**itr),
            )

            with autocast(enabled=self.args.mixed_precision):
                net_dw16, up_mask, delta_flow = self.update_block16(
                    net_dw16,
                    inp_dw16,
                    out_corrs,
                    coords1 - coords0,
                )
            # delta_flow[:, 1] = 0.0
            coords1 = coords1 + delta_flow[:, 0:1]
            # init_flow = torch.clip(
            #     init_flow + delta_flow[:, 0:2], min=-self.args.max_disp // 16, max=0
            # )
            flow_dw8 = 2 * F.interpolate(
                torch.clip(coords1 - coords0, min=-self.args.max_disp // 16, max=-0.1),
                size=(2 * fmap1_dw16.shape[2], 2 * fmap1_dw16.shape[3]),
                mode="bilinear",
                align_corners=True,
            )
            flow_res = 8 * F.interpolate(
                flow_dw8,
                size=(8 * flow_dw8.shape[2], 8 * flow_dw8.shape[3]),
                mode="bilinear",
                align_corners=True,
            )
            flow_predictions16.append(flow_res[:, :1])

        # init_flow = 2 * F.interpolate(
        #     init_flow,
        #     size=(2 * init_flow.shape[2], 2 * init_flow.shape[3]),
        #     mode="bilinear",
        #     align_corners=True,
        # )
        coords0 = coords_grid(
            fmap1_dw8.shape[0], fmap1_dw8.shape[2], fmap1_dw8.shape[3]
        ).to(fmap1_dw8.device)
        coords1 = coords0 + flow_dw8

        for itr in range(iters // 4):
            coords1 = coords1.detach()
            if itr % patchmatch_rounds == patchmatch_rounds - 1:
                # if itr == 0:
                propagation = True
            else:
                propagation = False
            propagate_flow_8 = None
            if propagation:
                propagate_flow_8 = F.grid_sample(
                    coords1 - coords0,
                    grid8,
                    mode="bilinear",
                    padding_mode="border",
                    align_corners=True,
                ).reshape(
                    batch, 9, 2, fmap1_dw8.shape[2], fmap1_dw8.shape[3]
                )  # propagate_flow [N, search_num,2, H, W]
                propagate_flow_8[:, :, 1, :, :] = 0
            # gradient8 = gradient8.detach()
            out_corrs = corr_fn_dw8(
                coords1,
                radius=self.args.corr_radius,
                propagation=propagation,
                propagate_flow=propagate_flow_8,
                step=max(1, step / 2**itr),
            )
            with autocast(enabled=self.args.mixed_precision):
                net_dw8, up_mask, delta_flow = self.update_block8(
                    net_dw8,
                    inp_dw8,
                    out_corrs,
                    coords1 - coords0,
                )
            # delta_flow[:, 1] = 0.0
            # init_flow = torch.clip(
            #     init_flow + delta_flow[:, 0:2], min=-self.args.max_disp // 8, max=0
            # )
            coords1 = coords1 + delta_flow[:, 0:1]

            flow = 2 * F.interpolate(
                torch.clip(coords1 - coords0, min=-self.args.max_disp // 8, max=-0.1),
                size=(2 * fmap1_dw8.shape[2], 2 * fmap1_dw8.shape[3]),
                mode="bilinear",
                align_corners=True,
            )
            flow_res = 4 * F.interpolate(
                flow,
                size=(4 * flow.shape[2], 4 * flow.shape[3]),
                mode="bilinear",
                align_corners=True,
            )
            flow_predictions8.append(flow_res[:, :1])
            # gradient = F.interpolate(gradient8, scale_factor=2)

        # init_flow = 2 * F.interpolate(
        #     init_flow,
        #     size=(2 * init_flow.shape[2], 2 * init_flow.shape[3]),
        #     mode="bilinear",
        #     align_corners=True,
        # )
        coords0 = coords_grid(fmap1.shape[0], fmap1.shape[2], fmap1.shape[3]).to(
            fmap1.device
        )
        coords1 = coords0 + flow

        for itr in range(iters // 2):
            coords1 = coords1.detach()
            if itr % patchmatch_rounds == patchmatch_rounds - 1:
                # if itr == 0:
                propagation = True
            else:
                propagation = False
            # propagation = False
            propagate_flow = None
            if propagation:
                propagate_flow = F.grid_sample(
                    coords1 - coords0,
                    grid,
                    mode="bilinear",
                    padding_mode="border",
                    align_corners=True,
                ).reshape(
                    batch, 9, 2, fmap1.shape[2], fmap1.shape[3]
                )  # propagate_flow [N, search_num,2, H, W]
                propagate_flow[:, :, 1, :, :] = 0
            # gradient = gradient.detach()
            out_corrs = corr_fn(
                coords1,
                radius=self.args.corr_radius,
                propagation=propagation,
                propagate_flow=propagate_flow,
                step=max(1, step / 2**itr),
            )

            with autocast(enabled=self.args.mixed_precision):
                net, up_mask, delta_flow = self.update_block(
                    net,
                    inp,
                    out_corrs,
                    coords1 - coords0,
                )

            # delta_flow[:, 1] = 0.0
            coords1 = coords1 + delta_flow[:, 0:1]
            # init_flow = torch.clip(
            #     init_flow + delta_flow[:, 0:2], min=-self.args.max_disp // 4, max=0
            # )
            flow_up = self.upsample_flow(
                torch.clip(coords1 - coords0, min=-self.args.max_disp // 4, max=-0.1),
                up_mask,
            )
            # if itr == iters // 2 - 1:
            #     flow = self.refine(flow_up[:, :1], image1)
            # gradient_predictions.append(gradient)
            flow_predictions.append(flow_up[:, :1])
            # gradient_predictions.append(gradient)

        if test_mode:
            return flow_up[:, :1], flow_up[:, :1]

        return (flow_predictions16, flow_predictions8, flow_predictions), (
            gradient_predictions16,
            gradient_predictions8,
            gradient_predictions,
        )
