"""
Geometry to Semantics Data Format Converter

This module handles the conversion between geometry model outputs and semantics model input formats.
Key transformations:
- Extrinsics matrices to pose matrices 
- Depth map format and resolution conversion
- Intrinsics matrix format conversion
"""

import torch
import torch.nn.functional as F
import numpy as np
from typing import Dict, Tuple, Optional, Union


class GeometryToSemanticsConverter:
    """
    Converts geometry model outputs to semantics model compatible input format.
    
    Geometry model outputs:
    - images: [B, S, 3, H, W] RGB images (typically 518x518)
    - depth: [B, S, H, W, 1] depth maps 
    - extrinsics: [B, S, 3, 4] camera extrinsics (camera-to-world)
    - intrinsics: [B, S, 3, 3] camera intrinsics
    
    Semantics model expects:
    - images: [B, V, 3, H, W] RGB images (336x336)
    - depths: [B, V, H, W] depth maps
    - poses: [B, V, 4, 4] camera poses (world-to-camera)
    - intrinsics: [B, V, 4, 4] camera intrinsics
    """
    
    def __init__(self, target_resolution: int = 336):
        """
        Initialize converter with target resolution for semantics model.
        
        Args:
            target_resolution: Target image resolution for semantics model (default: 336)
        """
        self.target_resolution = target_resolution
        
    def convert_extrinsics_to_poses(self, extrinsics: torch.Tensor) -> torch.Tensor:
        """
        Convert geometry extrinsics to semantics pose format.
        
        Geometry extrinsics are [R|t] matrices in camera-to-world format (3x4).
        Semantics model expects world-to-camera pose matrices (4x4).
        
        Args:
                        extrinsics: [B, S, 3, 4] extrinsics matrices from geometry model
             
        Returns:
            poses: [B, S, 4, 4] pose matrices for semantics model
        """
        B, S, _, _ = extrinsics.shape
        device = extrinsics.device
        dtype = extrinsics.dtype
        
        # Convert 3x4 extrinsics to 4x4 homogeneous matrices
        homogeneous_row = torch.tensor([0., 0., 0., 1.], device=device, dtype=dtype)
        homogeneous_row = homogeneous_row.view(1, 1, 1, 4).expand(B, S, 1, 4)
        
        # Create 4x4 camera-to-world matrices
        cam_to_world = torch.cat([extrinsics, homogeneous_row], dim=2)  # [B, S, 4, 4]
        
        # Invert to get world-to-camera poses
        # Handle batch matrix inversion
        poses = torch.zeros_like(cam_to_world)
        for b in range(B):
            for s in range(S):
                poses[b, s] = torch.inverse(cam_to_world[b, s])
        
        return poses
        
    def convert_depth_format(self, depth: torch.Tensor) -> torch.Tensor:
        """
        Convert geometry depth format to semantics format.
        
        Args:
            depth: [B, S, H, W, 1] depth maps from geometry model
            
        Returns:
            depths: [B, S, target_H, target_W] depth maps for semantics model
        """
        # Remove last dimension: [B, S, H, W, 1] -> [B, S, H, W]
        depth = depth.squeeze(-1)
        
        # Resize to target resolution
        B, S, H, W = depth.shape
        depth = depth.view(B * S, 1, H, W)  # Reshape for interpolation
        depth = F.interpolate(
            depth, 
            size=(self.target_resolution, self.target_resolution),
            mode='bilinear', 
            align_corners=False
        )
        depth = depth.view(B, S, self.target_resolution, self.target_resolution)
        
        return depth
        
    def convert_intrinsics_format(self, intrinsics: torch.Tensor, 
                                original_resolution: int = 518) -> torch.Tensor:
        """
        Convert geometry intrinsics to semantics format.
        
        Args:
            intrinsics: [B, S, 3, 3] intrinsics matrices from geometry model
            original_resolution: Original image resolution from geometry model
            
        Returns:
            intrinsics_4x4: [B, S, 4, 4] intrinsics matrices for semantics model
        """
        B, S, _, _ = intrinsics.shape
        device = intrinsics.device
        dtype = intrinsics.dtype
        
        # Scale intrinsics for resolution change
        scale_factor = self.target_resolution / original_resolution
        scaled_intrinsics = intrinsics.clone()
        
        # Scale focal lengths and principal points
        scaled_intrinsics[:, :, 0, 0] *= scale_factor  # fx
        scaled_intrinsics[:, :, 1, 1] *= scale_factor  # fy
        scaled_intrinsics[:, :, 0, 2] *= scale_factor  # cx
        scaled_intrinsics[:, :, 1, 2] *= scale_factor  # cy
        
        # Convert 3x3 to 4x4 format
        intrinsics_4x4 = torch.zeros(B, S, 4, 4, device=device, dtype=dtype)
        intrinsics_4x4[:, :, :3, :3] = scaled_intrinsics
        intrinsics_4x4[:, :, 3, 3] = 1.0
        
        return intrinsics_4x4
        
    def resize_images(self, images: torch.Tensor) -> torch.Tensor:
        """
        Resize RGB images to target resolution.
        
        Args:
            images: [B, S, 3, H, W] RGB images from geometry model
            
        Returns:
            resized_images: [B, S, 3, target_H, target_W] resized images
        """
        B, S, C, H, W = images.shape
        images = images.view(B * S, C, H, W)
        
        resized_images = F.interpolate(
            images,
            size=(self.target_resolution, self.target_resolution),
            mode='bilinear',
            align_corners=False
        )
        
        resized_images = resized_images.view(B, S, C, self.target_resolution, self.target_resolution)
        return resized_images
        
    def convert(self, geometry_predictions: Dict[str, torch.Tensor], 
                original_resolution: int = 518) -> Dict[str, torch.Tensor]:
        """
        Main conversion function that transforms geometry outputs to semantics format.
        
        Args:
            geometry_predictions: Dictionary containing geometry model outputs
                - 'images': [B, S, 3, H, W] RGB images
                - 'depth': [B, S, H, W, 1] depth maps
                - 'extrinsic': [B, S, 3, 4] extrinsics matrices  
                - 'intrinsic': [B, S, 3, 3] intrinsics matrices
            original_resolution: Original image resolution from geometry model
            
        Returns:
            semantics_inputs: Dictionary containing semantics model compatible inputs
                - 'images': [B, S, 3, target_H, target_W] resized RGB images
                - 'depths': [B, S, target_H, target_W] resized depth maps
                - 'poses': [B, S, 4, 4] pose matrices
                - 'intrinsics': [B, S, 4, 4] intrinsics matrices
        """
        # Validate inputs
        required_keys = ['images', 'depth', 'extrinsic', 'intrinsic']
        for key in required_keys:
            if key not in geometry_predictions:
                raise ValueError(f"Missing required key '{key}' in geometry_predictions")
                
        # Extract data
        images = geometry_predictions['images']
        depth = geometry_predictions['depth'] 
        extrinsics = geometry_predictions['extrinsic']
        intrinsics = geometry_predictions['intrinsic']
        
        # Perform conversions
        converted_images = self.resize_images(images)
        converted_depths = self.convert_depth_format(depth)
        converted_poses = self.convert_extrinsics_to_poses(extrinsics)
        converted_intrinsics = self.convert_intrinsics_format(intrinsics, original_resolution)
        
        return {
            'images': converted_images,
            'depths': converted_depths, 
            'poses': converted_poses,
            'intrinsics': converted_intrinsics
        }
        
    def validate_conversion(self, geometry_predictions: Dict[str, torch.Tensor],
                          semantics_inputs: Dict[str, torch.Tensor]) -> bool:
        """
        Validate that the conversion was successful.
        
        Args:
            geometry_predictions: Original geometry outputs
            semantics_inputs: Converted semantics inputs
            
        Returns:
            is_valid: Whether the conversion is valid
        """
        try:
            # Check batch and sequence dimensions match
            B_orig, S_orig = geometry_predictions['images'].shape[:2]
            B_conv, S_conv = semantics_inputs['images'].shape[:2]
            
            if B_orig != B_conv or S_orig != S_conv:
                return False
                
            # Check target resolution
            target_h, target_w = semantics_inputs['images'].shape[-2:]
            if target_h != self.target_resolution or target_w != self.target_resolution:
                return False
                
            # Check pose matrices are valid (determinant should be ±1 for rotation part)
            poses = semantics_inputs['poses']
            rotation_parts = poses[:, :, :3, :3]
            dets = torch.det(rotation_parts)
            
            if not torch.allclose(torch.abs(dets), torch.ones_like(dets), atol=1e-3):
                return False
                
            return True
            
        except Exception:
            return False 