import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import warp, get_robust_weight
from loss import *


def resize(x, scale_factor):
    return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)


def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias), 
        nn.PReLU(out_channels)
    )


class ResBlock(nn.Module):
    def __init__(self, in_channels, side_channels, bias=True):
        super(ResBlock, self).__init__()
        self.side_channels = side_channels
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), 
            nn.PReLU(in_channels)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), 
            nn.PReLU(side_channels)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), 
            nn.PReLU(in_channels)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), 
            nn.PReLU(side_channels)
        )
        self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias)
        self.prelu = nn.PReLU(in_channels)

    def forward(self, x):
        out = self.conv1(x)
        out[:, -self.side_channels:, :, :] = self.conv2(out[:, -self.side_channels:, :, :].clone())
        out = self.conv3(out)
        out[:, -self.side_channels:, :, :] = self.conv4(out[:, -self.side_channels:, :, :].clone())
        out = self.prelu(x + self.conv5(out))
        return out


class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.pyramid1 = nn.Sequential(
            convrelu(3, 32, 3, 2, 1), 
            convrelu(32, 32, 3, 1, 1)
        )
        self.pyramid2 = nn.Sequential(
            convrelu(32, 48, 3, 2, 1), 
            convrelu(48, 48, 3, 1, 1)
        )
        self.pyramid3 = nn.Sequential(
            convrelu(48, 72, 3, 2, 1), 
            convrelu(72, 72, 3, 1, 1)
        )
        self.pyramid4 = nn.Sequential(
            convrelu(72, 96, 3, 2, 1), 
            convrelu(96, 96, 3, 1, 1)
        )
        
    def forward(self, img):
        f1 = self.pyramid1(img)
        f2 = self.pyramid2(f1)
        f3 = self.pyramid3(f2)
        f4 = self.pyramid4(f3)
        return f1, f2, f3, f4



class Masked_self_feature(nn.Module):
    def __init__(self):
        super(Masked_self_feature, self).__init__()

        self.spatial_attn = convrelu(96*2, 96, 1, 1, 0)
        self.spatial_attn_mul1 = convrelu(96, 96, 3, 1, 1)
        self.spatial_attn_mul2 = nn.Conv2d(96, 96, 3, 1, 1)
        
    def forward(self, f0, mask0):
        B, N, _, _ = mask0.size()
        _, C, H, W = f0.size()
        mask0 = F.interpolate(mask0, size=(H, W), mode='nearest')
        ######################################### self-attention ##################################
        x_self = f0.clone()  
        x_self1 = f0.clone() 
        for b in range(B):
            for n in range(N):
                if mask0[b,n].sum() == 0:
                    continue
                single_mask0 = mask0[b, n]
                # single feature
                single_masked_x0 = x_self[b, :, single_mask0>0]

                #object feature
                pooled_x0 = single_masked_x0.mean(dim=1, keepdim=True)
                repeats = [1, single_masked_x0.size(1)]
                replicated_x0 = pooled_x0.repeat(*repeats)

                x_self1[b,:,single_mask0>0] = replicated_x0

        # spatial attention
        attn = self.spatial_attn(torch.cat([x_self, x_self1], dim=1))
        attn_mul = self.spatial_attn_mul2((self.spatial_attn_mul1(attn)))
        attn_mul = torch.sigmoid(attn_mul)
        f0 = x_self * attn_mul + x_self1 * (1 - attn_mul)
        return f0


