import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import nn
from timm.models.layers.weight_init import trunc_normal_
import math

from .util import AntiAliasInterpolation2d, coords_grid, bilinear_sampler, batch_bilinear_sampler, Hourglass, kp2gaussian
from .generator import OcclusionAwareGenerator

class CorrBlock:
    def __init__(self, corr, num_levels=2, radius=3, h=64, w=64, p=64, q=64):
        self.num_levels = num_levels
        self.radius = radius
        self.corr_pyramid = []
        
        self.corr_pyramid.append(corr)
        for i in range(self.num_levels-1):
            corr = F.avg_pool2d(corr, 2, stride=2)
            self.corr_pyramid.append(corr)

    def __call__(self, coords):
        r = self.radius
        coords = coords.permute(0, 2, 3, 1)
        batch, h1, w1, _ = coords.shape

        out_pyramid = []
        for i in range(self.num_levels):
            corr = self.corr_pyramid[i]
            dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
            dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
            delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)

            centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
            delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
            coords_lvl = centroid_lvl + delta_lvl

            if batch >1 and h1 >= 128:
                corr = batch_bilinear_sampler(corr, coords_lvl, h=h1, w=w1, mini_batch=1)
            else:
                corr = bilinear_sampler(corr, coords_lvl)
            # corr = bilinear_sampler(corr, coords_lvl)
            corr = corr.view(batch, h1, w1, -1)
            out_pyramid.append(corr)

        out = torch.cat(out_pyramid, dim=-1)
        return out.permute(0, 3, 1, 2).contiguous().float()

class BasicMotionEncoder(nn.Module):
    def __init__(self):
        super(BasicMotionEncoder, self).__init__()
        cor_planes = 2 * (2*3 + 1)**2
        self.convc1 = nn.Conv2d(cor_planes, 128, 1, padding=0)
        self.convc2 = nn.Conv2d(128, 96, 3, padding=1)
        self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
        self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
        self.conv = nn.Conv2d(64+96, 128-2, 3, padding=1)

    def forward(self, delta_flow, corr):
        cor = F.relu(self.convc1(corr))
        cor = F.relu(self.convc2(cor))
        flo = F.relu(self.convf1(delta_flow))
        flo = F.relu(self.convf2(flo))

        cor_flo = torch.cat([cor, flo], dim=1)
        out = F.relu(self.conv(cor_flo))
        # out = F.relu(self.conv2(out))
        return torch.cat([out, delta_flow], dim=1)

class RefineFlow(nn.Module):
    def __init__(self):
        super(RefineFlow, self).__init__()
        self.convc1 = nn.Conv2d(192, 128, 3, padding=1)
        self.conv1 = nn.Conv2d(256, 128, 3, padding=1)
        self.conv2 = nn.Conv2d(128, 2, 3, padding=1)
        self.convo1 = nn.Conv2d(256, 128, 3, padding=1)
        self.convo2 = nn.Conv2d(128, 1, 3, padding=1)

    def forward(self, m_f, warp_f):
        c = F.relu(self.convc1(warp_f))
        inp = torch.cat([m_f,c],dim=1)
        flow = self.conv2(F.relu(self.conv1(inp)))
        occ = self.convo2(F.relu(self.convo1(inp)))
        out = torch.cat([flow,occ],dim=1)
        return out, inp

