import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import AutoencoderKL

from src.models.larrp_unimodal import AdaptiveRankReducedLinear

class StrongRGBBranch(nn.Module):
    def __init__(self, latent_dim, model_name="stabilityai/sd-vae-ft-mse", freeze_vae=True, use_projections=True):
        """
        Adapts the 83M param Stable Diffusion VAE for 3-channel RGB images.
        Can load pretrained weights from stage 1 or stage 2 pretraining.
        
        Args:
            latent_dim: Dimension for projection layers
            model_name: HuggingFace model name for VAE
            freeze_vae: Whether to freeze VAE parameters
            use_projections: Whether to include linear projection layers
        """
        super().__init__()
        print(f"[StrongRGBBranch] Loading SD-VAE: {model_name}...")
        self.vae = AutoencoderKL.from_pretrained(model_name)
        self.use_projections = use_projections
        
        if freeze_vae:
            for param in self.vae.parameters():
                param.requires_grad = False
            self.vae.eval()
            
        # VAE downsamples by 8. Latent channels = 4.
        # We assume 256x256 input -> 32x32 feature map.
        self.spatial_dim = 32 
        self.vae_flat_dim = 4 * self.spatial_dim * self.spatial_dim
        print(f"[StrongRGBBranch] VAE latent flat dim: {self.vae_flat_dim}")
        
        # Trainable Projections (optional)
        if use_projections:
            self.to_latent = nn.Linear(self.vae_flat_dim, latent_dim)
            self.from_latent = nn.Linear(latent_dim, self.vae_flat_dim)
        else:
            self.to_latent = None
            self.from_latent = None
        
        self.scale_factor = 0.18215

    def encode(self, x):
        """Encode RGB image to latent space.
        Args:
            x: (B, 3, H, W) in [0, 1]
        Returns:
            z: (B, latent_dim) if use_projections else (B, 4, 32, 32)
        """
        # Normalize [0, 1] -> [-1, 1] for VAE
        x_norm = 2.0 * x - 1.0
        
        # Get Mode (Deterministic encoding)
        if not self.vae.training:
            with torch.no_grad():
                dist = self.vae.encode(x_norm).latent_dist
                z_vae = dist.mode() 
                z_vae = z_vae * self.scale_factor
        else:
            dist = self.vae.encode(x_norm).latent_dist
            z_vae = dist.mode() 
            z_vae = z_vae * self.scale_factor
            
        if self.use_projections:
            z_flat = torch.flatten(z_vae, start_dim=1)
            return self.to_latent(z_flat)
        else:
            return z_vae

    def decode(self, z):
        """Decode latent to RGB image.
        Args:
            z: (B, latent_dim) if use_projections else (B, 4, 32, 32)
        Returns:
            x_hat: (B, 3, H, W) in [0, 1]
        """
        # Project back
        if self.use_projections:
            z_flat = self.from_latent(z)
            z_vae = z_flat.view(-1, 4, self.spatial_dim, self.spatial_dim)
        else:
            z_vae = z
        
        z_vae = z_vae / self.scale_factor
        
        # Decode through VAE
        if not self.vae.training:
            with torch.no_grad():
                x_hat = self.vae.decode(z_vae).sample
        else:
            x_hat = self.vae.decode(z_vae).sample
        
        # Convert [-1, 1] back to [0, 1]
        x_hat = (x_hat / 2.0 + 0.5).clamp(0, 1)
        return x_hat


