import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import AutoencoderKL
from transformers import AutoImageProcessor, AutoModelForDepthEstimation

from src.models.larrp_unimodal import AdaptiveRankReducedLinear

class StrongRGBBranch(nn.Module):
    def __init__(self, latent_dim, model_name="stabilityai/sd-vae-ft-mse", freeze_vae=True, use_projections=True):
        """
        Adapts the 83M param Stable Diffusion VAE for 3-channel RGB images.
        Can load pretrained weights from stage 1 or stage 2 pretraining.
        
        Args:
            latent_dim: Dimension for projection layers
            model_name: HuggingFace model name for VAE
            freeze_vae: Whether to freeze VAE parameters
            use_projections: Whether to include linear projection layers
        """
        super().__init__()
        print(f"[StrongRGBBranch] Loading SD-VAE: {model_name}...")
        self.vae = AutoencoderKL.from_pretrained(model_name)
        self.use_projections = use_projections
        
        if freeze_vae:
            for param in self.vae.parameters():
                param.requires_grad = False
            self.vae.eval()
            
        # VAE downsamples by 8. Latent channels = 4.
        # We assume 256x256 input -> 32x32 feature map.
        self.spatial_dim = 32 
        self.vae_flat_dim = 4 * self.spatial_dim * self.spatial_dim
        print(f"[StrongRGBBranch] VAE latent flat dim: {self.vae_flat_dim}")
        
        # Trainable Projections (optional)
        if use_projections:
            self.to_latent = nn.Linear(self.vae_flat_dim, latent_dim)
            self.from_latent = nn.Linear(latent_dim, self.vae_flat_dim)
        else:
            self.to_latent = None
            self.from_latent = None
        
        self.scale_factor = 0.18215

    def encode(self, x):
        """Encode RGB image to latent space.
        Args:
            x: (B, 3, H, W) in [0, 1]
        Returns:
            z: (B, latent_dim) if use_projections else (B, 4, 32, 32)
        """
        # Normalize [0, 1] -> [-1, 1] for VAE
        x_norm = 2.0 * x - 1.0
        
        # Get Mode (Deterministic encoding)
        if not self.vae.training:
            with torch.no_grad():
                dist = self.vae.encode(x_norm).latent_dist
                z_vae = dist.mode() 
                z_vae = z_vae * self.scale_factor
        else:
            dist = self.vae.encode(x_norm).latent_dist
            z_vae = dist.mode() 
            z_vae = z_vae * self.scale_factor
            
        if self.use_projections:
            z_flat = torch.flatten(z_vae, start_dim=1)
            return self.to_latent(z_flat)
        else:
            z_flat = torch.flatten(z_vae, start_dim=1)
            return z_flat
            #return z_vae

    def decode(self, z):
        """Decode latent to RGB image.
        Args:
            z: (B, latent_dim) if use_projections else (B, 4, 32, 32)
        Returns:
            x_hat: (B, 3, H, W) in [0, 1]
        """
        # Project back
        if self.use_projections:
            z_flat = self.from_latent(z)
            z_vae = z_flat.view(-1, 4, self.spatial_dim, self.spatial_dim)
        else:
            #z_vae = z
            z_vae = z.view(-1, 4, self.spatial_dim, self.spatial_dim)
        
        z_vae = z_vae / self.scale_factor
        
        # Decode through VAE
        if not self.vae.training:
            with torch.no_grad():
                x_hat = self.vae.decode(z_vae).sample
        else:
            x_hat = self.vae.decode(z_vae).sample
        
        # Convert [-1, 1] back to [0, 1]
        x_hat = (x_hat / 2.0 + 0.5).clamp(0, 1)
        return x_hat


class ResidualBlock(nn.Module):
    """Residual block with GELU activation."""
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.norm1 = nn.GroupNorm(8, channels)
        self.norm2 = nn.GroupNorm(8, channels)
        
    def forward(self, x):
        residual = x
        x = F.gelu(self.norm1(self.conv1(x)))
        x = self.norm2(self.conv2(x))
        return F.gelu(x + residual)

