import torch 
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

class SSIM(nn.Module):
    
    """
    Layer to compute the SSIM loss between a pair of images 
    Adapted from: https://github.com/nianticlabs/monodepth2/tree/master

    """
    def __init__(self):
        super(SSIM, self).__init__()
        self.mu_x_pool   = nn.AvgPool2d(3, 1)
        self.mu_y_pool   = nn.AvgPool2d(3, 1)
        self.sig_x_pool  = nn.AvgPool2d(3, 1)
        self.sig_y_pool  = nn.AvgPool2d(3, 1)
        self.sig_xy_pool = nn.AvgPool2d(3, 1)

        self.refl = nn.ReflectionPad2d(1)

        self.C1 = 0.01 ** 2
        self.C2 = 0.03 ** 2

    def forward(self, x, y):
        x = self.refl(x)
        y = self.refl(y)

        mu_x = self.mu_x_pool(x)
        mu_y = self.mu_y_pool(y)

        sigma_x  = self.sig_x_pool(x ** 2) - mu_x ** 2
        sigma_y  = self.sig_y_pool(y ** 2) - mu_y ** 2
        sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y

        SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
        SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2)

        return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)
    


def get_smooth_loss(disp, img):
    """Computes the smoothness loss for a disparity image
    The color image is used for edge-aware smoothness
    Adapted from: https://github.com/nianticlabs/monodepth2/tree/master
    """
    grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:])
    grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])

    grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
    grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)

    grad_disp_x *= torch.exp(-grad_img_x)
    grad_disp_y *= torch.exp(-grad_img_y)

    return grad_disp_x.mean() + grad_disp_y.mean()


def disp_to_depth(disp, min_depth, max_depth):
    """Convert network's sigmoid output into depth prediction
    Adapted from: https://github.com/nianticlabs/monodepth2/tree/master
    """
    min_disp = 1 / max_depth
    max_disp = 1 / min_depth
    scaled_disp = min_disp + (max_disp - min_disp) * disp
    depth = 1 / scaled_disp
    return scaled_disp, depth
    
    