class Masked_cross_feature(nn.Module):
    def __init__(self):
        super(Masked_cross_feature, self).__init__()

        self.spatial_attn = convrelu(96*2, 96, 1, 1, 0)
        self.spatial_attn_mul1 = convrelu(96, 96, 3, 1, 1)
        self.spatial_attn_mul2 = nn.Conv2d(96, 96, 3, 1, 1)

        
    def forward(self, f0, mask0, f1, mask1):
        B, N, _, _ = mask0.size()
        _, C, H, W = f0.size()
        mask0 = F.interpolate(mask0, size=(H, W), mode='nearest')
        mask1 = F.interpolate(mask1, size=(H, W), mode='nearest')
        ######################################### self-attention ##################################
        x_cross1 = f0.clone() 
        for b in range(B):
            for n in range(N):
                if mask0[b, n].sum() == 0 or mask1[b, n].sum() == 0:
                    continue 
                single_mask0 = mask0[b, n]
                single_mask1 = mask1[b, n]
                
                single_masked_x0 = f0[b, :, single_mask0>0]
                single_masked_x1 = f1[b, :, single_mask1>0]

                #object feature
                pooled_x1 = single_masked_x1.mean(dim=1, keepdim=True)
                repeats = [1, single_masked_x0.size(1)]
                replicated_x1 = pooled_x1.repeat(*repeats)

                x_cross1[b,:,single_mask0>0] = replicated_x1

        # spatial attention
        attn = self.spatial_attn(torch.cat([f0, x_cross1], dim=1))
        attn_mul = self.spatial_attn_mul2((self.spatial_attn_mul1(attn)))
        attn_mul = torch.sigmoid(attn_mul)
        f0 = f0 * attn_mul + x_cross1 * (1 - attn_mul)
        return f0


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class FlowAttention(nn.Module):
    def __init__(self,
                 dim=96,
                 bias=True,
                 qk_scale=None,
                 attn_drop=0.,
                 proj_drop=0.,
                 norm_layer=nn.LayerNorm):

        super().__init__()
        self.dim = dim
        self.scale = qk_scale or dim**-0.5

        self.norm1 = norm_layer(dim)
        self.q_proj = nn.Linear(dim, dim, bias=bias)
        self.k_proj = nn.Linear(dim, dim, bias=bias)
        self.v_proj = nn.Linear(dim + 2, dim + 2, bias=bias)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim + 2, dim + 2)
        self.proj_drop = nn.Dropout(proj_drop)

        self.softmax = nn.Softmax(dim=-1)

        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * 2)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=nn.GELU)

        self.grids = {}

    def generate_grid(self, B, H, W, normalize=True):
        yy, xx = torch.meshgrid(torch.arange(H), torch.arange(W))
        if normalize:
            yy = yy / (H - 1)
            xx = xx / (W - 1)
        grid = torch.stack([xx, yy], dim=0)
        grid = grid[None].expand(B, -1, -1, -1)
        return grid

    def apply_attention(self, x, tgt, grid, mask0, mask1, b):
        single_mask0 = mask0[b:b + 1]
        single_mask1 = mask1[b:b + 1]

        single_masked_x = x[b:b + 1, :, single_mask0[0, 0] > 0].permute(0, 2, 1)  # [1, S0, C]
        single_masked_tgt = tgt[b:b + 1, :, single_mask1[0, 0] > 0].permute(0, 2, 1)  # [1, S1, C]
        single_masked_grid = grid[b:b + 1, :, single_mask1[0, 0] > 0].permute(0, 2, 1)  # [1, S1, 2]

        single_masked_x = self.norm1(single_masked_x)
        shortcut = single_masked_x   # [1, S0, C]

        q = self.q_proj(single_masked_x) * self.scale  # [1, S0, C]
        k = self.k_proj(single_masked_tgt)  # [1, S1, C]
        v = self.v_proj(torch.cat([single_masked_tgt, single_masked_grid], dim=-1))  # [1, S1, C+2]

        attn = self.softmax(q @ k.transpose(-2, -1))  # [1, S0, S1]
        attn = self.attn_drop(attn)

        single_masked_x = attn @ v  # [1, S0, C+2]
        single_masked_x = self.proj_drop(self.proj(single_masked_x))  # [1, S0, C+2]

        return single_masked_x, shortcut    #[1, S0, C+2],  # [1, S0, C]

    def forward(self, x, tgt, mask0, mask1):
        B, C, H, W = x.shape
        mask0 = F.interpolate(mask0, size=(H, W), mode='nearest')
        mask1 = F.interpolate(mask1, size=(H, W), mode='nearest')

        grid = self.grids.get(f"{H}_{W}")
        if grid is None:
            grid = self.generate_grid(B, H, W).to(x)
            self.grids[f"{H}_{W}"] = grid.clone()

        grid_tran = grid.clone()
        single_mask0_left = torch.ones(B, 1, H, W, device=mask0.device)
        single_mask1_left = torch.ones(B, 1, H, W, device=mask1.device)
        for b in range(B):
            for n in range(mask0.size(1)):
                if mask0[b, n].sum() == 0 or mask1[b, n].sum() == 0:
                    continue

                single_masked_x, shortcut = self.apply_attention(x, tgt, grid, mask0[:, n:n+1], mask1[:, n:n+1], b)

                grid_tran[b:b + 1, :, mask0[b, n] > 0] = single_masked_x[:, :, -2:].permute(0, 2, 1) # [1, 2, S0]
                single_masked_x = shortcut + single_masked_x[:, :, :-2]  
                single_masked_x = single_masked_x + self.mlp(self.norm2(single_masked_x))  # [1, S0, C]

                x[b:b + 1, :, mask0[b, n]> 0] = single_masked_x.permute(0, 2, 1)

                single_mask0_left[b:b+1] = single_mask0_left[b:b+1] - mask0[b:b+1, n:n+1]
                single_mask1_left[b:b+1] = single_mask1_left[b:b+1] - mask1[b:b+1, n:n+1]

            
            single_masked_x, shortcut = self.apply_attention(x, tgt, grid, single_mask0_left, single_mask1_left, b)

            grid_tran[b:b + 1, :, single_mask0_left[b, 0] > 0] = single_masked_x[:, :, -2:].permute(0, 2, 1) # [1, 2, S0]
            single_masked_x = shortcut + single_masked_x[:, :, :-2]
            single_masked_x = single_masked_x + self.mlp(self.norm2(single_masked_x))

            x[b:b + 1, :, single_mask0_left[b, 0] > 0] = single_masked_x.permute(0, 2, 1)

        flow = grid_tran - grid
        return x, flow


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class FlowAttention1(nn.Module):
    def __init__(self,
                 dim=96,
                 bias=True,
                 qk_scale=None,
                 attn_drop=0.,
                 proj_drop=0.,
                 norm_layer=nn.LayerNorm):

        super().__init__()
        self.dim = dim
        self.scale = dim**-0.5

        self.norm1 = norm_layer(dim)
        # define the projection layer
        self.q_proj = nn.Linear(dim, dim, bias=bias)
        self.k_proj = nn.Linear(dim, dim, bias=bias)
        self.v_proj = nn.Linear(dim + 2, dim + 2, bias=bias)

        self.attn_drop = nn.Dropout(attn_drop)

        self.proj = nn.Linear(dim+2, dim+2)
        self.proj_drop = nn.Dropout(proj_drop)

        self.softmax = nn.Softmax(dim=-1)

        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * 2)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=nn.GELU)

        self.grids = {}

    def generate_grid(self, B, H, W, normalize=True):
        yy, xx = torch.meshgrid(torch.arange(H), torch.arange(W))
        if normalize:
            yy = yy / (H - 1)
            xx = xx / (W - 1)
        grid = torch.stack([xx, yy], dim=0)
        grid = grid[None].expand(B, -1, -1, -1)
        return grid

    def forward(self, x, tgt, return_attn=False):
        """
        Args:
            x: input features with shape of (B, C, H, W)
        """
        B, C, H, W = x.shape
        grid = self.grids.get(f"{H}_{W}")
        if grid is None:
            grid = self.generate_grid(B, H, W).to(x)
            # grid = self.generate_grid(B, H, W, normalize=False).to(x)
            self.grids[f"{H}_{W}"] = grid.clone()
        grid = grid.flatten(2).permute(0, 2, 1)

        x = x.flatten(2).permute(0, 2, 1)
        k = tgt.flatten(2).permute(0, 2, 1)
        v = torch.cat([tgt.flatten(2).permute(0, 2, 1), grid], dim=-1)

        x = self.norm1(x)
        shortcut = x

        q = self.q_proj(x)  # [B, N, C]
        k = self.k_proj(k)
        v = self.v_proj(v)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)  # [B, H, N, N]

        x = attn @ v
        x = self.proj_drop(self.proj(x))  # .view(B, H, W, 2).permute(0, 3, 1, 2)

        # mlp
        flow = x[..., -2:] - grid
        x = x[..., :-2]
        x = shortcut + x  # [B, N, :-2], global warped features, [B, N, -2:]: correspondence
        x = x + self.mlp(self.norm2(x))

        x = torch.cat([x, flow], dim=-1)
        x = x.view(B, H, W, -1).permute(0, 3, 1, 2)  # [B, C+2, H, W]
        return x

