import torch
import torch.nn as nn
import torch.nn.functional as F
from core.update import (
    BasicMultiUpdateBlock,
    BasicUpdateBlock,
    PatchMatchMultiUpdateBlock,
    C_BasicUpdateBlock,
)
from core.extractor import (
    BasicEncoder,
    MultiBasicEncoder,
    ResidualBlock,
    C_BasicEncoder,
)
from core.corr import (
    CorrBlock1D,
    PytorchAlternateCorrBlock1D,
    CorrBlockFast1D,
    AlternateCorrBlock,
    PAGCL,
)
from core.utils.utils import coords_grid, upflow8
import functools

try:
    autocast = torch.cuda.amp.autocast
except:
    # dummy autocast for PyTorch < 1.6
    class autocast:
        def __init__(self, enabled):
            pass

        def __enter__(self):
            pass

        def __exit__(self, *args):
            pass


@functools.lru_cache()
@torch.no_grad()
def make_warp_coef(scale, device):
    center = (scale - 1) / 2
    index = torch.arange(scale, device=device) - center
    coef_y, coef_x = torch.meshgrid(index, index)
    coef_x = coef_x.reshape(1, -1, 1, 1)
    coef_y = coef_y.reshape(1, -1, 1, 1)
    return coef_x, coef_y


def disp_up(d, dx, dy, scale):
    n, _, h, w = d.size()
    coef_x, coef_y = make_warp_coef(scale, d.device)
    d = (d + coef_x * dx + coef_y * dy) * scale
    d = d.reshape(n, 1, scale, scale, h, w)
    d = d.permute(0, 1, 4, 2, 5, 3)
    d = d.reshape(n, 1, h * scale, w * scale)
    return d


class ResBlock(nn.Module):
    def __init__(self, c0, dilation=1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(c0, c0, 3, 1, dilation, dilation),
            nn.LeakyReLU(0.2),
            nn.Conv2d(c0, c0, 3, 1, dilation, dilation),
        )
        self.relu = nn.LeakyReLU(0.2)

    def forward(self, input):
        x = self.conv(input)
        x = x + input
        x = self.relu(x)
        return x


class Refine(nn.Module):
    def __init__(self, cin, cres, dilations):
        super().__init__()
        self.conv1x1 = nn.Sequential(
            nn.Conv2d(cin, cres, 1),
            nn.LeakyReLU(0.2),
        )
        self.conv1 = nn.Sequential(
            nn.Conv2d(cres, cres, 3, 1, 1),
            nn.LeakyReLU(0.2),
        )
        self.res_block = []
        for d in dilations:
            self.res_block += [ResBlock(cres, d)]
        self.res_block = nn.Sequential(*self.res_block)
        self.convn = nn.Conv2d(cres, 1, 3, 1, 1)

    def forward(self, hpy, left):
        x = torch.cat((left, hpy), dim=1)
        x = self.conv1x1(x)
        x = self.conv1(x)
        x = self.res_block(x)
        x = self.convn(x)
        return hpy + x


