import torch
import torch.nn as nn
import torch.nn.functional as F
from core.update import (
    BasicMultiUpdateBlock,
    BasicUpdateBlock,
    PatchMatchMultiUpdateBlock,
)
from core.extractor import BasicEncoder, MultiBasicEncoder, ResidualBlock
from core.corr import (
    CorrBlock1D,
    PytorchAlternateCorrBlock1D,
    CorrBlockFast1D,
    AlternateCorrBlock,
)
from core.utils.utils import coords_grid, upflow8
import functools

try:
    autocast = torch.cuda.amp.autocast
except:
    # dummy autocast for PyTorch < 1.6
    class autocast:
        def __init__(self, enabled):
            pass

        def __enter__(self):
            pass

        def __exit__(self, *args):
            pass


@functools.lru_cache()
@torch.no_grad()
def make_warp_coef(scale, device):
    center = (scale - 1) / 2
    index = torch.arange(scale, device=device) - center
    coef_y, coef_x = torch.meshgrid(index, index)
    coef_x = coef_x.reshape(1, -1, 1, 1)
    coef_y = coef_y.reshape(1, -1, 1, 1)
    return coef_x, coef_y


def disp_up(d, dx, dy, scale):
    n, _, h, w = d.size()
    coef_x, coef_y = make_warp_coef(scale, d.device)
    d = (d + coef_x * dx + coef_y * dy) * scale
    d = d.reshape(n, 1, scale, scale, h, w)
    d = d.permute(0, 1, 4, 2, 5, 3)
    d = d.reshape(n, 1, h * scale, w * scale)
    return d


class ResBlock(nn.Module):
    def __init__(self, c0, dilation=1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(c0, c0, 3, 1, dilation, dilation),
            nn.LeakyReLU(0.2),
            nn.Conv2d(c0, c0, 3, 1, dilation, dilation),
        )
        self.relu = nn.LeakyReLU(0.2)

    def forward(self, input):
        x = self.conv(input)
        x = x + input
        x = self.relu(x)
        return x


class Refine(nn.Module):
    def __init__(self, cin, cres, dilations):
        super().__init__()
        self.conv1x1 = nn.Sequential(
            nn.Conv2d(cin, cres, 1),
            nn.LeakyReLU(0.2),
        )
        self.conv1 = nn.Sequential(
            nn.Conv2d(cres, cres, 3, 1, 1),
            nn.LeakyReLU(0.2),
        )
        self.res_block = []
        for d in dilations:
            self.res_block += [ResBlock(cres, d)]
        self.res_block = nn.Sequential(*self.res_block)
        self.convn = nn.Conv2d(cres, 1, 3, 1, 1)

    def forward(self, hpy, left):
        x = torch.cat((left, hpy), dim=1)
        x = self.conv1x1(x)
        x = self.conv1(x)
        x = self.res_block(x)
        x = self.convn(x)
        return hpy + x


