import torch
import torch.nn as nn
import torch.nn.functional as F
from core.update import (
    BasicMultiUpdateBlock,
    BasicUpdateBlock,
    PatchMatchMultiUpdateBlock,
    C_BasicUpdateBlock,
)
from core.extractor import (
    BasicEncoder,
    MultiBasicEncoder,
    ResidualBlock,
    C_BasicEncoder,
)
from core.corr import (
    CorrBlock1D,
    PytorchAlternateCorrBlock1D,
    CorrBlockFast1D,
    AlternateCorrBlock,
    PAGCL,
)
from core.utils.utils import coords_grid, upflow8
import functools

try:
    autocast = torch.cuda.amp.autocast
except:
    # dummy autocast for PyTorch < 1.6
    class autocast:
        def __init__(self, enabled):
            pass

        def __enter__(self):
            pass

        def __exit__(self, *args):
            pass


class ResBlock(nn.Module):
    def __init__(self, c0, dilation=1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(c0, c0, 3, 1, dilation, dilation),
            nn.LeakyReLU(0.2),
            nn.Conv2d(c0, c0, 3, 1, dilation, dilation),
        )
        self.relu = nn.LeakyReLU(0.2)

    def forward(self, input):
        x = self.conv(input)
        x = x + input
        x = self.relu(x)
        return x


class Refine(nn.Module):
    def __init__(self, cin, cres, dilations):
        super().__init__()
        self.conv1x1 = nn.Sequential(
            nn.Conv2d(cin, cres, 1),
            nn.LeakyReLU(0.2),
        )
        self.conv1 = nn.Sequential(
            nn.Conv2d(cres, cres, 3, 1, 1),
            nn.LeakyReLU(0.2),
        )
        self.res_block = []
        for d in dilations:
            self.res_block += [ResBlock(cres, d)]
        self.res_block = nn.Sequential(*self.res_block)
        self.convn = nn.Conv2d(cres, 1, 3, 1, 1)

    def forward(self, hpy, left):
        x = torch.cat((left, hpy), dim=1)
        x = self.conv1x1(x)
        x = self.conv1(x)
        x = self.res_block(x)
        x = self.convn(x)
        return hpy + x


def make_warp_coef(scale, device):
    center = (scale - 1) / 2
    index = torch.arange(scale, device=device) - center
    coef_y, coef_x = torch.meshgrid(index, index)
    coef_x = coef_x.reshape(1, -1, 1, 1)
    coef_y = coef_y.reshape(1, -1, 1, 1)
    return coef_x, coef_y


def disp_up(d, dx, dy, scale):
    n, _, h, w = d.size()
    coef_x, coef_y = make_warp_coef(scale, d.device)
    d = (d + coef_x * dx + coef_y * dy) * scale
    d = d.reshape(n, 1, scale, scale, h, w)
    d = d.permute(0, 1, 4, 2, 5, 3)
    d = d.reshape(n, 1, h * scale, w * scale)
    return d