class StrongDepthBranch(nn.Module):
    def __init__(self, latent_dim, freeze_encoder=True, use_projections=True, 
                 depth_model_name="depth-anything/Depth-Anything-V2-Small-hf"):
        """
        Adapts the Depth-Anything-V2 model for depth map encoding/decoding.
        Uses the encoder from Depth-Anything-V2 and creates a decoder.
        
        Args:
            latent_dim: Dimension for projection layers
            freeze_encoder: Whether to freeze encoder parameters
            use_projections: Whether to include linear projection layers
            depth_model_name: Hugging Face model ID (Small/Base/Large variants)
        """
        super().__init__()
        print(f"[StrongDepthBranch] Loading Depth-Anything-V2: {depth_model_name}...")
        
        # Load pretrained Depth-Anything-V2
        self.processor = AutoImageProcessor.from_pretrained(depth_model_name)
        depth_model = AutoModelForDepthEstimation.from_pretrained(depth_model_name)
        
        # Extract encoder (backbone)
        self.encoder = depth_model.backbone
        self.use_projections = use_projections
        
        if freeze_encoder:
            for param in self.encoder.parameters():
                param.requires_grad = False
            self.encoder.eval()
        
        # Depth-Anything-V2 uses DINOv2 backbone
        # For Small: output features are typically 384 channels
        # For Base: 768 channels, Large: 1024 channels
        if 'Small' in depth_model_name:
            self.feature_channels = 384
        elif 'Base' in depth_model_name:
            self.feature_channels = 768
        elif 'Large' in depth_model_name:
            self.feature_channels = 1024
        else:
            self.feature_channels = 384  # Default to Small
        
        #self.spatial_dim = 18  # Typical spatial resolution after encoding
        self.spatial_dim = 32
        #self.encoder_flat_dim = self.feature_channels * self.spatial_dim * self.spatial_dim
        self.reduced_feature_channels = int(int(self.feature_channels ** 0.5) ** 0.5)
        self.encoder_flat_dim = self.reduced_feature_channels * self.spatial_dim * self.spatial_dim
        print(f"[StrongDepthBranch] Feature channels: {self.feature_channels}, Reduced channels: {self.reduced_feature_channels}")
        print(f"[StrongDepthBranch] Encoder flat dim: {self.encoder_flat_dim}")
        #print(f"[StrongDepthBranch] Encoder feature dim: {self.encoder_flat_dim} ({self.feature_channels}ch x {self.spatial_dim}x{self.spatial_dim})")
        
        # Trainable Projections (optional)
        self.use_projections = use_projections
        if use_projections:
            self.to_latent = nn.Sequential(
                nn.Conv2d(self.feature_channels, self.reduced_feature_channels, kernel_size=1),
                nn.Flatten(),
                nn.Linear(self.encoder_flat_dim, latent_dim)
            )
            self.from_latent = nn.Sequential(
                nn.Linear(latent_dim, self.encoder_flat_dim),
                nn.Unflatten(1, (self.reduced_feature_channels, self.spatial_dim, self.spatial_dim)),
                nn.Conv2d(self.reduced_feature_channels, self.feature_channels, kernel_size=1)
            )
        else:
            self.to_latent = None
            self.from_latent = None
        
        # Decoder: Compact CNN with PixelShuffle, GELU, and Residual Blocks
        # Input: (B, 384, 32, 32) -> Output: (B, 1, 256, 256)
        # Reduced from ~5M to ~2M parameters by using fewer channels and residual blocks
        
        self.decoder = nn.Sequential(
            # Initial conv - reduced from 512 to 256 channels
            nn.Conv2d(self.feature_channels, 256, 3, padding=1),
            nn.GroupNorm(8, 256),
            nn.GELU(),
            
            # Single residual block at 32x32 (reduced from 2)
            ResidualBlock(256),
            
            # Upsample 32 -> 64 using PixelShuffle (factor 2)
            nn.Conv2d(256, 128 * 4, 3, padding=1),  # 4 = 2^2 for PixelShuffle
            nn.PixelShuffle(2),  # (B, 128, 64, 64)
            nn.GroupNorm(8, 128),
            nn.GELU(),
            
            # Single residual block at 64x64 (reduced from 2)
            ResidualBlock(128),
            
            # Upsample 64 -> 128 using PixelShuffle (factor 2)
            nn.Conv2d(128, 64 * 4, 3, padding=1),
            nn.PixelShuffle(2),  # (B, 64, 128, 128)
            nn.GroupNorm(8, 64),
            nn.GELU(),
            
            # Single residual block at 128x128 (reduced from 2)
            ResidualBlock(64),
            
            # Upsample 128 -> 256 using PixelShuffle (factor 2)
            nn.Conv2d(64, 32 * 4, 3, padding=1),
            nn.PixelShuffle(2),  # (B, 32, 256, 256)
            nn.GroupNorm(8, 32),
            nn.GELU(),
            
            # Output projection (no final residual block)
            nn.Conv2d(32, 1, 3, padding=1),
            nn.Sigmoid()  # Output in [0, 1]
        )

    def encode(self, x):
        """Encode depth map to latent space.
        Args:
            x: (B, 1, H, W) in [0, 1]
        Returns:
            z: (B, latent_dim) if use_projections else (B, C, H', W')
        """
        # Depth-Anything expects 3-channel input (RGB-like), so repeat the single channel
        x_3ch = x.repeat(1, 3, 1, 1)
        
        # Normalize to ImageNet stats (processor handles this internally, but we do it manually)
        # Depth-Anything uses standard ImageNet normalization
        mean = torch.tensor([0.485, 0.456, 0.406], device=x.device).view(1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225], device=x.device).view(1, 3, 1, 1)
        x_norm = (x_3ch - mean) / std
        
        # Get features from encoder
        if not self.encoder.training:
            with torch.no_grad():
                outputs = self.encoder(x_norm, output_hidden_states=True)
        else:
            outputs = self.encoder(x_norm, output_hidden_states=True)
        
        # Extract features from last hidden state
        # DINOv2 returns hidden states as a list
        if hasattr(outputs, 'last_hidden_state'):
            features = outputs.last_hidden_state
        elif hasattr(outputs, 'hidden_states'):
            features = outputs.hidden_states[-1]
        else:
            # Fallback: assume outputs is the feature tensor
            features = outputs
        
        # remove class token if present
        features = features[:, 1:, :]  # (B, N-1, C)
        
        # Reshape if needed (B, N, C) -> (B, C, H, W) for ViT outputs
        if features.ndim == 3:
            B, N, C = features.shape
            H = W = int(N ** 0.5)
            features = features.transpose(1, 2).reshape(B, C, H, W)
        
        # Resize to expected spatial dimension if needed
        if features.shape[-1] != self.spatial_dim:
            features = F.interpolate(features, size=(self.spatial_dim, self.spatial_dim), 
                                    mode='bilinear', align_corners=False)
        
        if self.use_projections:
            # max pool across channels to get single channel feature map
            #features = torch.mean(features, dim=1, keepdim=True)  # (B, 1, H, W)
            #features = self.reduce_conv(features)  # (B, reduced_C, H, W)
            #z_flat = torch.flatten(features, start_dim=1)
            #return self.to_latent(z_flat)
            return self.to_latent(features)
        else:
            z_flat = torch.flatten(features, start_dim=1)
            #print(f"[StrongDepthBranch] Returning flattened features shape: {z_flat.shape}")
            #return features
            return z_flat

    def decode(self, z):
        """Decode latent to depth map.
        Args:
            z: (B, latent_dim) if use_projections else (B, C, H', W')
        Returns:
            x_hat: (B, 1, H, W) in [0, 1]
        """
        # Project back to feature space
        if self.use_projections:
            #z_flat = self.from_latent(z)
            #features = z_flat.view(-1, self.reduced_feature_channels, self.spatial_dim, self.spatial_dim)
            features = self.from_latent(z)
        else:
            #features = z
            features = z.view(-1, self.feature_channels, self.spatial_dim, self.spatial_dim)
        
        # Decode through CNN decoder
        x_hat = self.decoder(features)
        
        return x_hat