class SCORRFlow(nn.Module):
    def __init__(self, num_kp=10, dim=256, size=256):
        super(SCORRFlow, self).__init__()
       
        self.corr_enc = BasicMotionEncoder()
        self.refine = RefineFlow()
        self.scale = dim ** -0.5
        h = size // 4
        w = size // 4
        self.size = size
        self.h = h
        self.w = w

        ### driving and soure structure encoder
        self.kp = Hourglass(block_expansion=64,in_features=num_kp,max_features=1024, num_blocks=5)
        self.kp_img = Hourglass(block_expansion=64,in_features=num_kp+3,max_features=1024, num_blocks=5)
        self.kp_head = nn.Conv2d(in_channels=self.kp.out_filters, out_channels=dim, kernel_size=1,padding=0)
        self.kp_img_head = nn.Conv2d(in_channels=self.kp_img.out_filters, out_channels=dim, kernel_size=1,padding=0)
        self.pos_embedding = nn.Parameter(torch.zeros(1, num_kp, h, w))
        trunc_normal_(self.pos_embedding, std=.02)
        ### driving and soure structure encoder

        self.down =  AntiAliasInterpolation2d(3, 0.25)

        ### image generator
        self.generator = OcclusionAwareGenerator(num_channels=3, block_expansion=64, max_features=512, num_up_blocks=5)
        ### image generator

        channels = {
            size//32: 512,
            size//16: 512,
            size//8: 512,
            size//4: 256,
            size//2: 128,
            size: 64,
            # 512: 32,
        }
        self.num_iter = int(math.log(self.size, 2)) - 2
        
        ### warped source feature to context
        self.to_context = nn.ModuleList()
        # 345678
        for i in range(self.num_iter):
            f_channel = channels[(size//128) * (2 ** (i+2))]
            self.to_context.append(nn.Conv2d(f_channel, 192, 1, padding=0))
        ### warped source feature to context
            
    def forward(self, kp_s, kp_d, dense_motion, img, img_full, prior_only=False, scorr_only=False):
        
        feature = self.generator.encode(img_full)
        # corr_volume = torch.einsum('bid,bjd->bij', self.norm(x),self.norm(pos_emb_src+l_s)) * self.scale
        if img is None:
            img = self.down(img_full)
        b,_,h,w = img.shape

        if not prior_only:
            ### calculating structure correlation volume
            kp_s = kp2gaussian(kp_s['kp'], (h,w), 0.1) + self.pos_embedding
            kp_d = kp2gaussian(kp_d['kp'], (h,w), 0.1) + self.pos_embedding
            fe_s = self.kp_img(torch.cat([kp_s,img],dim=1))
            fe_d = self.kp(kp_d)
            k_s = self.kp_img_head(fe_s)
            q_d = self.kp_head(fe_d)
            f_s = rearrange(k_s, 'b c h w -> b (h w) c', h=self.h, w=self.w)
            f_d = rearrange(q_d, 'b c h w -> b (h w) c', h=self.h, w=self.w)
            corr_volume = torch.einsum('bic,bjc->bij',f_d,f_s) * self.scale
            ### calculating structure correlation volume
           
            ### prior motion initialization
            id_grid = coords_grid(b, self.h, self.w, corr_volume.device)
            init_flow = (h-1)*(dense_motion['deformation'].permute(0,3,1,2)+1) / 2.0 - id_grid
            init_occlusion = dense_motion['occlusion']
            ### prior motion initialization

            ### SCORR only initialization
            if scorr_only:
                id_grid = coords_grid(b, self.h, self.w, corr_volume.device)
                id_grid = rearrange(id_grid, 'b c h w -> b (h w) c', h=self.h, w=self.w)
                init_flow = torch.einsum('bij,bjc->bic', corr_volume.softmax(-1), id_grid)
                init_flow = rearrange(init_flow, 'b (h w) c-> b c h w', h=self.h, w=self.w)
                id_grid = rearrange(id_grid, 'b (h w) c -> b c h w', h=self.h, w=self.w)
                init_flow = init_flow - id_grid
                init_occlusion = None
            ### SCORR only initialization

            flow = F.interpolate(init_flow, scale_factor=1.0/(2**3),mode='bilinear',align_corners=True) / 8.0
            occlusion = F.interpolate(init_occlusion, scale_factor=1.0/(2**3),mode='bilinear',align_corners=True) if init_occlusion is not None else None
            corr_volume = rearrange(corr_volume, 'b (h w) n -> (b n) h w', h=self.h, w=self.w).unsqueeze(1)

        out_warp_f = []
        out_occlusion = []

        ### prior only
        if prior_only:
            flow = dense_motion['deformation']
            occlusion = dense_motion['occlusion']
            for i in range(self.num_iter):
                if flow.shape[2] != feature[i].shape[2]:
                    flow_res = F.interpolate(flow.permute(0,3,1,2), size=feature[i].shape[2:], mode='bilinear',align_corners=True)
                    occlusion_res = F.interpolate(occlusion, size=feature[i].shape[2:], mode='bilinear',align_corners=True)
                else:
                    flow_res = flow.permute(0,3,1,2)
                    occlusion_res = occlusion
                out_warp_f.append(F.grid_sample(feature[i], flow_res.permute(0,2,3,1)))
                out_occlusion.append(F.sigmoid(occlusion_res))
            warp_img = F.grid_sample(img_full, flow_res.permute(0,2,3,1))
            out = self.generator.decode(out_warp_f, warp_img, out_occlusion)
            for i in range(len(out_occlusion)):
                out_occlusion[i] = F.interpolate(out_occlusion[i], size=self.size, mode='bilinear',align_corners=True) 
            occlusion = torch.cat(out_occlusion,dim=3)
            return out, warp_img, occlusion
        ### prior only


        ### flow refinements
        for i in range(self.num_iter):
            id_grid = coords_grid(b, self.size//128 * (2 ** (i+2)),self.size//128 * (2 ** (i+2)), corr_volume.device)
            flow_sample = flow
            id_grid_sample = id_grid
            if i < 3:
                corr_volume_res = F.avg_pool2d(corr_volume, 2**(3-i), stride=2**(3-i))
                scale = 2**(3-i)
            elif i == 3:
                corr_volume_res = corr_volume
                scale = 1
            elif i > 3:
                corr_volume_res = corr_volume
                scale = 0.5**(i-3)
                flow_sample = F.interpolate(flow, size=self.h, mode='bilinear',align_corners=True) * scale
                id_grid_sample = coords_grid(b, self.h,self.w, corr_volume.device)
                scale = 1
            corr_volume_res = rearrange(corr_volume_res, '(b n) c h w -> (b h w) c n', n=self.h*self.w)
            corr_volume_res = rearrange(corr_volume_res, 'b c (p q) -> b c p q', p=self.h,q=self.w)
            corr_fn = CorrBlock(corr_volume_res)
            corr = corr_fn((flow_sample+id_grid_sample)*scale)
            if i > 3:
                corr = F.interpolate(corr, size=flow.shape[2], mode='bilinear',align_corners=True)
            m_f = self.corr_enc(flow, corr)
            warp_f = bilinear_sampler(feature[i], (flow+id_grid).permute(0,2,3,1))
            warp_f = F.relu(self.to_context[i](warp_f))
            d_flow, mf= self.refine(m_f,warp_f)
            flow_w = flow + d_flow[:,0:2,:,:]
            d_occ = d_flow[:,2:,:,:]
            occlusion = occlusion + d_occ
            out_flow.append(flow_w+id_grid)
            out = bilinear_sampler(feature[i], (flow_w+id_grid).permute(0,2,3,1))
            out_occlusion.append(F.sigmoid(occlusion))
            out_warp_f.append(out)
           
            # ## for init flow warping
            # if i!= 3:
            #     flow_res = F.interpolate(dense_motion['deformation'].permute(0,3,1,2), size=feature[i].shape[2:], mode='bilinear',align_corners=True)
            #     occlusion_res = F.interpolate(dense_motion['occlusion'], size=feature[i].shape[2:], mode='bilinear',align_corners=True)
            # else:
            #     flow_res = dense_motion['deformation'].permute(0,3,1,2)
            #     occlusion_res = dense_motion['occlusion']
            # out_warp_f_c.append(F.grid_sample(feature[i], flow_res.permute(0,2,3,1)))
            # out_occlusion_c.append(F.sigmoid(occlusion_res))
            # ## for init flow warping

            if i != self.num_iter-1:
                if i <= 3:
                    scale = 2 ** (3-i) / 2.0
                else:
                    scale = 0.5 ** (i-3) / 2.0
                d_f = F.interpolate(d_flow[:,0:2,:,:], scale_factor=2, mode='bilinear',align_corners=True) * 2
                flow = d_f + F.interpolate(init_flow, size=self.size//128 * (2 ** (i+3)), mode='bilinear',align_corners=True) / scale
                if i ==0:
                    d_f_pre = d_f
                else:
                    flow = flow + F.interpolate(d_f_pre, scale_factor=2, mode='bilinear',align_corners=True) * 2
                    d_f_pre = d_f + F.interpolate(d_f_pre, scale_factor=2, mode='bilinear',align_corners=True) * 2
                d_occ = F.interpolate(d_occ, scale_factor=2, mode='bilinear',align_corners=True) 
                occlusion = d_occ + F.interpolate(init_occlusion, size=self.size//128 * (2 ** (i+3)), mode='bilinear',align_corners=True)
                if i ==0:
                    d_occ_pre = d_occ
                else:
                    occlusion = occlusion + F.interpolate(d_occ_pre, scale_factor=2, mode='bilinear',align_corners=True)
                    d_occ_pre = d_occ + F.interpolate(d_occ_pre, scale_factor=2, mode='bilinear',align_corners=True)
        ### flow refinements   
        
        warp_img = bilinear_sampler(img_full, (flow+id_grid).permute(0,2,3,1))
        out = self.generator.decode(out_warp_f, warp_img, out_occlusion)
        # out = self.generator.decode(out_warp_f, warp_img, out_occlusion, out_warp_f_c, out_occlusion_c)
        out_occlusion.append(F.sigmoid(init_occlusion))
        out_occlusion_vis = []
        for i in range(len(out_occlusion)):
            out_occlusion_vis.append(F.interpolate(out_occlusion[i], size=self.size, mode='bilinear',align_corners=True))
        occlusion = torch.cat(out_occlusion_vis,dim=3)
        return out, warp_img, occlusion