class RAFTStereo_modified(nn.Module):
    def __init__(self, args):
        super().__init__()
        output_dim = 2
        flow_dim = 2
        self.args = args
        self.dropout = 0
        # self.fnet = C_BasicEncoder(
        #     output_dim=256, norm_fn="instance", dropout=self.dropout
        # )

        context_dims = args.hidden_dims
        self.cnet = MultiBasicEncoder(
            output_dim=[args.hidden_dims, context_dims],
            norm_fn=args.context_norm,
            downsample=args.n_downsample,
        )

        if args.shared_backbone:
            self.conv2 = nn.Sequential(
                ResidualBlock(128, 128, "instance", stride=1),
                nn.Conv2d(128, 256, 3, padding=1),
            )
        else:
            self.fnet = BasicEncoder(
                output_dim=256, norm_fn="instance", downsample=args.n_downsample
            )

        self.update_block = C_BasicUpdateBlock(
            hidden_dim=args.hidden_dims[2],
            cor_planes=4 * 9,
            mask_size=4,
            flow_dim=flow_dim,
            output_dim=output_dim,
        )

        self.update_block8 = C_BasicUpdateBlock(
            hidden_dim=args.hidden_dims[2],
            cor_planes=4 * 9,
            mask_size=2,
            flow_dim=flow_dim,
            output_dim=output_dim,
            update=False,
        )

        self.update_block16 = C_BasicUpdateBlock(
            hidden_dim=args.hidden_dims[2],
            cor_planes=4 * 9,
            mask_size=2,
            flow_dim=flow_dim,
            output_dim=output_dim,
            update=False,
        )

        self.num_neighbors = args.num_neighbors
        out = self.num_neighbors * 2
        self.propa_conv = nn.Sequential(
            nn.Conv2d(
                in_channels=args.hidden_dims[0] * 2,
                out_channels=out,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=True,
            ),
            nn.BatchNorm2d(out),
            nn.LeakyReLU(0.2),
        )
        self.propa_conv8 = nn.Sequential(
            nn.Conv2d(
                in_channels=args.hidden_dims[0] * 2,
                out_channels=out,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=True,
            ),
            nn.BatchNorm2d(out),
            nn.LeakyReLU(0.2),
        )
        self.propa_conv16 = nn.Sequential(
            nn.Conv2d(
                in_channels=args.hidden_dims[0] * 2,
                out_channels=out,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=True,
            ),
            nn.BatchNorm2d(out),
            nn.LeakyReLU(0.2),
        )

        # self.refine = Refine(1 + 256, 1, [1, 3, 1, 1])

    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    # def initialize_flow(self, fmap):
    #     """Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
    #     N, _, H, W = fmap.shape

    #     coords0 = coords_grid(N, H, W).to(fmap.device)
    #     coords1 = coords_grid(N, H, W).to(fmap.device)
    #     gradient = torch.zeros_like(coords1).to(fmap.device)

    #     return coords0, coords1, gradient

    def initialize_flow(self, fmap):
        """Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
        N, _, H, W = fmap.shape

        coords0 = coords_grid(N, H, W).to(fmap.device)
        coords1 = coords_grid(N, H, W).to(fmap.device)
        # res = torch.rand(N, 1, H, W, device=fmap.device) * W
        # coords1[:, 0, :] = coords1[:, 0, :] + res[:, 0, :]
        # flow = coords1 - coords1
        gradient = torch.zeros_like(coords0).to(fmap.device)

        return coords0, coords1, gradient

    def get_grid(self, batch, height, width, device, offset, dilation=1):

        original_offset = [
            [-dilation, -dilation],
            [-dilation, 0],
            [-dilation, -dilation],
            [0, -dilation],
            [0, dilation],
            [dilation, dilation],
            [dilation, 0],
            [dilation, dilation],
        ]
        if self.num_neighbors == 16:
            for i in range(len(original_offset)):
                offset_x, offset_y = original_offset[i]
                original_offset.append([2 * offset_x, 2 * offset_y])

        with torch.no_grad():
            y_grid, x_grid = torch.meshgrid(
                [
                    torch.arange(0, height, dtype=torch.float32, device=device),
                    torch.arange(0, width, dtype=torch.float32, device=device),
                ]
            )
            y_grid, x_grid = y_grid.contiguous().view(
                height * width
            ), x_grid.contiguous().view(height * width)
            xy = torch.stack((x_grid, y_grid))  # [2, H*W]
            xy = torch.unsqueeze(xy, 0).repeat(batch, 1, 1)  # [B, 2, H*W]
        xy_list = []
        xy_list.append(xy.unsqueeze(2))
        for i in range(self.num_neighbors):
            original_offset_y, original_offset_x = original_offset[i]
            offset_x = original_offset_x + offset[:, 2 * i, :].unsqueeze(1)
            offset_y = original_offset_y + offset[:, 2 * i + 1, :].unsqueeze(1)
            temp = torch.cat((offset_x, offset_y), dim=1)
            xy_list.append((xy + temp).unsqueeze(2))

        xy = torch.cat(xy_list, dim=2)  # [B, 2, 9, H*W]
        x_normalized = xy[:, 0, :, :] / ((width - 1) / 2) - 1
        y_normalized = xy[:, 1, :, :] / ((height - 1) / 2) - 1
        grid = torch.stack((x_normalized, y_normalized), dim=3)  # [B, 9, H*W, 2]
        grid = grid.view(batch, (self.num_neighbors + 1) * height, width, 2)
        return grid
        # return grid

    def upsample_flow(self, flow, mask, n_downsample=2):
        """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
        N, D, H, W = flow.shape
        factor = 2**n_downsample

        mask = mask.view(N, 1, 9, factor, factor, H, W)
        mask = torch.softmax(mask, dim=2)
        up_flow = F.unfold(factor * flow, [3, 3], padding=1)
        up_flow = up_flow.view(N, D, 9, 1, 1, H, W)

        up_flow = torch.sum(mask * up_flow, dim=2)
        up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
        return up_flow.reshape(N, D, factor * H, factor * W)

    def forward(
        self,
        image1,
        image2,
        patchmatch_rounds,
        iters=12,
        flow_init=None,
        test_mode=False,
    ):
        """Estimate optical flow between pair of frames"""

        image1 = (2 * (image1 / 255.0) - 1.0).contiguous()
        image2 = (2 * (image2 / 255.0) - 1.0).contiguous()

        # run the context network
        hdim = self.args.hidden_dims[2]

        # run the feature network
        with autocast(enabled=self.args.mixed_precision):
            if self.args.shared_backbone:
                *cnet_list, x = self.cnet(
                    torch.cat((image1, image2), dim=0),
                    dual_inp=True,
                    num_layers=self.args.n_gru_layers,
                )
                fmap1, fmap2 = self.conv2(x).split(dim=0, split_size=x.shape[0] // 2)
            else:
                cnet_list = self.cnet(image1, num_layers=self.args.n_gru_layers)
                fmap1, fmap2 = self.fnet([image1, image2])
            net_list = [torch.tanh(x[0]) for x in cnet_list]
            inp_list = [torch.relu(x[1]) for x in cnet_list]
        net, net_dw8, net_dw16 = net_list[0], net_list[1], net_list[2]
        inp, inp_dw8, inp_dw16 = inp_list[0], inp_list[1], inp_list[2]

        # with autocast(enabled=self.args.mixed_precision):
        #     fmap1, fmap2 = self.fnet([image1, image2])

        with autocast(enabled=self.args.mixed_precision):
            # feature
            fmap1_dw8 = F.avg_pool2d(fmap1, 2, stride=2)
            fmap2_dw8 = F.avg_pool2d(fmap2, 2, stride=2)
            fmap1_dw16 = F.avg_pool2d(fmap1, 4, stride=4)
            fmap2_dw16 = F.avg_pool2d(fmap2, 4, stride=4)

            # context
            # net, inp = torch.split(fmap1, [hdim, hdim], dim=1)
            # net = torch.tanh(net)
            # inp = F.relu(inp)
            # net_dw8 = F.avg_pool2d(net, 2, stride=2)
            # inp_dw8 = F.avg_pool2d(inp, 2, stride=2)
            # net_dw16 = F.avg_pool2d(net, 4, stride=4)
            # inp_dw16 = F.avg_pool2d(inp, 4, stride=4)
        batch, _, height, width = fmap1.shape

        if self.args.corr_implementation == "reg":  # Default
            corr_block = CorrBlock1D
            fmap1, fmap2 = fmap1.float(), fmap2.float()
        elif self.args.corr_implementation == "alt":  # More memory efficient than reg
            corr_block = PytorchAlternateCorrBlock1D
            fmap1, fmap2 = fmap1.float(), fmap2.float()
        elif self.args.corr_implementation == "reg_cuda":  # Faster version of reg
            corr_block = CorrBlockFast1D
        elif self.args.corr_implementation == "alt_cuda":  # Faster version of alt
            corr_block = AlternateCorrBlock
        # corr_fn = corr_block(fmap1, fmap2, num_levels=self.args.corr_levels)

        fmap1_dw8 = F.avg_pool2d(fmap1, 2, stride=2)
        fmap2_dw8 = F.avg_pool2d(fmap2, 2, stride=2)
        fmap1_dw16 = F.avg_pool2d(fmap1, 4, stride=4)
        fmap2_dw16 = F.avg_pool2d(fmap2, 4, stride=4)
        propa_weight = self.propa_conv(fmap1).view(
            batch, self.num_neighbors * 2, height * width
        )
        grid = self.get_grid(batch, height, width, fmap1.device, propa_weight)
        propa_weight8 = self.propa_conv8(fmap1_dw8).view(
            batch, self.num_neighbors * 2, fmap1_dw8.shape[2] * fmap1_dw8.shape[3]
        )
        grid8 = self.get_grid(
            batch, fmap1_dw8.shape[2], fmap1_dw8.shape[3], fmap1.device, propa_weight8
        )
        propa_weight16 = self.propa_conv16(fmap1_dw16).view(
            batch, self.num_neighbors * 2, fmap1_dw16.shape[2] * fmap1_dw16.shape[3]
        )
        grid16 = self.get_grid(
            batch,
            fmap1_dw16.shape[2],
            fmap1_dw16.shape[3],
            fmap1.device,
            propa_weight16,
        )
        # corr_fn = PAGCL(fmap1, fmap2, grid, grid_1d)
        # corr_fn_dw8 = PAGCL(fmap1_dw8, fmap2_dw8, grid8, grid_1d8)
        # corr_fn_dw16 = PAGCL(fmap1_dw16, fmap2_dw16, grid16, grid_1d16)

        corr_fn = corr_block(fmap1, fmap2, num_levels=self.args.corr_levels)
        corr_fn_dw8 = corr_block(fmap1_dw8, fmap2_dw8, num_levels=self.args.corr_levels)
        corr_fn_dw16 = corr_block(
            fmap1_dw16, fmap2_dw16, num_levels=self.args.corr_levels
        )
        # 1.加入代价平滑处理：当邻近像素的视差数值大于1时，被认为是剧烈变化增加惩罚项
        # 2.加入代价聚合：在构建cost volume时也要考虑同一视差

        coords0, coords1, gradient16 = self.initialize_flow(fmap1_dw16)

        flow_predictions = []
        flow_predictions8 = []
        flow_predictions16 = []
        gradient_predictions = []
        gradient_predictions8 = []
        gradient_predictions16 = []
        steps = [1, 1, 1]
        step = steps[0]
        for itr in range(iters // 4):
            if itr % (patchmatch_rounds - 1) == patchmatch_rounds - 1:
                # if itr == patchmatch_rounds:
                propagation = True
            else:
                propagation = False
            propagate_flow_16 = None
            coords1 = coords1.detach()
            # gradient16 = gradient16.detach()
            if propagation:
                propagate_flow_16 = F.grid_sample(
                    coords1 - coords0,
                    grid16,
                    mode="bilinear",
                    padding_mode="border",
                    align_corners=True,
                ).reshape(
                    batch, 9, 2, fmap1_dw16.shape[2], fmap1_dw16.shape[3]
                )  # propagate_flow [N, search_num,2, H, W]
                propagate_flow_16[:, :, 1, :, :] = 0
            out_corrs = corr_fn_dw16(
                coords1,
                radius=self.args.corr_radius,
                propagation=propagation,
                propagate_flow=propagate_flow_16,
                step=max(1, step / 2**itr),
            )

            with autocast(enabled=self.args.mixed_precision):
                net_dw16, up_mask, delta_flow = self.update_block16(
                    net_dw16,
                    inp_dw16,
                    out_corrs,
                    torch.cat([coords1-coords0], dim=1),
                )
            delta_flow[:, 1] = 0.0
            coords1 = coords1 + delta_flow[:, 0:2]
            # gradient16 = gradient16 + delta_flow[:, 2:]
            # flow_dw16 = torch.clip(coords1 - coords0, min=1e-4)
            # coords1 = flow_dw16 + coords0
            flow_dw16=coords1 - coords0
            # flow_dw8 = self.upsample_flow(flow_dw16, up_mask, n_downsample=1)

            flow_dw8 = 2 * F.interpolate(
                flow_dw16,
                size=(2 * flow_dw16.shape[2], 2 * flow_dw16.shape[3]),
                mode="bilinear",
                align_corners=True,
            )
            flow_res = 8 * F.interpolate(
                flow_dw8,
                size=(8 * flow_dw8.shape[2], 8 * flow_dw8.shape[3]),
                mode="bilinear",
                align_corners=True,
            )
            flow_predictions16.append(flow_res[:, :1])
            # gradient8 = F.interpolate(gradient16, scale_factor=2)
            # gradient_predictions16.append(gradient16)

        # scale = fmap1_dw8.shape[2] / flow_up.shape[2]
        # flow_dw8 = scale * F.interpolate(
        #     flow_up,
        #     size=(fmap1_dw8.shape[2], fmap1_dw8.shape[3]),
        #     mode="bilinear",
        #     align_corners=True,
        # )
        coords0 = coords_grid(
            fmap1_dw8.shape[0], fmap1_dw8.shape[2], fmap1_dw8.shape[3]
        ).to(fmap1_dw8.device)
        coords1 = coords0 + flow_dw8
        # flow_dw8 = disp_up(flow_dw16[:, :1], gradient16[:, 0:1], gradient16[:, 1:2], 2)
        # flow_dw8 = flow_dw8.repeat(1, 2, 1, 1)
        # flow_dw8[:, 1:2] = 0
        # gradient8 = F.interpolate(gradient16, scale_factor=2)
        # RUM: 1/8
        step = steps[1]
        for itr in range(iters // 4):
            if itr % (patchmatch_rounds - 1) == 0:
                # if itr == 0:
                propagation = True
            else:
                propagation = False
            propagate_flow_8 = None
            if propagation:
                propagate_flow_8 = F.grid_sample(
                    coords1 - coords0,
                    grid8,
                    mode="bilinear",
                    padding_mode="border",
                    align_corners=True,
                ).reshape(
                    batch, 9, 2, fmap1_dw8.shape[2], fmap1_dw8.shape[3]
                )  # propagate_flow [N, search_num,2, H, W]
                propagate_flow_8[:, :, 1, :, :] = 0
            coords1 = coords1.detach()
            # gradient8 = gradient8.detach()
            out_corrs = corr_fn_dw8(
                coords1,
                radius=self.args.corr_radius,
                propagation=propagation,
                propagate_flow=propagate_flow_8,
                step=max(1, step / 2**itr),
            )

            with autocast(enabled=self.args.mixed_precision):
                net_dw8, up_mask, delta_flow = self.update_block8(
                    net_dw8,
                    inp_dw8,
                    out_corrs,
                    torch.cat([coords1-coords0], dim=1),
                )

            delta_flow[:, 1] = 0.0
            coords1 = coords1 + delta_flow[:, 0:2]
            flow_dw8=coords1-coords0
            # gradient8 = gradient8 + delta_flow[:, 2:]

            # flow_dw8 = torch.clip(coords1 - coords0, min=1e-4)
            # coords1 = flow_dw8 + coords0

            # flow = self.upsample_flow(flow_dw8, up_mask, n_downsample=1)

            flow = 2 * F.interpolate(
                flow_dw8,
                size=(2 * flow_dw8.shape[2], 2 * flow_dw8.shape[3]),
                mode="bilinear",
                align_corners=True,
            )
            flow_res = 4 * F.interpolate(
                flow,
                size=(4 * flow.shape[2], 4 * flow.shape[3]),
                mode="bilinear",
                align_corners=True,
            )
            flow_predictions8.append(flow_res[:, :1])
            # gradient = F.interpolate(gradient8, scale_factor=2)
            # gradient_predictions8.append(gradient8)
            # flow_predictions2.append(flow_up)
        # scale = fmap1.shape[2] / flow_up.shape[2]
        # flow = scale * F.interpolate(
        #     flow_up,
        #     size=(fmap1.shape[2], fmap1.shape[3]),
        #     mode="bilinear",
        #     align_corners=True,
        # )
        coords0 = coords_grid(fmap1.shape[0], fmap1.shape[2], fmap1.shape[3]).to(
            fmap1.device
        )
        coords1 = coords0 + flow
        # flow = disp_up(flow_dw8[:, :1], gradient8[:, 0:1], gradient8[:, 1:2], 2)
        # flow = flow.repeat(1, 2, 1, 1)
        # flow[:, 1:2] = 0
        # gradient = F.interpolate(gradient8, scale_factor=2)
        # RUM: 1/4
        step = steps[2]
        for itr in range(iters // 2):
            if itr % patchmatch_rounds == 0:
                # if itr == 0:
                propagation = True
            else:
                propagation = False
            propagate_flow = None
            if propagation:
                propagate_flow = F.grid_sample(
                    coords1 - coords0,
                    grid,
                    mode="bilinear",
                    padding_mode="border",
                    align_corners=True,
                ).reshape(
                    batch, 9, 2, fmap1.shape[2], fmap1.shape[3]
                )  # propagate_flow [N, search_num,2, H, W]
                propagate_flow[:, :, 1, :, :] = 0
            coords1 = coords1.detach()
            # gradient = gradient.detach()
            out_corrs = corr_fn(
                coords1,
                radius=self.args.corr_radius,
                propagation=propagation,
                propagate_flow=propagate_flow,
                step=max(1, step / 2**itr),
            )

            with autocast(enabled=self.args.mixed_precision):
                net, up_mask, delta_flow = self.update_block(
                    net,
                    inp,
                    out_corrs,
                    torch.cat([coords1-coords0], dim=1),
                )

            delta_flow[:, 1] = 0.0
            coords1 = coords1 + delta_flow[:, 0:2]
            # gradient = gradient + delta_flow[:, 2:]
            flow=coords1-coords0
            # flow = torch.clip(coords1 - coords0, min=1e-4)
            # coords1 = flow + coords0
            # if itr == iters // 2 - 1:
            #     flow = self.refine(flow[:, :1], fmap1)
            flow_up = self.upsample_flow(flow, up_mask)
            # if itr == iters // 2 - 1:
            #     flow = self.refine(flow_up[:, :1], image1)
            # gradient_predictions.append(gradient)
            flow_predictions.append(flow_up[:, :1])
            # gradient_predictions.append(gradient)

        if test_mode:
            return flow_up[:, :1], flow_up[:, :1]

        # return (
        #     flow_predictions,
        #     flow_predictions2,
        #     flow_predictions4,
        # ), gradient_predictions
        return (flow_predictions16, flow_predictions8, flow_predictions), (
            gradient_predictions16,
            gradient_predictions8,
            gradient_predictions,
        )