class FGVONetLoss_depth(nn.Module):
    def __init__(self):
        super().__init__()
        self.w_x = nn.Parameter(torch.tensor(0.0))
        self.w_q = nn.Parameter(torch.tensor(-2.5))
        self.w_d = nn.Parameter(torch.tensor(0.0))
        self.layer_weights = [0.2, 0.4, 0.8, 1.6]
        self.ssim = SSIM()

    def normalize_quat(self, q, eps=1e-10):
        norm = torch.norm(q, dim=-1, keepdim=True).clamp_min(eps)
        return q / norm
    
    def normalize_feat(self, feat):
        mean = feat.mean(dim=[2,3], keepdim=True)
        std  = feat.std(dim=[2,3], keepdim=True).clamp_min(1e-8)
        z_score = (feat - mean) / std
        return torch.sigmoid(z_score)
    
    def quat2mat(self, q):
        """
        q: (B,4) tensor in [w, x, y, z] format, not necessarily normalized
        returns R: (B,3,3) rotation matrices
        https://afni.nimh.nih.gov/pub/dist/src/pkundu/meica.libs/nibabel/quaternions.py
        """

        w, x, y, z = q.unbind(dim=1)           

        Nq = w*w + x*x + y*y + z*z             
        Nq = Nq.clamp_min(1e-8)                

        s = 2.0 / Nq                           

        X = x * s                              
        Y = y * s
        Z = z * s

        wX = w * X; wY = w * Y; wZ = w * Z
        xX = x * X; xY = x * Y; xZ = x * Z
        yY = y * Y; yZ = y * Z; zZ = z * Z

        row0 = torch.stack([
            1.0 - (yY + zZ),
            xY - wZ,
            xZ + wY
        ], dim=1)  # (B,3)

        row1 = torch.stack([
            xY + wZ,
            1.0 - (xX + zZ),
            yZ - wX
        ], dim=1)

        row2 = torch.stack([
            xZ - wY,
            yZ + wX,
            1.0 - (xX + yY)
        ], dim=1)

        R = torch.stack([row0, row1, row2], dim=1)  
        return R
    
    def invert_pose(self, q, t):

        q_inv = q.clone()
        q_inv[:, 1:] *= -1

        R = self.quat2mat(q)
        t_inv = -torch.bmm(R.transpose(1,2), t.unsqueeze(-1)).squeeze(-1)

        return q_inv, t_inv
    
    def build_intrinsics_pyramid(self, K, num_levels=4):

        is_batched = (K.dim() == 3)
        if K.dim() == 2:
            K = K.unsqueeze(0)

        B = K.size(0)
        Ks = []
        for l in range(num_levels):
            scale = 2.0 ** (l+1)

            K_l = K.clone().float()

            K_l[:, 0, 0] /= scale
            K_l[:, 1, 1] /= scale
            K_l[:, 0, 2] /= scale
            K_l[:, 1, 2] /= scale
            Ks.append(K_l)

        if not is_batched:
            Ks = [K_l.squeeze(0) for K_l in Ks]
        return Ks

    def pose_loss_cal(self,
                l0_q, l0_t,
                l1_q, l1_t,
                l2_q, l2_t,
                l3_q, l3_t,
                q_gt, t_gt):
        losses = []
        raw_q_losses = []
        raw_t_losses = []

        for (l_q, l_t), weight in zip(
            [(l0_q, l0_t), (l1_q, l1_t),
             (l2_q, l2_t), (l3_q, l3_t)],
            self.layer_weights
        ):
            l_q_norm = self.normalize_quat(l_q)
 
            loss_q = ((q_gt - l_q_norm).pow(2).sum(dim=-1)+ 1e-10).mean()                          

            loss_x = (l_t - t_gt).norm(p=1, dim=-1).mean()  

            raw_q_losses.append(weight*loss_q)
            raw_t_losses.append(weight*loss_x)

            level_loss = (
                loss_x * torch.exp(-self.w_x) + self.w_x +
                loss_q * torch.exp(-self.w_q) + self.w_q
            )
            # level_loss = loss_x + loss_q*10
            losses.append(weight * level_loss)

        total_loss = sum(losses)
        total_raw_q_loss = sum(raw_q_losses)
        total_raw_t_loss = sum(raw_t_losses)

        return total_loss, total_raw_q_loss, total_raw_t_loss
    
    def photometric_loss_cal(self, feat1, feat2, depth, flow_mask, K, q, t, alpha=0.85):

        B, C, H, W = feat1.shape
        device = feat1.device
        dtype = feat1.dtype
        # _, depth = disp_to_depth(disp, min_depth, max_depth)

        u = torch.arange(W, device=device, dtype=dtype).view(1,1,1,W).expand(B,1,H,W)
        v = torch.arange(H, device=device, dtype=dtype).view(1,1,H,1).expand(B,1,H,W)
        ones = torch.ones_like(u)
        pix = torch.cat([u, v, ones], dim=1).view(B, 3, -1)  # (B,3,H*W)

        # Back-project to camera coords: X = z * K^-1 * pix
        depth_flat = depth.view(B, 1, -1).expand(B, 3, -1)   # (B, 3, H*W)
        K = K.to(dtype=dtype, device=device)
        K_inv = torch.inverse(K)                              # (B, 3, 3)
        cam_pts = K_inv.bmm(pix) * depth_flat                # (B, 3, H*W)

        # Transform to frame2: X2 = R @ X + t
        q_norm = self.normalize_quat(q)
        R = self.quat2mat(q_norm)
        t = t.view(B, 3, 1)               
        cam2 = R.bmm(cam_pts) + t    

        # Project back to pixel plane
        proj = K.bmm(cam2)                                   # (B,3,H*W)
        x2 = proj[:, 0] / (proj[:, 2] + 1e-6)
        y2 = proj[:, 1] / (proj[:, 2] + 1e-6)
        x2_norm = 2.0 * (x2.view(B, H, W) / (W - 1)) - 1.0
        y2_norm = 2.0 * (y2.view(B, H, W) / (H - 1)) - 1.0
        grid = torch.stack([x2_norm, y2_norm], dim=-1)       # (B,H,W,2)

        feat2_warp = F.grid_sample(feat2, grid, mode='bilinear',
                                padding_mode='zeros', align_corners=True)

        # feat2_warp_norm = self.normalize_feat(feat2_warp)
        # feat1_norm = self.normalize_feat(feat1)

        l1 = (feat2_warp - feat1).abs().mean(dim=1, keepdim=True)

        ssim_map = self.ssim(feat2_warp, feat1)

        photo_map = alpha * ssim_map + (1 - alpha) * l1

        valid_pixels = flow_mask.sum() + 1e-8
        photo_loss = (flow_mask * photo_map).sum() / valid_pixels

        # mean_disp = disp.mean(2, True).mean(3, True)
        # norm_disp = disp / (mean_disp + 1e-7)
        # smooth_loss = get_smooth_loss(norm_disp, feat1)

        #loss = (smooth_loss + photo_loss) * torch.exp(-self.w_d) + self.w_d

        return photo_loss
    
    def forward(self, l0_q, l0_t, l1_q, l1_t, l2_q, l2_t, l3_q, l3_t, K, depth_t1, depth_t2, img_t1, img_t2, t_gt, q_gt, \
                flow_mask_forward, flow_mask_backward):
        
        l0_q_inv, l0_t_inv = self.invert_pose(l0_q, l0_t)
        
        pose_loss, raw_q_loss, raw_t_loss = self.pose_loss_cal(l0_q, l0_t, l1_q, l1_t, l2_q, l2_t, l3_q, l3_t, q_gt, t_gt)

        loss_layer_forward = self.photometric_loss_cal(img_t1, img_t2, depth_t1, flow_mask_forward, K, l0_q, l0_t)

        loss_layer_backward = self.photometric_loss_cal(img_t1, img_t2, depth_t2, flow_mask_backward, K, l0_q_inv, l0_t_inv)

        depth_loss = (loss_layer_forward + loss_layer_backward) * torch.exp(-self.w_d) + self.w_d

        loss = 0.1 * depth_loss + pose_loss

        return loss, raw_q_loss, raw_t_loss