class RAFTStereo(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args

        context_dims = args.hidden_dims

        self.cnet = MultiBasicEncoder(
            output_dim=[args.hidden_dims, context_dims],
            norm_fn=args.context_norm,
            downsample=args.n_downsample,
        )
        # self.update_block = BasicMultiUpdateBlock(self.args, hidden_dims=args.hidden_dims)
        self.update_block = BasicMultiUpdateBlock(
            self.args, hidden_dims=args.hidden_dims, flow_dim=2, output_dim=2
        )

        self.context_zqr_convs = nn.ModuleList(
            [
                nn.Conv2d(context_dims[i], args.hidden_dims[i] * 3, 3, padding=3 // 2)
                for i in range(self.args.n_gru_layers)
            ]
        )

        if args.shared_backbone:
            self.conv2 = nn.Sequential(
                ResidualBlock(128, 128, "instance", stride=1),
                nn.Conv2d(128, 256, 3, padding=1),
            )
        else:
            self.fnet = BasicEncoder(
                output_dim=256, norm_fn="instance", downsample=args.n_downsample
            )

        self.num_neighbors = args.num_neighbors
        out = self.num_neighbors * 2
        self.propa_conv = nn.Sequential(
            nn.Conv2d(
                in_channels=args.hidden_dims[0] * 2,
                out_channels=out,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=True,
            ),
            nn.BatchNorm2d(out),
            nn.LeakyReLU(0.2),
        )
        self.refine = Refine(4 + 3, 16, [1, 1])
        # self.p_update_block=PatchMatchMultiUpdateBlock(self.args, hidden_dims=args.hidden_dims)
        # self.p_update_block=BasicUpdateBlock(self.args, hidden_dims=args.hidden_dims)

    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    def initialize_flow(self, corr):
        """Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
        N, H, W, _, _ = corr.shape

        coords0 = coords_grid(N, H, W).to(corr.device)
        coords1 = coords_grid(N, H, W).to(corr.device)
        gradient = torch.zeros_like(coords1).to(corr.device)

        return coords0, coords1, gradient

    # def initialize_flow(self, corr):
    #     """Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
    #     N, H, W, _, _ = corr.shape

    #     coords0 = coords_grid(N, H, W).to(corr.device)
    #     coords1 = coords_grid(N, H, W).to(corr.device)
    #     res = torch.rand(N, 1, H, W, device=corr.device) * W
    #     coords1[:, 0, :] = coords1[:, 0, :] + res[:, 0, :]
    #     # _,res=torch.max(corr,dim=4)
    #     # res=res.permute(0,3,1,2)
    #     # coords1[:,0,:]=res[:,0,:]
    #     gradient = torch.zeros_like(coords1).to(corr.device)

    #     return coords0, coords1, gradient

    def get_grid(self, batch, height, width, device, offset, dilation=1):

        original_offset = [
            [-dilation, -dilation],
            [-dilation, 0],
            [-dilation, -dilation],
            [0, -dilation],
            [0, dilation],
            [dilation, dilation],
            [dilation, 0],
            [dilation, dilation],
        ]
        if self.num_neighbors == 16:
            for i in range(len(original_offset)):
                offset_x, offset_y = original_offset[i]
                original_offset.append([2 * offset_x, 2 * offset_y])

        with torch.no_grad():
            y_grid, x_grid = torch.meshgrid(
                [
                    torch.arange(0, height, dtype=torch.float32, device=device),
                    torch.arange(0, width, dtype=torch.float32, device=device),
                ]
            )
            y_grid, x_grid = y_grid.contiguous().view(
                height * width
            ), x_grid.contiguous().view(height * width)
            xy = torch.stack((x_grid, y_grid))  # [2, H*W]
            xy = torch.unsqueeze(xy, 0).repeat(batch, 1, 1)  # [B, 2, H*W]
        xy_list = []
        offset_list = []
        xy_list.append(xy.unsqueeze(2))
        offset_list.append(torch.zeros(batch, 2, 1, height * width).to(device))
        for i in range(self.num_neighbors):
            original_offset_y, original_offset_x = original_offset[i]
            offset_x = original_offset_x + offset[:, 2 * i, :].unsqueeze(1)
            offset_y = original_offset_y + offset[:, 2 * i + 1, :].unsqueeze(1)
            temp = torch.cat((offset_x, offset_y), dim=1)
            xy_list.append((xy + temp).unsqueeze(2))
            offset_list.append(temp.unsqueeze(2))

            # offset_x = offset[:, 2 * i, :].unsqueeze(1)
            # offset_y = offset[:, 2 * i + 1, :].unsqueeze(1)
            # temp=torch.cat((offset_x, offset_y), dim=1)
            # xy_list.append((xy + temp).unsqueeze(2))
            # offset_list.append(temp.unsqueeze(2))

        offset_grid = torch.cat(offset_list, dim=2)
        xy = torch.cat(xy_list, dim=2)  # [B, 2, 9, H*W]
        x_normalized = xy[:, 0, :, :] / ((width - 1) / 2) - 1
        y_normalized = xy[:, 1, :, :] / ((height - 1) / 2) - 1
        grid = torch.stack((x_normalized, y_normalized), dim=3)  # [B, 9, H*W, 2]
        grid = grid.view(batch, (self.num_neighbors + 1) * height, width, 2)
        return grid, offset_grid
        # return grid

    def get_1d_grid(self, batch, height, width, device):

        original_offset = [
            [0, -4],
            [0, -3],
            [0, -2],
            [0, -1],
            [0, 1],
            [0, 2],
            [0, 3],
            [0, 4],
        ]

        with torch.no_grad():
            y_grid, x_grid = torch.meshgrid(
                [
                    torch.arange(0, height, dtype=torch.float32, device=device),
                    torch.arange(0, width, dtype=torch.float32, device=device),
                ]
            )
            y_grid, x_grid = y_grid.contiguous().view(
                height * width
            ), x_grid.contiguous().view(height * width)
            xy = torch.stack((x_grid, y_grid))  # [2, H*W]
            xy = torch.unsqueeze(xy, 0).repeat(batch, 1, 1)  # [B, 2, H*W]
        offset_list = []
        for i in range(self.num_neighbors):
            original_offset_y, original_offset_x = original_offset[i]
            offset_x = original_offset_x + torch.zeros(
                (batch, 1, height * width), device=device
            )
            offset_y = original_offset_y + torch.zeros(
                (batch, 1, height * width), device=device
            )
            temp = torch.cat((offset_x, offset_y), dim=1)
            offset_list.append(temp.unsqueeze(2))
        offset_1d_grid = torch.cat(offset_list, dim=2)
        return offset_1d_grid
        # return grid

    def upsample_flow(self, flow, mask):
        """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
        N, D, H, W = flow.shape
        factor = 2**self.args.n_downsample

        mask = mask.view(N, 1, 9, factor, factor, H, W)
        mask = torch.softmax(mask, dim=2)
        up_flow = F.unfold(factor * flow, [3, 3], padding=1)
        up_flow = up_flow.view(N, D, 9, 1, 1, H, W)

        up_flow = torch.sum(mask * up_flow, dim=2)
        up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
        return up_flow.reshape(N, D, factor * H, factor * W)

    def upsample_flow_with_gradient(self, flow, gradient, mask):
        """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
        N, D, H, W = flow.shape
        factor = 2**self.args.n_downsample
        mask = mask.view(N, 3, 9, factor, factor, H, W)
        up_flow = F.unfold(factor * flow, [3, 3], padding=1)
        up_flow = up_flow.view(N, 1, 9, 1, 1, H, W)
        up_gradient = F.unfold(factor * gradient, [3, 3], padding=1)
        up_gradient = up_gradient.view(N, 2, 9, 1, 1, H, W)
        gradient_mask = mask[:, 1:3, :].view(N, 2, 9, factor, factor, H, W)
        # upsample_mask = torch.sigmoid(gradient_mask)/2
        upsample_mask = mask[:, 0:1, :]
        upsample_mask = upsample_mask.view(N, 1, 9, factor, factor, H, W)
        upsample_mask = torch.softmax(upsample_mask, dim=2)
        gradient_offset = torch.sum(gradient_mask * up_gradient, dim=1).unsqueeze(1)
        up_flow_with_gradient_offset = up_flow + gradient_offset
        up_flow_with_gradient_offset = up_flow_with_gradient_offset.view(
            N, 1, 9, factor, factor, H, W
        )
        up_flow_with_offset = torch.sum(
            upsample_mask * up_flow_with_gradient_offset, dim=2
        )
        up_flow_with_offset = up_flow_with_offset.permute(0, 1, 4, 2, 5, 3)
        # print(up_flow_with_gradient_offset)
        return up_flow_with_offset.reshape(N, D, factor * H, factor * W)

    def forward(
        self,
        image1,
        image2,
        patchmatch_rounds,
        iters=12,
        flow_init=None,
        test_mode=False,
    ):
        """Estimate optical flow between pair of frames"""

        image1 = (2 * (image1 / 255.0) - 1.0).contiguous()
        image2 = (2 * (image2 / 255.0) - 1.0).contiguous()

        # run the context network
        with autocast(enabled=self.args.mixed_precision):
            if self.args.shared_backbone:
                *cnet_list, x = self.cnet(
                    torch.cat((image1, image2), dim=0),
                    dual_inp=True,
                    num_layers=self.args.n_gru_layers,
                )
                fmap1, fmap2 = self.conv2(x).split(dim=0, split_size=x.shape[0] // 2)
            else:
                cnet_list = self.cnet(image1, num_layers=self.args.n_gru_layers)
                fmap1, fmap2 = self.fnet([image1, image2])
            net_list = [torch.tanh(x[0]) for x in cnet_list]
            inp_list = [torch.relu(x[1]) for x in cnet_list]

            # Rather than running the GRU's conv layers on the context features multiple times, we do it once at the beginning
            inp_list = [
                list(conv(i).split(split_size=conv.out_channels // 3, dim=1))
                for i, conv in zip(inp_list, self.context_zqr_convs)
            ]

        batch, _, height, width = fmap1.shape
        if self.args.corr_implementation == "reg":  # Default
            corr_block = CorrBlock1D
            fmap1, fmap2 = fmap1.float(), fmap2.float()
        elif self.args.corr_implementation == "alt":  # More memory efficient than reg
            corr_block = PytorchAlternateCorrBlock1D
            fmap1, fmap2 = fmap1.float(), fmap2.float()
        elif self.args.corr_implementation == "reg_cuda":  # Faster version of reg
            corr_block = CorrBlockFast1D
        elif self.args.corr_implementation == "alt_cuda":  # Faster version of alt
            corr_block = AlternateCorrBlock
        corr_fn = corr_block(fmap1, fmap2, num_levels=self.args.corr_levels)

        coords0, coords1, gradient = self.initialize_flow(corr_fn.corr_pyramid[0])
        # offset_1d = self.get_1d_grid(batch, height, width, fmap1.device)
        if flow_init is not None:
            coords1 = coords1 + flow_init

        flow_predictions = []
        gradient_predictions = []
        propa_weight = self.propa_conv(fmap1).view(
            batch, self.num_neighbors * 2, height * width
        )
        grid, offset_grid = self.get_grid(
            batch, height, width, fmap1.device, propa_weight
        )
        for itr in range(iters):
            coords1 = coords1.detach()
            gradient = gradient.detach()
            if itr % patchmatch_rounds == patchmatch_rounds - 1:
                batch, _, height, width = coords1.size()
                propagate_disp = F.grid_sample(
                    coords1 - coords0,
                    grid,
                    mode="bilinear",
                    padding_mode="border",
                    align_corners=False,
                ).view(batch, 2, self.num_neighbors + 1, height, width)
                # print(propagate_disp[:, 1, :])

                if self.args.look_up_before_propa:
                    corr = corr_fn(coords1, radius=0)
                    # corr=corr.mean(dim=1).unsqueeze(1)
                    propagate_corr = F.grid_sample(
                        corr,
                        grid,
                        mode="bilinear",
                        padding_mode="border",
                        align_corners=False,
                    ).view(
                        batch,
                        (self.num_neighbors + 1) * self.args.corr_levels,
                        height,
                        width,
                    )
                else:
                    corr_list = []
                    for i in range(self.num_neighbors + 1):
                        # corr_list.append(corr_fn(propagate_disp[:,:,i,:,:],radius=0).mean(dim=1).unsqueeze(1))
                        corr_list.append(
                            corr_fn(propagate_disp[:, :, i, :, :], radius=0)
                        )
                    propagate_corr = torch.cat(corr_list, dim=1)

                with autocast(enabled=self.args.mixed_precision):
                    if (
                        self.args.n_gru_layers == 3 and self.args.slow_fast_gru
                    ):  # Update low-res GRU
                        net_list = self.update_block(
                            net_list,
                            inp_list,
                            iter32=True,
                            iter16=False,
                            iter08=False,
                            update=False,
                        )
                    if (
                        self.args.n_gru_layers >= 2 and self.args.slow_fast_gru
                    ):  # Update low-res GRU and mid-res GRU
                        net_list = self.update_block(
                            net_list,
                            inp_list,
                            iter32=self.args.n_gru_layers == 3,
                            iter16=True,
                            iter08=False,
                            update=False,
                        )
                    net_list, up_mask, score = self.update_block(
                        net_list,
                        inp_list,
                        propagate_corr,
                        torch.cat([flow], dim=1),
                        iter32=self.args.n_gru_layers == 3,
                        iter16=self.args.n_gru_layers >= 2,
                    )

                # index=torch.argmax(score,dim=1).unsqueeze(1)
                # offset_grid=offset_grid.view(batch,2,self.num_neighbors+1,height,width)
                # grid = grid.view(batch,2,self.num_neighbors+1,height,width)
                # result_grid = torch.gather(offset_grid, 2, index.unsqueeze(1).expand(-1, 2, -1, -1, -1)).squeeze(2)
                # grid = torch.gather(grid, 2, index.unsqueeze(1).expand(-1, 2, -1, -1, -1)).squeeze(2)

                # gradient=F.grid_sample(
                #     gradient,
                #     grid.permute(0,2,3,1),
                #     mode="bilinear",
                #     padding_mode="border",
                #     align_corners=False
                # ).view(batch, 2, height, width)

                # gradient_offset=torch.sum(gradient*result_grid,dim=1).unsqueeze(1)

                # disp= F.grid_sample(
                #     coords1-coords0,
                #     grid.permute(0,2,3,1),
                #     mode="bilinear",
                #     padding_mode="border",
                #     align_corners=False
                # ).view(batch, 2, height, width)

                # flow[:,0:1]=1/(1/(disp[:,0:1]+1e-6)+gradient_offset)

                score[:, 1] = 0.0
                # F(t+1) = F(t) + \Delta(t)
                coords1 = coords1 + score[:, 0:2]
                # coords1[:,0:1]=1/(1/(coords1[:,0:1]+1e-6)+torch.sum(coords0*score[:,2:],dim=1).unsqueeze(1)+1e-6)
                # gradient = gradient + score[:, 2:]

            else:
                corr = corr_fn(
                    coords1, radius=self.args.corr_radius
                )  # index correlation volume
                flow = coords1 - coords0
                with autocast(enabled=self.args.mixed_precision):
                    if (
                        self.args.n_gru_layers == 3 and self.args.slow_fast_gru
                    ):  # Update low-res GRU
                        net_list = self.update_block(
                            net_list,
                            inp_list,
                            iter32=True,
                            iter16=False,
                            iter08=False,
                            update=False,
                        )
                    if (
                        self.args.n_gru_layers >= 2 and self.args.slow_fast_gru
                    ):  # Update low-res GRU and mid-res GRU
                        net_list = self.update_block(
                            net_list,
                            inp_list,
                            iter32=self.args.n_gru_layers == 3,
                            iter16=True,
                            iter08=False,
                            update=False,
                        )
                    net_list, up_mask, delta_flow = self.update_block(
                        net_list,
                        inp_list,
                        corr,
                        torch.cat([flow], dim=1),
                        iter32=self.args.n_gru_layers == 3,
                        iter16=self.args.n_gru_layers >= 2,
                    )

                # in stereo mode, project flow onto epipolar
                delta_flow[:, 1] = 0.0

                # F(t+1) = F(t) + \Delta(t)
                coords1 = coords1 + delta_flow[:, 0:2]
                # coords1[:,0:1]=1/(1/(coords1[:,0:1]+1e-6)+torch.sum(coords0*delta_flow[:,2:],dim=1).unsqueeze(1)+1e-6)
                # gradient = gradient + delta_flow[:, 2:]

            if test_mode and itr < iters - 1:
                continue
            flow = coords1 - coords0
            # flow = 1/(1/(coords1- coords0 +1e-6)+torch.sum(coords0*gradient,dim=1).unsqueeze(1)+1e-6)
            if up_mask is None:
                flow_up = upflow8(flow)
            else:
                flow_up = self.upsample_flow(flow, up_mask)
                # flow_up = self.upsample_flow_with_gradient(flow[:,:1], gradient, up_mask)
                flow_up = flow_up[:, :1]
                refine = False
                if itr == iters - 1 and refine == True:

                    flow_pm = disp_up(
                        flow[:, :1], gradient[:, 0:1], gradient[:, 1:2], 4
                    )
                    gradient_up = F.interpolate(gradient, scale_factor=4)
                    flow_pm = torch.cat((flow_pm, gradient_up), dim=1)
                    flow_up = self.refine(flow_up, torch.cat((flow_pm, image1), dim=1))
            flow_predictions.append(flow_up)
            gradient_predictions.append(gradient)
            # We do not need to upsample or output intermediate results in test_mode
            # print(gradient)
            # upsample predictions
        if test_mode:
            return coords1 - coords0, flow_up

        return flow_predictions, gradient_predictions