class BasicBlock(nn.Module):
    def __init__(self, in_channels, act=nn.LeakyReLU(negative_slope=0.1), stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
        self.act = act
        self.conv2 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)

        # initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    m.bias.data.zero_()

    def forward(self, x):
        identity = x

        out = self.act(self.conv1(x))
        out = self.act(self.conv2(out))

        out = out + identity
        # out = self.act(out)

        return out



class InitialDecoder(nn.Module):
    def __init__(self):
        super(InitialDecoder, self).__init__()
        self.convblock = nn.Sequential(
            convrelu(96*2 + 4 + 1, 96*2), 
            ResBlock(96*2, 32), 
            nn.Conv2d(96*2, 4+96, 3, 1, 1, bias=True)
        )

    def forward(self, f0, embt):
        b, c, h, w = f0.shape
        embt = embt.repeat(1, 1, h, w)
        f_in = torch.cat([f0, embt], 1)
        f_out = self.convblock(f_in)
        return f_out


class Decoder4(nn.Module):
    def __init__(self):
        super(Decoder4, self).__init__()
        self.convblock = nn.Sequential(
            convrelu(96*3+4, 96*3), 
            ResBlock(96*3, 32), 
            nn.ConvTranspose2d(96*3, 76, 4, 2, 1, bias=True)
        )
        
    def forward(self, ft_, f0, f1, up_flow0, up_flow1):
        b, c, h, w = f0.shape
        f0_warp = warp(f0, up_flow0)
        f1_warp = warp(f1, up_flow1)
        f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1)
        f_out = self.convblock(f_in)
        return f_out