class MultimodalAdaptiveAE_ImageDepth(nn.Module):
    def __init__(self, 
                 image_shape=(3, 256, 256),
                 depth_shape=(1, 256, 256),
                 img_flat_dim=4*32*32,
                 depth_flat_dim=4*32*32,
                 latent_dim_rgb=512,
                 latent_dim_depth=512,
                 latent_dim_shared=512,
                 vae_model_name="stabilityai/sd-vae-ft-mse",
                 depth_model_name="depth-anything/Depth-Anything-V2-Small-hf",
                 freeze_vae=True):
        """
        Multimodal autoencoder for RGB images and depth maps using pretrained models.
        
        Args:
            image_shape: RGB image shape (C, H, W) - typically (3, 256, 256)
            depth_shape: Depth map shape (C, H, W) - typically (1, 256, 256)
            latent_dim_rgb: RGB-specific latent dimension
            latent_dim_depth: Depth-specific latent dimension
            latent_dim_shared: Shared latent dimension
            vae_model_name: HuggingFace model ID for RGB VAE
            depth_model_name: HuggingFace model ID for depth encoder
            freeze_vae: Whether to freeze pretrained parameters
        """
        super().__init__()
        
        # --- Unimodal Encoders/Decoders ---
        self.rgb_branch = StrongRGBBranch(
            latent_dim_rgb, 
            model_name=vae_model_name,
            freeze_vae=freeze_vae,
            use_projections=False
        )
        
        self.depth_branch = StrongDepthBranch(
            latent_dim_depth,
            freeze_encoder=freeze_vae,
            use_projections=True,
            depth_model_name=depth_model_name
        )
        
        # --- Adaptive Bottlenecks ---
        # 1. Shared Layer (Takes concatenated latents -> compresses to shared)
        self.shared_layer = AdaptiveRankReducedLinear(
            in_features=latent_dim_rgb + latent_dim_depth,
            out_features=latent_dim_shared
        )
        
        # 2. Specific Layers (Take unimodal latent -> compresses to specific)
        self.rgb_spec_layer = AdaptiveRankReducedLinear(latent_dim_rgb, latent_dim_shared)
        self.depth_spec_layer = AdaptiveRankReducedLinear(latent_dim_depth, latent_dim_shared)
        
        # Store for easy access in training loop
        self.adaptive_layers = nn.ModuleList([
            self.shared_layer, 
            self.rgb_spec_layer, 
            self.depth_spec_layer
        ])

        # Normalization layers
        self.rgb_norm = nn.LayerNorm(latent_dim_rgb)
        self.depth_norm = nn.LayerNorm(latent_dim_depth)
        
        # Pre-fusion layers to align pretrained representations
        self.rgb_pre_fusion = nn.Linear(img_flat_dim, latent_dim_rgb)
        self.depth_pre_fusion = nn.Linear(latent_dim_depth, latent_dim_depth)
        
        # Fusion layers to project concatenated (shared + specific) back to decoder input dims
        self.rgb_fusion = nn.Linear(2*latent_dim_shared, img_flat_dim)
        self.depth_fusion = nn.Linear(2*latent_dim_shared, latent_dim_depth)
        
        # Initialize fusion layers with small weights for smoother transition
        with torch.no_grad():
            nn.init.xavier_uniform_(self.rgb_fusion.weight, gain=0.01)
            nn.init.xavier_uniform_(self.depth_fusion.weight, gain=0.01)
            nn.init.zeros_(self.rgb_fusion.bias)
            nn.init.zeros_(self.depth_fusion.bias)
    
    def encode(self, rgb_images, depth_maps):
        """Encode RGB and depth to shared and specific representations.
        
        Args:
            rgb_images: (B, 3, H, W) in [0, 1]
            depth_maps: (B, 1, H, W) in [0, 1]
            
        Returns:
            tuple: (h_shared, [h_rgb_specific, h_depth_specific])
        """
        # Encode each modality
        h_rgb = self.rgb_branch.encode(rgb_images)
        h_depth = self.depth_branch.encode(depth_maps)
        #print(f"[MultimodalAdaptiveAE_ImageDepth] h_rgb shape: {h_rgb.shape}, h_depth shape: {h_depth.shape}")
        
        # Normalize and pre-fuse
        h_rgb = self.rgb_norm(h_rgb)
        h_rgb = F.relu(self.rgb_pre_fusion(h_rgb))
        
        h_depth = self.depth_norm(h_depth)
        h_depth = F.relu(self.depth_pre_fusion(h_depth))
        
        # Concatenate for shared representation
        h = torch.cat([h_rgb, h_depth], dim=1)
        h_shared = self.shared_layer(h)
        
        # Modality-specific representations
        h_rgb_spec = self.rgb_spec_layer(h_rgb)
        h_depth_spec = self.depth_spec_layer(h_depth)
        
        h_specific = [h_rgb_spec, h_depth_spec]
        
        return (h_shared, h_specific)
    
    def decode(self, h):
        """Decode shared and specific representations back to RGB and depth.
        
        Args:
            h: tuple of (h_shared, [h_rgb_specific, h_depth_specific])
            
        Returns:
            list: [rgb_reconstruction, depth_reconstruction]
        """
        h_shared, h_specific = h
        
        # Reconstruct RGB
        h_rgb_combined = torch.cat([h_shared, h_specific[0]], dim=1)
        h_rgb_fused = self.rgb_fusion(h_rgb_combined)
        rgb_recon = self.rgb_branch.decode(h_rgb_fused)
        
        # Reconstruct Depth
        h_depth_combined = torch.cat([h_shared, h_specific[1]], dim=1)
        h_depth_fused = self.depth_fusion(h_depth_combined)
        depth_recon = self.depth_branch.decode(h_depth_fused)
        
        return [rgb_recon, depth_recon]

    def forward(self, rgb_images, depth_maps):
        """Full forward pass.
        
        Args:
            rgb_images: (B, 3, H, W)
            depth_maps: (B, 1, H, W)
            
        Returns:
            list: [rgb_reconstruction, depth_reconstruction]
        """
        h = self.encode(rgb_images, depth_maps)
        x_hat = self.decode(h)
        return x_hat
    
    def encode_modalities(self, rgb_images, depth_maps):
        """Encode and return combined representations for each modality.
        
        Args:
            rgb_images: (B, 3, H, W)
            depth_maps: (B, 1, H, W)
            
        Returns:
            list: [rgb_combined_repr, depth_combined_repr] 
                  where each is (B, latent_dim_shared + latent_dim_specific)
        """
        h_shared, h_specific = self.encode(rgb_images, depth_maps)
        h_combined = []
        for i, h_m in enumerate(h_specific):
            h_combined.append(torch.cat([h_shared, h_m], dim=1))
        return h_combined
    
    def reduce_rank(self, reduction_ratio=0.9, threshold=0.01, layer_ids=[], dim=0):
        changes_made = False
        for i, layer in enumerate(self.adaptive_layers):
            if i not in layer_ids: continue
            S = layer.get_rank_reduction_info()
            if len(S) <= layer.min_rank: continue
            
            energy = S**2
            cumulative_energy = torch.cumsum(energy / energy.sum(), dim=0)
            target_rank = max(layer.min_rank, torch.sum(cumulative_energy < (1.0 - threshold)).item())
            
            current_rank = layer.active_dims
            ratio_rank = max(layer.min_rank, int(current_rank * reduction_ratio))
            new_rank = max(target_rank, ratio_rank)
            
            if new_rank < current_rank:
                layer.reduce_rank(new_rank)
                changes_made = True
        return changes_made

    def increase_rank(self, increment=None, increase_ratio=1.1, layer_ids=[], dim=0):
        changes_made = False
        for i, layer in enumerate(self.adaptive_layers):
            if i not in layer_ids: continue
            if layer.increase_rank(increment=increment, increase_ratio=increase_ratio):
                changes_made = True
        return changes_made
    
    def get_total_rank(self):
        """Return total rank across all adaptive layers."""
        return sum(layer.active_dims for layer in self.adaptive_layers)