class RAFTStereo_modified(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.dropout = 0
        self.fnet = C_BasicEncoder(
            output_dim=256, norm_fn="instance", dropout=self.dropout
        )
        # self.update_block = BasicMultiUpdateBlock(self.args, hidden_dims=args.hidden_dims)

        self.update_block = C_BasicUpdateBlock(
            hidden_dim=args.hidden_dims[2],
            cor_planes=4 * 9,
            mask_size=4,
            flow_dim=2,
            output_dim=2,
        )

        self.update_block8 = C_BasicUpdateBlock(
            hidden_dim=args.hidden_dims[2],
            cor_planes=4 * 9,
            mask_size=4,
            flow_dim=2,
            output_dim=2,
        )

        self.update_block16 = C_BasicUpdateBlock(
            hidden_dim=args.hidden_dims[2],
            cor_planes=4 * 9,
            mask_size=4,
            flow_dim=2,
            output_dim=2,
        )

        self.num_neighbors = args.num_neighbors
        out = self.num_neighbors * 2
        self.propa_conv = nn.Sequential(
            nn.Conv2d(
                in_channels=args.hidden_dims[0] * 2,
                out_channels=out,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=True,
            ),
            nn.BatchNorm2d(out),
            nn.LeakyReLU(0.2),
        )
        self.propa_conv8 = nn.Sequential(
            nn.Conv2d(
                in_channels=args.hidden_dims[0] * 2,
                out_channels=out,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=True,
            ),
            nn.BatchNorm2d(out),
            nn.LeakyReLU(0.2),
        )
        self.propa_conv16 = nn.Sequential(
            nn.Conv2d(
                in_channels=args.hidden_dims[0] * 2,
                out_channels=out,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=True,
            ),
            nn.BatchNorm2d(out),
            nn.LeakyReLU(0.2),
        )
        # self.refine = Refine(4 + 3, 16, [1, 1])
        # self.p_update_block=PatchMatchMultiUpdateBlock(self.args, hidden_dims=args.hidden_dims)
        # self.p_update_block=BasicUpdateBlock(self.args, hidden_dims=args.hidden_dims)

    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    # def initialize_flow(self, corr):
    #     """Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
    #     N, H, W, _, _ = corr.shape

    #     coords0 = coords_grid(N, H, W).to(corr.device)
    #     coords1 = coords_grid(N, H, W).to(corr.device)
    #     gradient = torch.zeros_like(coords1).to(corr.device)

    #     return coords0, coords1, gradient

    def initialize_flow(self, fmap):
        """Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
        N, _, H, W = fmap.shape

        coords1 = coords_grid(N, H, W).to(fmap.device)
        coords1 = coords_grid(N, H, W).to(fmap.device)
        res = torch.rand(N, 1, H, W, device=fmap.device) * W
        coords1[:, 0, :] = coords1[:, 0, :] + res[:, 0, :]
        flow = coords1 - coords1
        gradient = torch.zeros_like(flow).to(fmap.device)

        return flow, gradient

    def get_grid(self, batch, height, width, device, offset, dilation=1):

        original_offset = [
            [-dilation, -dilation],
            [-dilation, 0],
            [-dilation, -dilation],
            [0, -dilation],
            [0, dilation],
            [dilation, dilation],
            [dilation, 0],
            [dilation, dilation],
        ]
        if self.num_neighbors == 16:
            for i in range(len(original_offset)):
                offset_x, offset_y = original_offset[i]
                original_offset.append([2 * offset_x, 2 * offset_y])

        with torch.no_grad():
            y_grid, x_grid = torch.meshgrid(
                [
                    torch.arange(0, height, dtype=torch.float32, device=device),
                    torch.arange(0, width, dtype=torch.float32, device=device),
                ]
            )
            y_grid, x_grid = y_grid.contiguous().view(
                height * width
            ), x_grid.contiguous().view(height * width)
            xy = torch.stack((x_grid, y_grid))  # [2, H*W]
            xy = torch.unsqueeze(xy, 0).repeat(batch, 1, 1)  # [B, 2, H*W]
        xy_list = []
        xy_list.append(xy.unsqueeze(2))
        for i in range(self.num_neighbors):
            original_offset_y, original_offset_x = original_offset[i]
            offset_x = original_offset_x + offset[:, 2 * i, :].unsqueeze(1)
            offset_y = original_offset_y + offset[:, 2 * i + 1, :].unsqueeze(1)
            temp = torch.cat((offset_x, offset_y), dim=1)
            xy_list.append((xy + temp).unsqueeze(2))

            # offset_x = offset[:, 2 * i, :].unsqueeze(1)
            # offset_y = offset[:, 2 * i + 1, :].unsqueeze(1)
            # temp=torch.cat((offset_x, offset_y), dim=1)
            # xy_list.append((xy + temp).unsqueeze(2))
            # offset_list.append(temp.unsqueeze(2))

        xy = torch.cat(xy_list, dim=2)  # [B, 2, 9, H*W]
        x_normalized = xy[:, 0, :, :] / ((width - 1) / 2) - 1
        y_normalized = xy[:, 1, :, :] / ((height - 1) / 2) - 1
        grid = torch.stack((x_normalized, y_normalized), dim=3)  # [B, 9, H*W, 2]
        grid = grid.view(batch, (self.num_neighbors + 1) * height, width, 2)
        return grid
        # return grid

    def get_1d_grid(self, batch, height, width, device):

        original_offset = [
            [0, -4],
            [0, -3],
            [0, -2],
            [0, -1],
            [0, 1],
            [0, 2],
            [0, 3],
            [0, 4],
        ]

        with torch.no_grad():
            y_grid, x_grid = torch.meshgrid(
                [
                    torch.arange(0, height, dtype=torch.float32, device=device),
                    torch.arange(0, width, dtype=torch.float32, device=device),
                ]
            )
            y_grid, x_grid = y_grid.contiguous().view(
                height * width
            ), x_grid.contiguous().view(height * width)
            xy = torch.stack((x_grid, y_grid))  # [2, H*W]
            xy = torch.unsqueeze(xy, 0).repeat(batch, 1, 1)  # [B, 2, H*W]
        offset_list = []
        offset_list.append(xy.unsqueeze(2))
        for i in range(self.num_neighbors):
            original_offset_y, original_offset_x = original_offset[i]
            offset_x = original_offset_x + torch.zeros(
                (batch, 1, height * width), device=device
            )
            offset_y = original_offset_y + torch.zeros(
                (batch, 1, height * width), device=device
            )
            temp = torch.cat((offset_x, offset_y), dim=1)
            offset_list.append(temp.unsqueeze(2))
        offset_1d_grid = torch.cat(offset_list, dim=2)
        return offset_1d_grid
        # return grid

    def upsample_flow(self, flow, mask):
        """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
        N, D, H, W = flow.shape
        factor = 2**self.args.n_downsample

        mask = mask.view(N, 1, 9, factor, factor, H, W)
        mask = torch.softmax(mask, dim=2)
        up_flow = F.unfold(factor * flow, [3, 3], padding=1)
        up_flow = up_flow.view(N, D, 9, 1, 1, H, W)

        up_flow = torch.sum(mask * up_flow, dim=2)
        up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
        return up_flow.reshape(N, D, factor * H, factor * W)

    def forward(
        self,
        image1,
        image2,
        patchmatch_rounds,
        iters=12,
        flow_init=None,
        test_mode=False,
    ):
        """Estimate optical flow between pair of frames"""

        image1 = (2 * (image1 / 255.0) - 1.0).contiguous()
        image2 = (2 * (image2 / 255.0) - 1.0).contiguous()

        # run the context network
        hdim = self.args.hidden_dims[2]

        # run the feature network
        with autocast(enabled=self.args.mixed_precision):
            fmap1, fmap2 = self.fnet([image1, image2])

        with autocast(enabled=self.args.mixed_precision):

            # 1/4 -> 1/8
            # feature
            fmap1_dw8 = F.avg_pool2d(fmap1, 2, stride=2)
            fmap2_dw8 = F.avg_pool2d(fmap2, 2, stride=2)

            # context
            net, inp = torch.split(fmap1, [hdim, hdim], dim=1)
            net = torch.tanh(net)
            inp = F.relu(inp)
            net_dw8 = F.avg_pool2d(net, 2, stride=2)
            inp_dw8 = F.avg_pool2d(inp, 2, stride=2)

            # 1/4 -> 1/16
            # feature
            fmap1_dw16 = F.avg_pool2d(fmap1, 4, stride=4)
            # print("fmap1_dw16", fmap1_dw16.shape)
            fmap2_dw16 = F.avg_pool2d(fmap2, 4, stride=4)

            # context
            net_dw16 = F.avg_pool2d(net, 4, stride=4)
            inp_dw16 = F.avg_pool2d(inp, 4, stride=4)
        batch, _, height, width = fmap1.shape

        if self.args.corr_implementation == "reg":  # Default
            corr_block = CorrBlock1D
            fmap1, fmap2 = fmap1.float(), fmap2.float()
        elif self.args.corr_implementation == "alt":  # More memory efficient than reg
            corr_block = PytorchAlternateCorrBlock1D
            fmap1, fmap2 = fmap1.float(), fmap2.float()
        elif self.args.corr_implementation == "reg_cuda":  # Faster version of reg
            corr_block = CorrBlockFast1D
        elif self.args.corr_implementation == "alt_cuda":  # Faster version of alt
            corr_block = AlternateCorrBlock
        # corr_fn = corr_block(fmap1, fmap2, num_levels=self.args.corr_levels)

        fmap1_dw8 = F.avg_pool2d(fmap1, 2, stride=2)
        fmap2_dw8 = F.avg_pool2d(fmap2, 2, stride=2)
        fmap1_dw16 = F.avg_pool2d(fmap1, 4, stride=4)
        fmap2_dw16 = F.avg_pool2d(fmap2, 4, stride=4)
        propa_weight = self.propa_conv(fmap1).view(
            batch, self.num_neighbors * 2, height * width
        )
        grid = self.get_grid(batch, height, width, fmap1.device, propa_weight)
        propa_weight8 = self.propa_conv8(fmap1_dw8).view(
            batch, self.num_neighbors * 2, fmap1_dw8.shape[2] * fmap1_dw8.shape[3]
        )
        grid8 = self.get_grid(
            batch, fmap1_dw8.shape[2], fmap1_dw8.shape[3], fmap1.device, propa_weight8
        )
        propa_weight16 = self.propa_conv16(fmap1_dw16).view(
            batch, self.num_neighbors * 2, fmap1_dw16.shape[2] * fmap1_dw16.shape[3]
        )
        grid16 = self.get_grid(
            batch,
            fmap1_dw16.shape[2],
            fmap1_dw16.shape[3],
            fmap1.device,
            propa_weight16,
        )

        grid_1d = self.get_1d_grid(batch, height, width, fmap1.device)
        grid_1d8 = self.get_1d_grid(
            batch, fmap1_dw8.shape[2], fmap1_dw8.shape[3], fmap1.device
        )
        grid_1d16 = self.get_1d_grid(
            batch, fmap1_dw16.shape[2], fmap1_dw16.shape[3], fmap1.device
        )
        # corr_fn = PAGCL(fmap1, fmap2, grid, grid_1d)
        # corr_fn_dw8 = PAGCL(fmap1_dw8, fmap2_dw8, grid8, grid_1d8)
        # corr_fn_dw16 = PAGCL(fmap1_dw16, fmap2_dw16, grid16, grid_1d16)

        corr_fn = corr_block(fmap1, fmap2, num_levels=self.args.corr_levels)
        corr_fn_dw8 = corr_block(fmap1_dw8, num_levels=self.args.corr_levels)
        corr_fn_dw16 = corr_block(fmap1_dw16, num_levels=self.args.corr_levels)

        flow_dw16, gradient16 = self.initialize_flow(fmap1_dw16)
        if flow_init is not None:
            coords1 = coords1 + flow_init

        flow_predictions = []
        flow_predictions8 = []
        flow_predictions16 = []
        gradient_predictions = []

        for itr in range(iters // 3):
            if itr % (patchmatch_rounds - 1) == 0:
                propagation = True
            else:
                propagation = False

            flow_dw16 = flow_dw16.detach()
            out_corrs = corr_fn_dw16(flow_dw16, propagation=propagation)

            with autocast(enabled=self.args.mixed_precision):
                net_dw16, up_mask, delta_flow = self.update_block16(
                    net_dw16,
                    inp_dw16,
                    out_corrs,
                    torch.cat([flow_dw16], dim=1),
                )

            flow_dw16 = flow_dw16 + delta_flow[:, 0:2]
            # gradient16 = gradient16 + delta_flow[:, 2:]
            flow_up = self.upsample_flow(flow_dw16, up_mask)
            flow_res = 4 * F.interpolate(
                flow_up,
                size=(4 * flow_up.shape[2], 4 * flow_up.shape[3]),
                mode="bilinear",
                align_corners=True,
            )
            flow_predictions16.append(flow_res[:, :1])
            # flow_predictions4.append(flow_up)

        scale = fmap1_dw8.shape[2] / flow_up.shape[2]
        flow_dw8 = scale * F.interpolate(
            flow_up,
            size=(fmap1_dw8.shape[2], fmap1_dw8.shape[3]),
            mode="bilinear",
            align_corners=True,
        )
        # flow_dw8 = disp_up(flow_dw16[:, :1], gradient16[:, 0:1], gradient16[:, 1:2], 2)
        # flow_dw8 = flow_dw8.repeat(1, 2, 1, 1)
        # flow_dw8[:, 1:2] = 0
        # gradient8 = F.interpolate(gradient16, scale_factor=2)
        # RUM: 1/8
        for itr in range(iters // 3):
            if itr % (patchmatch_rounds - 1) == 0:
                propagation = True
            else:
                propagation = False

            flow_dw8 = flow_dw8.detach()
            out_corrs = corr_fn_dw8(flow_dw8, propagation=propagation)

            with autocast(enabled=self.args.mixed_precision):
                net_dw8, up_mask, delta_flow = self.update_block8(
                    net_dw8,
                    inp_dw8,
                    out_corrs,
                    torch.cat([flow_dw8], dim=1),
                )
            flow_dw8 = flow_dw8 + delta_flow[:, 0:2]
            # gradient8 = gradient8 + delta_flow[:, 2:]
            flow_up = self.upsample_flow(flow_dw8, up_mask)
            flow_res = 2 * F.interpolate(
                flow_up,
                size=(2 * flow_up.shape[2], 2 * flow_up.shape[3]),
                mode="bilinear",
                align_corners=True,
            )
            flow_predictions8.append(flow_res[:, :1])
            # flow_predictions2.append(flow_up)
        scale = fmap1.shape[2] / flow_up.shape[2]
        flow = scale * F.interpolate(
            flow_up,
            size=(fmap1.shape[2], fmap1.shape[3]),
            mode="bilinear",
            align_corners=True,
        )
        # flow = disp_up(flow_dw8[:, :1], gradient8[:, 0:1], gradient8[:, 1:2], 2)
        # flow = flow.repeat(1, 2, 1, 1)
        # flow[:, 1:2] = 0
        # gradient = F.interpolate(gradient8, scale_factor=2)
        # RUM: 1/4
        for itr in range(iters // 3):
            if itr % patchmatch_rounds == 0:
                propagation = True
            else:
                propagation = False

            flow = flow.detach()
            out_corrs = corr_fn(flow, propagation=propagation)

            with autocast(enabled=self.args.mixed_precision):
                net, up_mask, delta_flow = self.update_block(
                    net,
                    inp,
                    out_corrs,
                    torch.cat([flow], dim=1),
                )

            flow = flow + delta_flow[:, 0:2]
            # gradient = gradient + delta_flow[:, 2:]
            flow_up = self.upsample_flow(flow, up_mask)
            flow_predictions.append(flow_up[:, :1])
            # gradient_predictions.append(gradient)

        if test_mode:
            return flow_up[:, :1], flow_up[:, :1]

        # return (
        #     flow_predictions,
        #     flow_predictions2,
        #     flow_predictions4,
        # ), gradient_predictions
        return (
            flow_predictions16,
            flow_predictions8,
            flow_predictions,
        ), gradient_predictions