class Decoder3(nn.Module):
    def __init__(self):
        super(Decoder3, self).__init__()
        self.convblock = nn.Sequential(
            convrelu(220, 216), 
            ResBlock(216, 32), 
            nn.ConvTranspose2d(216, 52, 4, 2, 1, bias=True)
        )

    def forward(self, ft_, f0, f1, up_flow0, up_flow1):
        f0_warp = warp(f0, up_flow0)
        f1_warp = warp(f1, up_flow1)
        f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1)
        f_out = self.convblock(f_in)
        return f_out


class Decoder2(nn.Module):
    def __init__(self):
        super(Decoder2, self).__init__()
        self.convblock = nn.Sequential(
            convrelu(148, 144), 
            ResBlock(144, 32), 
            nn.ConvTranspose2d(144, 36, 4, 2, 1, bias=True)
        )

    def forward(self, ft_, f0, f1, up_flow0, up_flow1):
        f0_warp = warp(f0, up_flow0)
        f1_warp = warp(f1, up_flow1)
        f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1)
        f_out = self.convblock(f_in)
        return f_out


class Decoder1(nn.Module):
    def __init__(self):
        super(Decoder1, self).__init__()
        self.convblock = nn.Sequential(
            convrelu(100, 96), 
            ResBlock(96, 32), 
            nn.ConvTranspose2d(96, 8, 4, 2, 1, bias=True)
        )
        
    def forward(self, ft_, f0, f1, up_flow0, up_flow1):
        f0_warp = warp(f0, up_flow0)
        f1_warp = warp(f1, up_flow1)
        f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1)
        f_out = self.convblock(f_in)
        return f_out


