import torch
import math
import torch.nn as nn
import torch.nn.functional as F

def normalize_quat(q, eps=1e-10):
    norm = torch.norm(q, dim=-1, keepdim=True).clamp_min(eps)
    return q / norm

def quat2mat(q):
    """
    q: (B,4) tensor in [w, x, y, z] format, not necessarily normalized
    returns R: (B,3,3) rotation matrices
    https://afni.nimh.nih.gov/pub/dist/src/pkundu/meica.libs/nibabel/quaternions.py
    """

    w, x, y, z = q.unbind(dim=1)           

    Nq = w*w + x*x + y*y + z*z             
    Nq = Nq.clamp_min(1e-8)                

    s = 2.0 / Nq                           

    X = x * s                              
    Y = y * s
    Z = z * s

    wX = w * X; wY = w * Y; wZ = w * Z
    xX = x * X; xY = x * Y; xZ = x * Z
    yY = y * Y; yZ = y * Z; zZ = z * Z

    row0 = torch.stack([
        1.0 - (yY + zZ),
        xY - wZ,
        xZ + wY
    ], dim=1)  # (B,3)

    row1 = torch.stack([
        xY + wZ,
        1.0 - (xX + zZ),
        yZ - wX
    ], dim=1)

    row2 = torch.stack([
        xZ - wY,
        yZ + wX,
        1.0 - (xX + yY)
    ], dim=1)

    R = torch.stack([row0, row1, row2], dim=1)  
    return R

def build_intrinsics_pyramid(self, K, num_levels=4):

    is_batched = (K.dim() == 3)
    if K.dim() == 2:
        K = K.unsqueeze(0)

    B = K.size(0)
    Ks = []
    for l in range(num_levels):
        scale = 2.0 ** (l+1)

        K_l = K.clone().float()

        K_l[:, 0, 0] /= scale
        K_l[:, 1, 1] /= scale
        K_l[:, 0, 2] /= scale
        K_l[:, 1, 2] /= scale
        Ks.append(K_l)

    if not is_batched:
        Ks = [K_l.squeeze(0) for K_l in Ks]
    return Ks

def sample_features_from_coords(target_coords, num_samples=32):

    B, H, W, _ = target_coords.shape
    
    offsets = torch.stack(torch.meshgrid(
        torch.arange(-3, 4, dtype=target_coords.dtype, device=target_coords.device),
        torch.arange(-3, 4, dtype=target_coords.dtype, device=target_coords.device),
        indexing="ij"), dim=-1)   
    
    offsets = offsets.reshape(-1, 2) # [C,2] where C=(2*3+1)^2
    
    cand = target_coords.unsqueeze(-2) + offsets.view(1, 1, 1, -1, 2)  # [B,H,W,C,2]

    cand_x = cand[..., 0].clamp(0, W - 1)
    cand_y = cand[..., 1].clamp(0, H - 1)
    cand   = torch.stack([cand_x, cand_y], dim=-1) # [B,H,W,C,2]
    
    distance_diff  = cand - target_coords.unsqueeze(-2) 
    distance_squre = (distance_diff ** 2).sum(dim=-1)

    _, idx = torch.topk(distance_squre, k=num_samples, dim=-1, largest=False) #[B,H,W,n]
    idx_exp = idx.unsqueeze(-1).expand(-1, -1, -1, -1, 2) #[B,H,W,K,2]
    neigh_xy = torch.gather(cand, -2, idx_exp) #[B,H,W,K,2]

    return neigh_xy

def quat_mul(q, r):
    w1, x1, y1, z1 = q.unbind(-1)
    w2, x2, y2, z2 = r.unbind(-1)
    w = w1*w2 - x1*x2 - y1*y2 - z1*z2
    x = w1*x2 + x1*w2 + y1*z2 - z1*y2
    y = w1*y2 - x1*z2 + y1*w2 + z1*x2
    z = w1*z2 + x1*y2 - y1*x2 + z1*w2
    return torch.stack((w, x, y, z), dim=-1)


def quat_inv(q, eps=1e-10):
    q_conj = q.clone()
    q_conj[:, 1:] = -q_conj[:, 1:]
    norm_sq = q.pow(2).sum(-1, keepdim=True).clamp_min(eps)
    return q_conj / norm_sq


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, bias=True):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

    def forward(self, x, mask=None):
        B, N, C = x.shape

        q = self.q_proj(x) 
        k = self.k_proj(x)
        v = self.v_proj(x)
        
        q = q.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) 
        k = k.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        
        attn_logits = torch.matmul(q, k.transpose(-2, -1))  
        attn_logits = attn_logits / (self.head_dim ** 0.5)
        
        if mask is not None:
            attn_logits = attn_logits.masked_fill(mask == 0, float('-inf'))
        
        attn = F.softmax(attn_logits, dim=-1)  
        
        out = torch.matmul(attn, v) 
        out = out.transpose(1, 2).contiguous().view(B, N, C)  
        out = self.out_proj(out)  
        return out

class FGVONetLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.w_x = nn.Parameter(torch.tensor(0.0))
        self.w_q = nn.Parameter(torch.tensor(-2.5))
        self.layer_weights = [0.2, 0.4, 0.8, 1.6]

    def normalize_quat(self, q, eps=1e-10):
        norm = torch.norm(q, dim=-1, keepdim=True).clamp_min(eps)
        return q / norm

    def forward(self,
                l0_q, l0_t,
                l1_q, l1_t,
                l2_q, l2_t,
                l3_q, l3_t,
                q_gt, t_gt):
        losses = []
        raw_q_losses = []
        raw_t_losses = []

        for (l_q, l_t), weight in zip(
            [(l0_q, l0_t), (l1_q, l1_t),
             (l2_q, l2_t), (l3_q, l3_t)],
            self.layer_weights
        ):
            l_q_norm = self.normalize_quat(l_q)
 
            loss_q = ((q_gt - l_q_norm).pow(2).sum(dim=-1)+ 1e-10).mean()                          

            loss_x = (l_t - t_gt).norm(p=1, dim=-1).mean()  

            raw_q_losses.append(weight*loss_q)
            raw_t_losses.append(weight*loss_x)

            level_loss = (
                loss_x * torch.exp(-self.w_x) + self.w_x +
                loss_q * torch.exp(-self.w_q) + self.w_q
            )
            # level_loss = loss_x + loss_q*10
            losses.append(weight * level_loss)

        total_loss = sum(losses)
        total_raw_q_loss = sum(raw_q_losses)
        total_raw_t_loss = sum(raw_t_losses)

        return total_loss, total_raw_q_loss, total_raw_t_loss             