class StrongDepthBranch(nn.Module):
    def __init__(self, latent_dim, model_name="stabilityai/sd-vae-ft-mse", freeze_vae=True, use_projections=True):
        """
        Adapts the 83M param Stable Diffusion VAE for 1-channel Depth.
        Can load pretrained weights from stage 1 or stage 2 pretraining.
        
        Args:
            latent_dim: Dimension for projection layers
            model_name: HuggingFace model name for VAE
            freeze_vae: Whether to freeze VAE parameters
            use_projections: Whether to include linear projection layers
        """
        super().__init__()
        print(f"[StrongDepthBranch] Loading SD-VAE: {model_name}...")
        self.vae = AutoencoderKL.from_pretrained(model_name)
        self.use_projections = use_projections
        
        if freeze_vae:
            for param in self.vae.parameters():
                param.requires_grad = False
            self.vae.eval()
            
        # VAE downsamples by 8. Latent channels = 4.
        # We assume 256x256 input -> 32x32 feature map.
        self.spatial_dim = 32 
        self.vae_flat_dim = 4 * self.spatial_dim * self.spatial_dim
        print(f"[StrongDepthBranch] VAE latent flat dim: {self.vae_flat_dim}")
        
        # Trainable Projections (optional)
        if use_projections:
            self.to_latent = nn.Linear(self.vae_flat_dim, latent_dim)
            self.from_latent = nn.Linear(latent_dim, self.vae_flat_dim)
        else:
            self.to_latent = None
            self.from_latent = None
        
        self.scale_factor = 0.18215

    def encode(self, x):
        """Encode depth map to latent space.
        Args:
            x: (B, 1, H, W) in [0, 1]
        Returns:
            z: (B, latent_dim) if use_projections else (B, 4, 32, 32)
        """
        # 1. Adapter: 1ch -> 3ch (Repeat)
        x_3ch = x.repeat(1, 3, 1, 1)
        
        # 2. Normalize [0, 1] -> [-1, 1] for VAE
        x_norm = 2.0 * x_3ch - 1.0
        
        # Get Mode (Deterministic encoding)
        if not self.vae.training:
            with torch.no_grad():
                dist = self.vae.encode(x_norm).latent_dist
                z_vae = dist.mode() 
                z_vae = z_vae * self.scale_factor
        else:
            dist = self.vae.encode(x_norm).latent_dist
            z_vae = dist.mode() 
            z_vae = z_vae * self.scale_factor
            
        if self.use_projections:
            z_flat = torch.flatten(z_vae, start_dim=1)
            return self.to_latent(z_flat)
        else:
            return z_vae

    def decode(self, z):
        """Decode latent to depth map.
        Args:
            z: (B, latent_dim) if use_projections else (B, 4, 32, 32)
        Returns:
            x_hat: (B, 1, H, W) in [0, 1]
        """
        # Project back
        if self.use_projections:
            z_flat = self.from_latent(z)
            z_vae = z_flat.view(-1, 4, self.spatial_dim, self.spatial_dim)
        else:
            z_vae = z
        
        z_vae = z_vae / self.scale_factor
        
        # Decode through VAE
        if not self.vae.training:
            with torch.no_grad():
                x_hat_3ch = self.vae.decode(z_vae).sample
        else:
            x_hat_3ch = self.vae.decode(z_vae).sample
            
        # 3. Adapter: 3ch -> 1ch (Mean)
        x_hat = torch.mean(x_hat_3ch, dim=1, keepdim=True)
        
        # Convert [-1, 1] back to [0, 1]
        x_hat = (x_hat / 2.0 + 0.5).clamp(0, 1)
        return x_hat