class Plug(nn.Module):
    def __init__(self):
        super(Plug, self).__init__()
        self.masked_self_feature = Masked_self_feature()
        self.masked_cross_feature = Masked_cross_feature()
        self.flowattention = FlowAttention()
        self.flowattention1 = FlowAttention1()

        self.pred_feats_t = nn.Sequential(
            nn.Conv2d(96*2, 96, kernel_size=3, stride=1, padding=1),
            nn.PReLU(96),
            BasicBlock(96, nn.PReLU(96)),
        )

        self.initial_decoder = InitialDecoder()
        
    def forward(self, f0, f1, embt, mask0, mask1):

        B, C, H, W = f0.size()

        f0 = self.masked_self_feature(f0, mask0)
        f1 = self.masked_self_feature(f1, mask1)

        f00 = self.masked_cross_feature(f0, mask0, f1, mask1)
        f11 = self.masked_cross_feature(f1, mask1, f0, mask0)

        f000, flow01 = self.flowattention(f00, f11, mask0, mask1)
        f111, flow10 = self.flowattention(f11, f00, mask1, mask0)

        flow_t0 = 0.5 * flow10
        flow_t1 = 0.5 * flow01

        f0_warp = warp(f000, flow_t0)
        f1_warp = warp(f111, flow_t1)

        feats_t = self.pred_feats_t(torch.cat([f0_warp, f1_warp], dim=1))

        identity = torch.stack([f000, f111], dim=1)

        flow_t_feats_t = self.flowattention1(feats_t.unsqueeze(1).repeat(1, 2, 1, 1, 1).flatten(0, 1), identity.flatten(0, 1)).view(B, 2, -1, H, W)   # [B, 2, C+2, H, W]

        f_out = self.initial_decoder(flow_t_feats_t.flatten(1, 2), embt)
        up_flow0_5 = flow_t0 + f_out[:, 0:2]
        up_flow1_5 = flow_t1 + f_out[:, 2:4]
        ft_5_ = f_out[:, 4:]

        return up_flow0_5, up_flow1_5, ft_5_