class WindowAttnPool(nn.Module):

    def __init__(self, dim, window_size=8, num_heads=4):
        super().__init__()
        self.win   = window_size
        self.dim   = dim
        self.norm  = nn.LayerNorm(dim)
        self.attn  = nn.MultiheadAttention(dim, num_heads,
                                           batch_first=True)
        self.pos_emb = nn.Parameter(
            torch.zeros(1, window_size * window_size, dim))
        
        nn.init.trunc_normal_(self.pos_emb, std=0.02)

    def forward(self, x):

        B,C,H,W = x.shape
        w = self.win

        x_in = x

        x = x.view(B, C, H//w, w, W//w, w)           # [B,C,H/w,w,W/w,w]
        x = x.permute(0, 2, 4, 3, 5, 1).contiguous()  # [B,H/w,W/w,w,w,C]
        seq = x.view(-1, w*w, C)                      # [B*N_win, L, C]

        seq = seq + self.pos_emb                      
        seq = self.norm(seq)
        seq, _ = self.attn(seq, seq, seq)             

        seq = seq.view(B, H//w, W//w, w, w, C)
        x_rec = seq.permute(0,5,1,3,2,4).contiguous()
        x_rec = x_rec.view(B, C, H, W)

        out = x_in + x_rec                      # [B,C]
        return out
    
def build_intrinsics_pyramid(K, num_levels=4):

    is_batched = (K.dim() == 3)
    if K.dim() == 2:
        K = K.unsqueeze(0)

    B = K.size(0)
    Ks = []
    for l in range(num_levels):
        scale = 2.0 ** (l+1)

        K_l = K.clone().float()

        K_l[:, 0, 0] /= scale
        K_l[:, 1, 1] /= scale
        K_l[:, 0, 2] /= scale
        K_l[:, 1, 2] /= scale
        Ks.append(K_l)

    if not is_batched:
        Ks = [K_l.squeeze(0) for K_l in Ks]
    return Ks


class FlowGuidedAttention(nn.Module):

    def __init__(self, nsample_coarse=32, nsample_fine=9, C_in=128, mlp1=(128, 64, 64), mlp2=(128, 64),
                 mlp1_local=(64, 64), mlp2_local=(64,64), geo_hiden=64):
        super().__init__()

        self.nsample_coarse = nsample_coarse
        self.nsample_fine = nsample_fine
        # MLP_layer1_global
        layers1 = []
        in_ch = 3 + 2*C_in
        for ch in mlp1:                    
            layers1 += [nn.Conv2d(in_ch, ch, 1), nn.BatchNorm2d(ch), nn.ReLU(inplace=True)]
            in_ch = ch
        self.mlp1 = nn.Sequential(*layers1)

        # Encoder_layer_global
        self.geo_enc1 = nn.Conv2d(3, geo_hiden, 1)

        # MLP_layer2_global
        layers2 = []
        in_ch = mlp1[-1] + geo_hiden
        for ch in mlp2:                       
            layers2 += [nn.Conv2d(in_ch, ch, 1), nn.BatchNorm2d(ch), nn.ReLU(inplace=True)]
            in_ch = ch
        self.mlp2 = nn.Sequential(*layers2)

        # Weight_layer_global
        self.global_weight = nn.Conv2d(mlp2[-1], 1, 1)

        # offset predictor global
        C_in_offset_global = C_in*2
        self.offset_predictor_global = nn.Conv2d(C_in_offset_global, self.nsample_coarse*2, kernel_size=3, padding=1)

        # MLP_layer1_local
        layers1_local = []
        in_ch = 3 + 2*mlp1[-1]
        for ch in mlp1_local:                    
            layers1_local += [nn.Conv2d(in_ch, ch, 1), nn.BatchNorm2d(ch), nn.ReLU(inplace=True)]
            in_ch = ch
        self.mlp1_local = nn.Sequential(*layers1_local)

        # Encoder_layer_local
        self.geo_enc1_local = nn.Conv2d(3, geo_hiden, 1)

        # MLP_layer2_local
        layers2_local = []
        in_ch = mlp1_local[-1] + geo_hiden
        for ch in mlp2_local:                       
            layers2_local += [nn.Conv2d(in_ch, ch, 1), nn.BatchNorm2d(ch), nn.ReLU(inplace=True)]
            in_ch = ch
        self.mlp2_local = nn.Sequential(*layers2_local)

        # Weight_layer_local
        self.global_weight_local = nn.Conv2d(mlp2_local[-1], 1, 1)

        # offset predictor local
        self.offset_predictor_local = nn.Conv2d(mlp1[-1], self.nsample_fine*2, kernel_size=3, padding=1)



    def forward(self, img1, img2, optical_flow12):
        B, C, H, W = img1.shape
        y_coords, x_coords = torch.meshgrid(
            torch.arange(H, dtype=torch.float32),
            torch.arange(W, dtype=torch.float32),
            indexing='ij'
        )

        # [B, H, W, 2]
        grid = torch.stack([x_coords, y_coords], dim=-1).unsqueeze(0).repeat(B, 1, 1, 1)
        grid = grid.type_as(img1)  

        target_coords = grid + optical_flow12.permute(0, 2, 3, 1)

        feat_global = torch.cat([img1, img2], dim=1)
        offset_global = self.offset_predictor_global(feat_global)
        offset_global = offset_global.view(B, self.nsample_coarse, 2, H, W).permute(0, 3, 4, 1, 2) #[B, H, W, K, 2]
        img2_sample_coords = target_coords.unsqueeze(3) + offset_global  # [B, H, W, K, 2]
        #img2_sample_coords = sample_features_from_coords(target_coords, num_samples=self.nsample_coarse) #[B,H,W,K,2]

        # Feats sampling from img2
        x_norm =  2.0 * img2_sample_coords[..., 0] / (W - 1) - 1.0
        y_norm =  2.0 * img2_sample_coords[..., 1] / (H - 1) - 1.0
        grid_norm    = torch.stack([x_norm, y_norm], dim=-1)  # [B,H,W,K,2]
        grid_norm = grid_norm.view(B, H, W * self.nsample_coarse, 2)

        img2_sample_feats_flat = F.grid_sample(
            img2, grid_norm,
            mode='bilinear',          
            padding_mode='border',   
            align_corners=True
        )

        img2_sample_feats = img2_sample_feats_flat.view(B, C, H, W, self.nsample_coarse).permute(0, 2, 3, 4, 1) # [B,H,W,K,d]

        img1_coords = grid.unsqueeze(3).repeat(1, 1, 1, self.nsample_coarse, 1)
        img1_feats = img1.permute(0, 2, 3, 1).unsqueeze(3).expand(-1, -1, -1, self.nsample_coarse, -1)# [B,H,W,K,d]

        # Global fusion
        coords_diff = img2_sample_coords - img1_coords  # [B,H,W,K,2]
        distance = torch.norm(coords_diff, dim=-1, keepdim=True) # [B,H,W,K,1]
        geo_information = torch.cat([coords_diff, distance], dim=-1) # [B,H,W,K,3]

        feats_information = torch.cat([img1_feats, img2_sample_feats], dim=-1) # [B,H,W,K,2d]

        feats = torch.cat([geo_information, feats_information], dim=-1) # [B,H,W,K,3+2d]

        BHW = B * H * W
        feat_conv = feats.view(BHW, self.nsample_coarse, -1).contiguous().transpose(1, 2).unsqueeze(-1)  # [B*H*W,3+2d,K,1]   
        feat_conv = self.mlp1(feat_conv)  # [B*H*W, mlp1[-1], K, 1]  

        geom_conv = geo_information.view(BHW, self.nsample_coarse, -1).contiguous().transpose(1, 2).unsqueeze(-1)  # [B*H*W,3,K,1]
        geom_conv = self.geo_enc1(geom_conv)

        feat_weight = torch.cat([geom_conv, feat_conv], dim=1)  
        feat_weight = self.mlp2(feat_weight)

        score = self.global_weight(feat_weight).squeeze(3)
        attention_score = torch.softmax(score, dim=-1).unsqueeze(-1) # [BHW, 1, K, 1]

        fused_feat = (attention_score * feat_conv).sum(dim=2).squeeze(-1).view(B, H, W, -1) # [B,H,W,mlp1[-1]]

        # Local fusion
        _, _, _, C_local = fused_feat.shape
        fused_feat_img1 = fused_feat.permute(0, 3, 1, 2) # [B,mlp1[-1],H,W]

        offset_local = self.offset_predictor_local(fused_feat_img1)
        offset_local = offset_local.view(B, self.nsample_fine, 2, H, W).permute(0, 3, 4, 1, 2) #[B, H, W, K, 2]
        target_coords_local = grid.unsqueeze(3) + offset_local #[B, H, W, K, 2]

        x_norm_local =  2.0 * target_coords_local[..., 0] / (W - 1) - 1.0
        y_norm_local =  2.0 * target_coords_local[..., 1] / (H - 1) - 1.0
        grid_norm_local    = torch.stack([x_norm_local, y_norm_local], dim=-1)  # [B,H,W,K,2]
        grid_norm_local = grid_norm_local.view(B, H, W * self.nsample_fine, 2)

        img1_sample_feats = F.grid_sample(
            fused_feat_img1, grid_norm_local,
            mode='bilinear',          
            padding_mode='border',   
            align_corners=True
        )

        img1_sample_feats = img1_sample_feats.view(B, C_local, H, W, self.nsample_fine).permute(0, 2, 3, 4, 1) # [B,H,W,K,d]
        # img1_sample_feats = F.unfold(fused_feat_img1, kernel_size=3, padding=1)
        # img1_sample_feats = img1_sample_feats.view(B, C_local, self.nsample_fine, H, W).permute(0, 3, 4, 2, 1) # [B,H,W,K,d]

        img1_local_feats = fused_feat_img1.permute(0, 2, 3, 1).unsqueeze(3).repeat(1, 1, 1, self.nsample_fine, 1) # [B,H,W,K,mlp1[-1]]

        distance_local = torch.norm(offset_local, dim=-1, keepdim=True) # [B,H,W,K,1]
        geo_information_local = torch.cat([target_coords_local, distance_local], dim=-1) # [B,H,W,K,3]

        feats_information_local = torch.cat([img1_local_feats, img1_sample_feats], dim=-1) # [B,H,W,K,2d]

        feats_local = torch.cat([geo_information_local, feats_information_local], dim=-1) # [B,H,W,K,3+2d]

        feat_local_conv = feats_local.view(BHW, self.nsample_fine, -1).contiguous().transpose(1, 2).unsqueeze(-1)  # [B*H*W,3+2d,K,1]   
        feat_local_conv = self.mlp1_local(feat_local_conv)  # [B*H*W, mlp1[-1], K, 1]  

        geom_conv_local = geo_information_local.view(BHW, self.nsample_fine, -1).contiguous().transpose(1, 2).unsqueeze(-1)  # [B*H*W,3,K,1]
        geom_conv_local = self.geo_enc1_local(geom_conv_local)

        feat_weight_local = torch.cat([geom_conv_local, feat_local_conv], dim=1)  
        feat_weight_local = self.mlp2_local(feat_weight_local)

        score_local = self.global_weight_local(feat_weight_local).squeeze(3)
        attention_score_local = torch.softmax(score_local, dim=-1).unsqueeze(-1) # [BHW, 1, K, 1]

        fused_feat_local_global = (attention_score_local * feat_local_conv).sum(dim=2).squeeze(-1).view(B, H, W, -1) # [B,H,W,mlp1[-1]]

        return fused_feat_local_global.permute(0, 3, 1, 2)


def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Running on: {device}')

    B, C, H, W = 2, 64, 192, 640
    img1 = torch.randn(B, C, H, W, device=device)
    img2 = torch.randn_like(img1)                       
    flow = torch.randn(B, 2, H, W, device=device)

    net = FlowGuidedAttention(C_in=C).to(device)

    with torch.no_grad():
        fused = net(img1, img2, flow)                   

    print('Output shape:', fused.shape)

if __name__ == '__main__':
    main()