class MultimodalAdaptiveAE_ImageDepth(nn.Module):
    def __init__(self, 
                 image_shape=(3, 256, 256),
                 depth_shape=(1, 256, 256),
                 latent_dim_rgb=512,
                 latent_dim_depth=512,
                 latent_dim_shared=512,
                 vae_model_name="stabilityai/sd-vae-ft-mse",
                 freeze_vae=True):
        """
        Multimodal autoencoder for RGB images and depth maps using pretrained VAEs.
        
        Args:
            image_shape: RGB image shape (C, H, W) - typically (3, 256, 256)
            depth_shape: Depth map shape (C, H, W) - typically (1, 256, 256)
            latent_dim_rgb: RGB-specific latent dimension
            latent_dim_depth: Depth-specific latent dimension
            latent_dim_shared: Shared latent dimension
            vae_model_name: HuggingFace model ID for VAE (same for both branches)
            freeze_vae: Whether to freeze VAE parameters (typically True for pretrained)
        """
        super().__init__()
        
        # --- Unimodal Encoders/Decoders (with pretrained VAEs) ---
        self.rgb_branch = StrongRGBBranch(
            latent_dim_rgb, 
            model_name=vae_model_name,
            freeze_vae=freeze_vae,
            use_projections=True
        )
        
        self.depth_branch = StrongDepthBranch(
            latent_dim_depth,
            model_name=vae_model_name,
            freeze_vae=freeze_vae,
            use_projections=True
        )
        
        # --- Adaptive Bottlenecks ---
        # 1. Shared Layer (Takes concatenated latents -> compresses to shared)
        self.shared_layer = AdaptiveRankReducedLinear(
            in_features=latent_dim_rgb + latent_dim_depth,
            out_features=latent_dim_shared
        )
        
        # 2. Specific Layers (Take unimodal latent -> compresses to specific)
        self.rgb_spec_layer = AdaptiveRankReducedLinear(latent_dim_rgb, latent_dim_shared)
        self.depth_spec_layer = AdaptiveRankReducedLinear(latent_dim_depth, latent_dim_shared)
        
        # Store for easy access in training loop
        self.adaptive_layers = nn.ModuleList([
            self.shared_layer, 
            self.rgb_spec_layer, 
            self.depth_spec_layer
        ])

        # Normalization layers
        self.rgb_norm = nn.LayerNorm(latent_dim_rgb)
        self.depth_norm = nn.LayerNorm(latent_dim_depth)
        
        # Pre-fusion layers to align pretrained representations
        self.rgb_pre_fusion = nn.Linear(latent_dim_rgb, latent_dim_rgb)
        self.depth_pre_fusion = nn.Linear(latent_dim_depth, latent_dim_depth)
        
        # Fusion layers to project concatenated (shared + specific) back to decoder input dims
        self.rgb_fusion = nn.Linear(2*latent_dim_shared, latent_dim_rgb)
        self.depth_fusion = nn.Linear(2*latent_dim_shared, latent_dim_depth)
        
        # Initialize fusion layers with small weights for smoother transition
        with torch.no_grad():
            nn.init.xavier_uniform_(self.rgb_fusion.weight, gain=0.01)
            nn.init.xavier_uniform_(self.depth_fusion.weight, gain=0.01)
            nn.init.zeros_(self.rgb_fusion.bias)
            nn.init.zeros_(self.depth_fusion.bias)
    
    def encode(self, rgb_images, depth_maps):
        """Encode RGB and depth to shared and specific representations.
        
        Args:
            rgb_images: (B, 3, H, W) in [0, 1]
            depth_maps: (B, 1, H, W) in [0, 1]
            
        Returns:
            tuple: (h_shared, [h_rgb_specific, h_depth_specific])
        """
        # Encode each modality
        h_rgb = self.rgb_branch.encode(rgb_images)
        h_depth = self.depth_branch.encode(depth_maps)
        
        # Normalize and pre-fuse
        h_rgb = self.rgb_norm(h_rgb)
        h_rgb = F.relu(self.rgb_pre_fusion(h_rgb))
        
        h_depth = self.depth_norm(h_depth)
        h_depth = F.relu(self.depth_pre_fusion(h_depth))
        
        # Concatenate for shared representation
        h = torch.cat([h_rgb, h_depth], dim=1)
        h_shared = self.shared_layer(h)
        
        # Modality-specific representations
        h_rgb_spec = self.rgb_spec_layer(h_rgb)
        h_depth_spec = self.depth_spec_layer(h_depth)
        
        h_specific = [h_rgb_spec, h_depth_spec]
        
        return (h_shared, h_specific)
    
    def decode(self, h):
        """Decode shared and specific representations back to RGB and depth.
        
        Args:
            h: tuple of (h_shared, [h_rgb_specific, h_depth_specific])
            
        Returns:
            list: [rgb_reconstruction, depth_reconstruction]
        """
        h_shared, h_specific = h
        
        # Reconstruct RGB
        h_rgb_combined = torch.cat([h_shared, h_specific[0]], dim=1)
        h_rgb_fused = self.rgb_fusion(h_rgb_combined)
        rgb_recon = self.rgb_branch.decode(h_rgb_fused)
        
        # Reconstruct Depth
        h_depth_combined = torch.cat([h_shared, h_specific[1]], dim=1)
        h_depth_fused = self.depth_fusion(h_depth_combined)
        depth_recon = self.depth_branch.decode(h_depth_fused)
        
        return [rgb_recon, depth_recon]

    def forward(self, rgb_images, depth_maps):
        """Full forward pass.
        
        Args:
            rgb_images: (B, 3, H, W)
            depth_maps: (B, 1, H, W)
            
        Returns:
            list: [rgb_reconstruction, depth_reconstruction]
        """
        h = self.encode(rgb_images, depth_maps)
        x_hat = self.decode(h)
        return x_hat
    
    def encode_modalities(self, rgb_images, depth_maps):
        """Encode and return combined representations for each modality.
        
        Args:
            rgb_images: (B, 3, H, W)
            depth_maps: (B, 1, H, W)
            
        Returns:
            list: [rgb_combined_repr, depth_combined_repr] 
                  where each is (B, latent_dim_shared + latent_dim_specific)
        """
        h_shared, h_specific = self.encode(rgb_images, depth_maps)
        h_combined = []
        for i, h_m in enumerate(h_specific):
            h_combined.append(torch.cat([h_shared, h_m], dim=1))
        return h_combined
    
    def reduce_rank(self, reduction_ratio=0.9, threshold=0.01, layer_ids=[], dim=0):
        changes_made = False
        for i, layer in enumerate(self.adaptive_layers):
            if i not in layer_ids: continue
            S = layer.get_rank_reduction_info()
            if len(S) <= layer.min_rank: continue
            
            energy = S**2
            cumulative_energy = torch.cumsum(energy / energy.sum(), dim=0)
            target_rank = max(layer.min_rank, torch.sum(cumulative_energy < (1.0 - threshold)).item())
            
            current_rank = layer.active_dims
            ratio_rank = max(layer.min_rank, int(current_rank * reduction_ratio))
            new_rank = max(target_rank, ratio_rank)
            
            if new_rank < current_rank:
                layer.reduce_rank(new_rank)
                changes_made = True
        return changes_made

    def increase_rank(self, increment=None, increase_ratio=1.1, layer_ids=[], dim=0):
        changes_made = False
        for i, layer in enumerate(self.adaptive_layers):
            if i not in layer_ids: continue
            if layer.increase_rank(increment=increment, increase_ratio=increase_ratio):
                changes_made = True
        return changes_made
    
    def get_total_rank(self):
        """Return total rank across all adaptive layers."""
        return sum(layer.active_dims for layer in self.adaptive_layers)