class Model(nn.Module):
    def __init__(self, local_rank=-1, lr=1e-4):
        super(Model, self).__init__()
        self.encoder = Encoder()
        self.decoder4 = Decoder4()
        self.decoder3 = Decoder3()
        self.decoder2 = Decoder2()
        self.decoder1 = Decoder1()
        self.plug = Plug()
        self.l1_loss = Charbonnier_L1()
        self.rb_loss = Charbonnier_Ada()

    def inference(self, img0, img1, embt, scale_factor, mask0, mask1):

        img0_ = resize(img0, scale_factor=scale_factor)
        img1_ = resize(img1, scale_factor=scale_factor)

        mask0_ = resize(mask0, scale_factor=scale_factor)
        mask1_ = resize(mask1, scale_factor=scale_factor)

        f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_)
        f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_)

        up_flow0_5, up_flow1_5, ft_5_ = self.plug(f0_4, f1_4, embt, mask0_, mask1_)

        out4 = self.decoder4(ft_5_, f0_4, f1_4, up_flow0_5, up_flow1_5)
        up_flow0_4 = out4[:, 0:2] + 2.0 * resize(up_flow0_5, scale_factor=2.0)
        up_flow1_4 = out4[:, 2:4] + 2.0 * resize(up_flow1_5, scale_factor=2.0)
        ft_3_ = out4[:, 4:]

        out3 = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4)
        up_flow0_3 = out3[:, 0:2] + 2.0 * resize(up_flow0_4, scale_factor=2.0)
        up_flow1_3 = out3[:, 2:4] + 2.0 * resize(up_flow1_4, scale_factor=2.0)
        ft_2_ = out3[:, 4:]

        out2 = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3)
        up_flow0_2 = out2[:, 0:2] + 2.0 * resize(up_flow0_3, scale_factor=2.0)
        up_flow1_2 = out2[:, 2:4] + 2.0 * resize(up_flow1_3, scale_factor=2.0)
        ft_1_ = out2[:, 4:]

        out1 = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2)
        up_flow0_1 = out1[:, 0:2] + 2.0 * resize(up_flow0_2, scale_factor=2.0)
        up_flow1_1 = out1[:, 2:4] + 2.0 * resize(up_flow1_2, scale_factor=2.0)
        up_mask_1 = torch.sigmoid(out1[:, 4:5])
        up_res_1 = out1[:, 5:]

        up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
        up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
        up_mask_1 = resize(up_mask_1, scale_factor=(1.0/scale_factor))
        up_res_1 = resize(up_res_1, scale_factor=(1.0/scale_factor))
        
        img0_warp = warp(img0, up_flow0_1)
        img1_warp = warp(img1, up_flow1_1)
        imgt_merge = up_mask_1 * img0_warp + (1 - up_mask_1) * img1_warp 
        imgt_pred = imgt_merge + up_res_1
        imgt_pred = torch.clamp(imgt_pred, 0, 1)
        return imgt_pred

    def forward(self, img0, img1, embt, imgt, flow=None, mask0=None, mask1=None):
        f0_1, f0_2, f0_3, f0_4 = self.encoder(img0)
        f1_1, f1_2, f1_3, f1_4 = self.encoder(img1)

        up_flow0_5, up_flow1_5, ft_5_ = self.plug(f0_4, f1_4, embt, mask0, mask1)

        out4 = self.decoder4(ft_5_, f0_4, f1_4, up_flow0_5, up_flow1_5)
        up_flow0_4 = out4[:, 0:2] + 2.0 * resize(up_flow0_5, scale_factor=2.0)
        up_flow1_4 = out4[:, 2:4] + 2.0 * resize(up_flow1_5, scale_factor=2.0)
        ft_3_ = out4[:, 4:]

        out3 = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4)
        up_flow0_3 = out3[:, 0:2] + 2.0 * resize(up_flow0_4, scale_factor=2.0)
        up_flow1_3 = out3[:, 2:4] + 2.0 * resize(up_flow1_4, scale_factor=2.0)
        ft_2_ = out3[:, 4:]

        out2 = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3)
        up_flow0_2 = out2[:, 0:2] + 2.0 * resize(up_flow0_3, scale_factor=2.0)
        up_flow1_2 = out2[:, 2:4] + 2.0 * resize(up_flow1_3, scale_factor=2.0)
        ft_1_ = out2[:, 4:]

        out1 = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2)
        up_flow0_1 = out1[:, 0:2] + 2.0 * resize(up_flow0_2, scale_factor=2.0)
        up_flow1_1 = out1[:, 2:4] + 2.0 * resize(up_flow1_2, scale_factor=2.0)
        up_mask_1 = torch.sigmoid(out1[:, 4:5])
        up_res_1 = out1[:, 5:]
        
        img0_warp = warp(img0, up_flow0_1)
        img1_warp = warp(img1, up_flow1_1)
        imgt_merge = up_mask_1 * img0_warp + (1 - up_mask_1) * img1_warp 
        imgt_pred = imgt_merge + up_res_1
        imgt_pred = torch.clamp(imgt_pred, 0, 1)

        # Recon_loss
        loss_rec = self.l1_loss(imgt_pred - imgt)

        # Flow_loss
        loss_dis = 0.00 * loss_rec

        # Grad_loss
        Up_Flow0 = 16.0 * resize(up_flow0_5, 16.0)
        Up_Flow1 = 16.0 * resize(up_flow1_5, 16.0)
        #fusion_mask = resize(fusion_mask, 8.0)
        Warped_img0 = warp(img0, Up_Flow0)
        Warped_img1 = warp(img1, Up_Flow1)
        #Warped_imgt = Warped_img0 * fusion_mask + Warped_img1 * (1 - fusion_mask)
        loss_grad = 0.1 * self.l1_loss(Warped_img0 - imgt) + 0.1 * self.l1_loss(Warped_img1 - imgt)
        #loss_grad = 0.00 * loss_rec

        return imgt_pred, loss_rec, loss_dis, loss_grad
