import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers import BertModel, BertConfig

from src.models.larrp_unimodal import AdaptiveRankReducedLinear

# ==========================================
# 1. IMAGE BRANCH (CNN)
# ==========================================

class ImageBranch(nn.Module):
    def __init__(self, input_shape, latent_dim, depth=3, base_channels=32, dropout=0.1):
        super().__init__()
        self.input_shape = input_shape # (C, H, W)
        self.latent_dim = latent_dim
        
        # --- Encoder ---
        layers = []
        in_c, h, w = input_shape
        
        # Conv Blocks
        for i in range(depth):
            out_c = base_channels * (2**i)
            layers.append(nn.Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout2d(dropout))
            in_c = out_c
            h, w = (h + 1) // 2, (w + 1) // 2 # Calculate new spatial dims
            
        self.encoder_conv = nn.Sequential(*layers)
        self.flat_size = in_c * h * w
        self.final_conv_shape = (in_c, h, w)
        
        # Projection
        self.to_latent = nn.Linear(self.flat_size, latent_dim)
        print(f"[ImageBranch] Encoder output shape: {self.final_conv_shape}, flat size: {self.flat_size}")
        
        # --- Decoder ---
        # Projection back
        self.from_latent = nn.Linear(latent_dim, self.flat_size)
        print(f"[ImageBranch] Decoder input flat size: {self.flat_size}")
        
        # Unflatten handled in forward
        
        # Transpose Conv Blocks
        dec_layers = []
        for i in range(depth - 1, -1, -1):
            in_c = base_channels * (2**i)
            out_c = base_channels * (2**(i-1)) if i > 0 else input_shape[0]
            
            #dec_layers.append(nn.ConvTranspose2d(in_c, out_c, kernel_size=3, stride=2, padding=1, output_padding=1))
            dec_layers.append(nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1)
            ))
            if i > 0: # No Activation on final layer (handled by loss or sigmoid outside)
                dec_layers.append(nn.ReLU())
                dec_layers.append(nn.Dropout2d(dropout))
                
        self.decoder_conv = nn.Sequential(*dec_layers)

    def encode(self, x):
        x = self.encoder_conv(x)
        x = torch.flatten(x, 1)
        z = self.to_latent(x)
        return z

    def decode(self, z):
        x = self.from_latent(z)
        x = x.view(-1, *self.final_conv_shape)
        x = self.decoder_conv(x)
        # Debug: check range before sigmoid
        # print(f"[DEBUG ImageBranch.decode] Before sigmoid: min={x.min().item():.4f}, max={x.max().item():.4f}")
        #x = torch.sigmoid(x)  # Ensure output in [0,1]
        # print(f"[DEBUG ImageBranch.decode] After sigmoid: min={x.min().item():.4f}, max={x.max().item():.4f}")
        return x


# ==========================================
# 2. TEXT BRANCH (BERT -> LATENT -> GPT-Style)
# ==========================================

class TrainableTextBranch(nn.Module):
    """Lightweight trainable bidirectional transformer encoder (alternative to frozen BERT)"""
    def __init__(self, latent_dim, vocab_size=30522, encoder_dim=256, encoder_depth=2, 
                 encoder_heads=4, decoder_depth=2, decoder_heads=4, decoder_dim=256, 
                 max_seq_len=512, dropout=0.1):
        """
        Args:
            latent_dim: Dimension of the compressed latent space
            vocab_size: Vocabulary size (default matches BERT)
            encoder_dim: Hidden dimension of encoder transformer
            encoder_depth: Number of encoder transformer layers
            encoder_heads: Number of attention heads in encoder
            decoder_depth: Number of transformer decoder layers
            decoder_heads: Number of attention heads in decoder
            decoder_dim: Hidden dimension of decoder
            max_seq_len: Maximum sequence length
            dropout: Dropout rate
        """
        super().__init__()
        
        self.vocab_size = vocab_size
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim
        
        # 1. Encoder (Trainable Bidirectional Transformer)
        self.encoder_embeddings = nn.Embedding(vocab_size, encoder_dim)
        self.pos_encoding = nn.Parameter(torch.randn(1, max_seq_len, encoder_dim) * 0.02)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=encoder_dim,
            nhead=encoder_heads,
            dim_feedforward=encoder_dim * 2,
            dropout=dropout,
            batch_first=True,
            norm_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=encoder_depth)
        
        # CLS token for pooling
        self.cls_token = nn.Parameter(torch.randn(1, 1, encoder_dim) * 0.02)
        
        # 2. Projections
        self.to_latent = nn.Linear(encoder_dim, latent_dim)
        self.from_latent = nn.Linear(latent_dim, decoder_dim)
        
        # 3. Decoder (Lightweight Transformer)
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=decoder_dim,
            nhead=decoder_heads,
            dim_feedforward=decoder_dim * 2,
            dropout=dropout,
            batch_first=True,
            norm_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=decoder_depth)
        
        # 4. Input embeddings for decoder
        self.decoder_embeddings = nn.Embedding(vocab_size, decoder_dim)
        
        # 5. Output Head
        self.output_head = nn.Linear(decoder_dim, vocab_size)

        self._init_weights()
    
    def encode(self, input_ids, attention_mask):
        B, Seq = input_ids.shape
        
        # Embed tokens
        x = self.encoder_embeddings(input_ids)  # (B, Seq, encoder_dim)
        
        # Add positional encoding
        x = x + self.pos_encoding[:, :Seq, :]
        
        # Prepend CLS token
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, encoder_dim)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, Seq+1, encoder_dim)
        
        # Extend attention mask for CLS token (CLS is always unmasked)
        cls_mask = torch.ones(B, 1, dtype=attention_mask.dtype, device=attention_mask.device)
        extended_mask = torch.cat([cls_mask, attention_mask], dim=1)  # (B, Seq+1)
        
        # Create padding mask (True for padding positions)
        padding_mask = ~extended_mask.bool()
        
        # Pass through encoder
        encoded = self.encoder(x, src_key_padding_mask=padding_mask)  # (B, Seq+1, encoder_dim)
        
        # Extract CLS token
        cls_token = encoded[:, 0, :]  # (B, encoder_dim)
        
        # Project to latent
        z = self.to_latent(cls_token)
        return z
    
    def decode(self, z, input_ids, attention_mask):
        """
        z: (B, Latent)
        input_ids: (B, Seq) - The target text for Teacher Forcing
        """
        B, Seq = input_ids.shape
        
        # Prepare Memory (Latent Z) for Cross Attention
        memory = self.from_latent(z).unsqueeze(1)  # (B, 1, decoder_dim)
        
        # Embed Targets
        tgt_emb = self.decoder_embeddings(input_ids)  # (B, Seq, decoder_dim)
        
        # Generate Causal Mask (prevent looking ahead)
        tgt_mask = torch.triu(torch.ones(Seq, Seq, dtype=torch.bool), diagonal=1).to(z.device)
        
        # Padding mask
        padding_mask = ~attention_mask.bool()
        
        # Pass through Decoder
        out = self.decoder(
            tgt=tgt_emb,
            memory=memory,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=padding_mask
        )
        
        # Project to Vocab
        logits = self.output_head(out)  # (B, Seq, Vocab)
        return logits
    
    def _init_weights(self):
        """Initialize like BERT (std=0.02) to prevent explosion"""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                module.weight.data.normal_(mean=0.0, std=0.02)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.Embedding):
                module.weight.data.normal_(mean=0.0, std=0.02)
                if module.padding_idx is not None:
                    module.weight.data[module.padding_idx].zero_()
            elif isinstance(module, nn.LayerNorm):
                module.bias.data.zero_()
                module.weight.data.fill_(1.0)


class TextBranch(nn.Module):
    def __init__(self, latent_dim, use_bert=True, bert_model_name='bert-base-uncased',
                 encoder_dim=256, encoder_depth=2, encoder_heads=4,
                 decoder_depth=2, decoder_heads=4, decoder_dim=256, dropout=0.1,
                 max_seq_len=512):
        """
        Args:
            latent_dim: Dimension of the compressed latent space
            use_bert: If True, use frozen pretrained BERT. If False, use trainable transformer.
            bert_model_name: Pretrained BERT model to use (only if use_bert=True)
            encoder_dim: Hidden dimension of trainable encoder (only if use_bert=False)
            encoder_depth: Number of encoder layers in trainable encoder (only if use_bert=False)
            encoder_heads: Number of attention heads in trainable encoder (only if use_bert=False)
            decoder_depth: Number of transformer decoder layers
            decoder_heads: Number of attention heads in decoder
            decoder_dim: Hidden dimension of decoder
            dropout: Dropout rate
        """
        super().__init__()
        
        self.use_bert = use_bert
        self.decoder_dim = decoder_dim
        
        if use_bert:
            # 1. Encoder (Frozen BERT)
            self.bert = BertModel.from_pretrained(bert_model_name)
            # Freeze BERT
            for param in self.bert.parameters():
                param.requires_grad = False
                
            self.bert_hidden = self.bert.config.hidden_size  # 768
            self.vocab_size = self.bert.config.vocab_size

            # 2. Projections
            self.to_latent = nn.Linear(self.bert_hidden, latent_dim)
            self.from_latent = nn.Linear(latent_dim, decoder_dim)  # Project to smaller decoder dim

            # 3. Decoder (Lightweight Transformer)
            decoder_layer = nn.TransformerDecoderLayer(
                d_model=decoder_dim,  # Use smaller hidden dimension
                nhead=decoder_heads,
                dim_feedforward=decoder_dim * 2,  # 2x instead of 4x
                dropout=dropout,
                batch_first=True,
                norm_first=True
            )
            self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=decoder_depth)
            
            # 4. Input embeddings (project from vocab to decoder_dim)
            self.decoder_embeddings = nn.Embedding(self.vocab_size, decoder_dim)
            
            # 5. Output Head (project from decoder_dim to vocab)
            self.output_head = nn.Linear(decoder_dim, self.vocab_size)
        else:
            # Use trainable transformer branch
            self.trainable_branch = TrainableTextBranch(
                latent_dim=latent_dim,
                vocab_size=30522,  # Match BERT vocab size
                encoder_dim=encoder_dim,
                encoder_depth=encoder_depth,
                encoder_heads=encoder_heads,
                decoder_depth=decoder_depth,
                max_seq_len=max_seq_len,
                decoder_heads=decoder_heads,
                decoder_dim=decoder_dim,
                dropout=dropout
            )
            self.vocab_size = self.trainable_branch.vocab_size
            self.to_latent = self.trainable_branch.to_latent

    def encode(self, input_ids, attention_mask):
        if self.use_bert:
            # Extract [CLS] token from BERT
            with torch.no_grad():
                out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
            cls_token = out.last_hidden_state[:, 0, :] # (B, 768)
            z = self.to_latent(cls_token)
            return z
        else:
            # Use trainable transformer
            return self.trainable_branch.encode(input_ids, attention_mask)

    def decode(self, z, input_ids, attention_mask):
        """
        z: (B, Latent)
        input_ids: (B, Seq) - The target text for Teacher Forcing
        """
        if self.use_bert:
            B, Seq = input_ids.shape
            
            # Prepare Memory (Latent Z) for Cross Attention
            # Reshape z to (B, 1, Hidden) to act as a sequence of length 1
            memory = self.from_latent(z).unsqueeze(1) 
            
            # Embed Targets
            tgt_emb = self.decoder_embeddings(input_ids) # (B, Seq, Hidden)
            
            # Generate Causal Mask (prevent looking ahead)
            # Use boolean mask instead of float to match key_padding_mask type
            # Shape: (Seq, Seq)
            tgt_mask = torch.triu(torch.ones(Seq, Seq, dtype=torch.bool), diagonal=1).to(z.device)
            
            # Pass through Decoder
            # We invert the attention mask for the padding (0 becomes True for masked spots)
            padding_mask = ~attention_mask.bool()
            
            out = self.decoder(
                tgt=tgt_emb,
                memory=memory,
                tgt_mask=tgt_mask,
                tgt_key_padding_mask=padding_mask
            )
            
            # Project to Vocab
            logits = self.output_head(out) # (B, Seq, Vocab)
            return logits
        else:
            # Use trainable transformer
            return self.trainable_branch.decode(z, input_ids, attention_mask)


# ==========================================
# 3. MAIN MULTIMODAL MODEL
# ==========================================

class MultimodalAdaptiveAE(nn.Module):
    def __init__(self, 
                 image_shape=(3, 64, 64),
                 latent_dim_img=128,
                 latent_dim_text=128,
                 latent_dim_shared=128,
                 use_bert=True,
                 bert_model_name='bert-base-uncased',
                 text_encoder_dim=256,
                 text_encoder_depth=2,
                 text_encoder_heads=4,
                 text_decoder_dim=256,
                 text_decoder_depth=2,
                 text_decoder_heads=4,
                 max_seq_len=512):
        """
        Args:
            image_shape: Input image shape (C, H, W)
            latent_dim_img: Image-specific latent dimension
            latent_dim_text: Text-specific latent dimension
            latent_dim_shared: Shared latent dimension
            use_bert: If True, use frozen BERT encoder. If False, use trainable transformer.
            bert_model_name: BERT model for text encoding (only if use_bert=True)
            text_encoder_dim: Hidden dimension for trainable encoder (only if use_bert=False)
            text_encoder_depth: Number of layers in trainable encoder (only if use_bert=False)
            text_encoder_heads: Number of attention heads in trainable encoder (only if use_bert=False)
            text_decoder_dim: Hidden dimension for text decoder (smaller = fewer params)
            text_decoder_depth: Number of transformer layers in text decoder
            text_decoder_heads: Number of attention heads in text decoder
        """
        super().__init__()
        
        # --- Unimodal Encoders/Decoders ---
        self.img_branch = ImageBranch(image_shape, latent_dim_img)
        self.text_branch = TextBranch(
            latent_dim_text,
            use_bert=use_bert,
            bert_model_name=bert_model_name,
            encoder_dim=text_encoder_dim,
            encoder_depth=text_encoder_depth,
            encoder_heads=text_encoder_heads,
            decoder_depth=text_decoder_depth,
            decoder_heads=text_decoder_heads,
            decoder_dim=text_decoder_dim,
            max_seq_len=max_seq_len
        )
        
        # --- Adaptive Bottlenecks ---
        # 1. Shared Layer (Takes concatenated latents -> compresses to shared)
        self.shared_layer = AdaptiveRankReducedLinear(
            in_features=latent_dim_img + latent_dim_text,
            out_features=latent_dim_shared
        )
        
        # 2. Specific Layers (Take unimodal latent -> compresses to specific)
        self.img_spec_layer = AdaptiveRankReducedLinear(latent_dim_img, latent_dim_img)
        self.text_spec_layer = AdaptiveRankReducedLinear(latent_dim_text, latent_dim_text)
        
        # Store for easy access in training loop
        self.adaptive_layers = nn.ModuleList([
            self.shared_layer, 
            self.img_spec_layer, 
            self.text_spec_layer
        ])

        self.img_norm = nn.LayerNorm(latent_dim_img)
        self.text_norm = nn.LayerNorm(latent_dim_text)
        
        # add a layer to learn image and text in the same "magnitudes" after pretraining if unimodal is used
        self.img_pre_fusion = nn.Linear(latent_dim_img, latent_dim_img)
        self.text_pre_fusion = nn.Linear(latent_dim_text, latent_dim_text)
        # --- Fusion layers to project concatenated (shared + specific) back to decoder input dims ---
        # These are necessary because decoders expect latent_dim_img/text but receive (shared + specific)
        self.img_fusion = nn.Linear(latent_dim_shared + latent_dim_img, latent_dim_img)
        self.text_fusion = nn.Linear(latent_dim_shared + latent_dim_text, latent_dim_text)
        
        # Initialize fusion layers with small weights for smoother transition from unimodal
        with torch.no_grad():
            nn.init.xavier_uniform_(self.img_fusion.weight, gain=0.5)
            nn.init.zeros_(self.img_fusion.bias)
            nn.init.xavier_uniform_(self.text_fusion.weight, gain=0.5)
            nn.init.zeros_(self.text_fusion.bias)
    
    def encode(self, images, text_input_ids, text_mask):
        h_img = self.img_branch.encode(images)
        h_text = self.text_branch.encode(text_input_ids, text_mask)
        h_img = self.img_norm(h_img)
        h_img = F.relu(self.img_pre_fusion(h_img))
        h_text = self.text_norm(h_text)
        h_text = F.relu(self.text_pre_fusion(h_text))
        h = torch.cat([h_img, h_text], dim=1)
        h_shared = self.adaptive_layers[0](h)
        h_specific = []
        for layer in self.adaptive_layers[1:]:
            h_specific.append(layer(h_img if layer == self.img_spec_layer else h_text))
        return (h_shared, h_specific)
    
    def decode(self, h, text_input_ids, text_mask):
        h_shared, h_specific = h
        x_hat = []
        for m, h_m in enumerate(h_specific):
            h_concat = torch.cat([h_shared, h_m], dim=1)
            if m == 0: # Image Modality
                # Project concatenated (shared + specific) to image decoder input dimension
                z_img_full = self.img_fusion(h_concat)
                recon_img = self.img_branch.decode(z_img_full)
                x_hat.append(recon_img)
            else: # Text Modality
                # Project concatenated (shared + specific) to text decoder input dimension
                z_text_full = self.text_fusion(h_concat)
                text_logits = self.text_branch.decode(z_text_full, text_input_ids, text_mask)
                x_hat.append(text_logits)
        return x_hat

    def forward(self, images, text_input_ids, text_mask):
        h = self.encode(images, text_input_ids, text_mask)
        x_hat = self.decode(h, text_input_ids, text_mask)
        return x_hat, h
    
    def encode_modalities(self, images, text_input_ids, text_mask):
        h_shared, h_specific = self.encode(images, text_input_ids, text_mask)
        h_combined = []
        for i, h_m in enumerate(h_specific):
            h_combined.append(torch.cat([h_shared, h_m], dim=1))
        return h_combined
    
    def reduce_rank(self, reduction_ratio=0.9, threshold=0.01, layer_ids=[], dim=0):
        """Reduce rank of all adaptive layers based on singular value importance"""
        changes_made = False

        for i, layer in enumerate(self.adaptive_layers):
            # if layer_ids is specified, only reduce rank for those layers
            if i not in layer_ids:
                continue
            # Get singular values
            S = layer.get_rank_reduction_info()
            
            if len(S) <= layer.min_rank:
                continue  # Already at minimum rank
                
            # Calculate normalized cumulative energy
            energy = S**2
            normalized_energy = energy / energy.sum()
            cumulative_energy = torch.cumsum(normalized_energy, dim=0)
            #print(f"Layer {i}: singular values = {S.cpu().numpy()}")
            #print(f"Layer {i}: cumulative energy = {cumulative_energy.cpu().numpy()}")

            # Find the rank that preserves specified energy threshold
            # Make sure we don't go below the minimum rank
            target_rank = max(layer.min_rank, 
                             torch.sum(cumulative_energy < (1.0 - threshold)).item())
            # get the indices of the dims where the cum energy is below the threshold ### we don't need this because the svd decomp makes the dims sorted by energy from left to right ###
            #if target_rank > layer.min_rank:
            #    which_dims = torch.where(cumulative_energy < (1.0 - threshold))[0]
            #else:
            #    which_dims = torch.where(cumulative_energy < (1.0 - threshold))[0][:layer.min_rank]
            # test if which_dims includes all dims from left to the target_rank
            #n_left_of_target = torch.sum(which_dims < target_rank).item()
            #n_right_of_target = torch.sum(which_dims >= target_rank).item()
            #print(f"Layer {i}: target_rank = {target_rank}, left of target = {n_left_of_target}, right of target = {n_right_of_target}")
            which_dims = None

            # Alternative: just reduce by fixed ratio, but not below min_rank
            current_rank = layer.active_dims
            ratio_rank = max(layer.min_rank, int(current_rank * reduction_ratio))
            
            # Take the larger of the two approaches
            new_rank = max(target_rank, ratio_rank)
            #new_rank = target_rank
            
            # Only reduce if new rank is smaller than current
            if new_rank < current_rank:
                layer.reduce_rank(new_rank, dim=dim, which_dims=which_dims)
                changes_made = True
                
        return changes_made

    def increase_rank(self, increment=None, increase_ratio=1.1, layer_ids=[], dim=0):
        """Increase rank of all adaptive layers by specified increment"""
        changes_made = False
        
        for i, layer in enumerate(self.adaptive_layers):
            if i not in layer_ids:
                continue
            #print(f"Increasing rank for layer {i}")
            if layer.increase_rank(increment=increment, increase_ratio=increase_ratio, dim=dim, mode='multimodal'):
                changes_made = True
                
        return changes_made
    
    def get_total_rank(self):
        """Return total rank across all adaptive layers"""
        return sum(layer.active_dims for layer in self.adaptive_layers)

###########################################

from transformers import T5ForConditionalGeneration, T5Tokenizer
from diffusers import AutoencoderTiny

# ==========================================
# 2. IMAGE BRANCH (Pretrained VAE)
# ==========================================

class ImageBranch2(nn.Module):
    def __init__(self, input_shape, latent_dim, model_name="madebyollin/taesd", freeze=True):
        """
        Args:
            input_shape: (C, H, W) - Expected to be (3, 256, 256) for VAE
            latent_dim: Dimension of the bottleneck z (e.g. 128)
            model_name: HuggingFace model ID for the VAE
        """
        super().__init__()
        print(f"Loading Pretrained VAE: {model_name}...")
        self.vae = AutoencoderTiny.from_pretrained(model_name)
        
        if freeze:
            for param in self.vae.parameters():
                param.requires_grad = False
            self.vae.eval()
            
        self.latent_dim = latent_dim
        
        # Calculate VAE flat dimension
        # SD VAE downsamples by factor of 8. For 256x256 input -> 32x32 latent.
        # Latent channels = 4.
        img_size = input_shape[1]
        self.spatial_dim = img_size // 8
        self.vae_flat_dim = 4 * self.spatial_dim * self.spatial_dim
        
        # Projections
        self.to_latent = nn.Sequential(
            nn.LayerNorm(self.vae_flat_dim),
            nn.Linear(self.vae_flat_dim, latent_dim)
        )
        #self.to_latent = nn.Identity()  # No projection, use VAE latent flat size as latent dim
        print(f"[ImageBranch2] VAE latent shape: (4, {self.spatial_dim}, {self.spatial_dim}), flat size: {self.vae_flat_dim}")
        self.from_latent = nn.Linear(latent_dim, self.vae_flat_dim)
        #self.from_latent = nn.Identity()  # No projection, use VAE latent flat size as latent dim
        print(f"[ImageBranch2] Decoder input flat size: {self.vae_flat_dim}")
        
        # Scaling factor for SD VAE
        self.scale_factor = 0.18215

    def encode(self, x):
        # Input x is [0, 1]. VAE expects [-1, 1].
        # print the min and max of x for debugging
        #print(f"[DEBUG ImageBranch2.encode] Input x: min={x.min().item():.4f}, max={x.max().item():.4f}")
        x_norm = 2.0 * x - 1.0
        #print(f"[DEBUG ImageBranch2.encode] Normalized x_norm: min={x_norm.min().item():.4f}, max={x_norm.max().item():.4f}")
        
        with torch.no_grad():
            # Get the mode of the distribution (deterministic encoding)
            #dist = self.vae.encode(x_norm).latent_dist
            #z_vae = dist.mode() * self.scale_factor
            output = self.vae.encode(x_norm)
            z_vae = output.latents
            z_vae = z_vae * self.scale_factor
            
        # Flatten
        z_flat = torch.flatten(z_vae, start_dim=1)
        
        # Project to bottleneck
        return self.to_latent(z_flat)

    def decode(self, z):
        # Project back
        z_flat = self.from_latent(z)
        
        # Unflatten
        z_vae = z_flat.view(-1, 4, self.spatial_dim, self.spatial_dim)
        z_vae = z_vae / self.scale_factor
        
        #with torch.no_grad():
        #    x_hat = self.vae.decode(z_vae).sample
        x_hat = self.vae.decode(z_vae).sample
            
        # Convert [-1, 1] back to [0, 1]
        #print(f"[DEBUG ImageBranch2.decode] Decoded x_hat before clamping: min={x_hat.min().item():.4f}, max={x_hat.max().item():.4f}")
        x_hat = (x_hat / 2.0 + 0.5).clamp(0, 1)
        #print(f"[DEBUG ImageBranch2.decode] Decoded x_hat after clamping: min={x_hat.min().item():.4f}, max={x_hat.max().item():.4f}")
        return x_hat


# ==========================================
# 3. TEXT BRANCH (Pretrained T5)
# ==========================================

class TextBranch2(nn.Module):
    def __init__(self, latent_dim, model_name="google/t5-efficient-tiny", freeze=True):
        """
        Args:
            latent_dim: Dimension of the bottleneck z
            model_name: HuggingFace model ID (t5-small, t5-base)
        """
        super().__init__()
        print(f"Loading Pretrained T5: {model_name}...")
        self.t5 = T5ForConditionalGeneration.from_pretrained(model_name)
        
        if freeze:
            for param in self.t5.parameters():
                param.requires_grad = False
            self.t5.eval()
            
        self.t5_dim = self.t5.config.d_model
        self.to_latent = nn.Linear(self.t5_dim, latent_dim)
        #self.to_latent = nn.Identity()  # No projection, use T5 hidden dim as latent dim
        print(f"[TextBranch2] T5 hidden dimension: {self.t5_dim}")
        self.from_latent = nn.Linear(latent_dim, self.t5_dim)
        #self.from_latent = nn.Identity()  # No projection, use T5 hidden dim as latent dim
        print(f"[TextBranch2] Decoder input dimension: {self.t5_dim}")
        
        # Access to output head weights if needed (usually handled internally by T5)
        self.to_latent.out_features = latent_dim # Store for reference

    def encode(self, input_ids, attention_mask):
        with torch.no_grad():
            # T5 Encoder
            enc_out = self.t5.encoder(input_ids=input_ids, attention_mask=attention_mask)
            hidden_states = enc_out.last_hidden_state # (B, Seq, Dim)
            
            # Mean Pooling (ignoring padding)
            mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
            sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
            sum_mask = mask_expanded.sum(1)
            sum_mask = torch.clamp(sum_mask, min=1e-9)
            pooled = sum_embeddings / sum_mask
            
        return self.to_latent(pooled)

    def decode(self, z, input_ids, attention_mask):
        # Project back to T5 dimension
        # Reshape to (B, 1, Dim) to act as a memory token
        memory = self.from_latent(z).unsqueeze(1)
        
        # Run T5 Decoder
        # We pass 'labels' so T5 automatically creates decoder_input_ids (shifted right)
        # We pass our latent 'memory' as the encoder_hidden_states
        outputs = self.t5(
            encoder_outputs=(memory,),
            labels=input_ids, 
            # T5 handles causal masking internally
        )
        
        return outputs.logits # (B, Seq, Vocab)


# ==========================================
# 4. MAIN MULTIMODAL MODEL
# ==========================================

class MultimodalAdaptiveAE2(nn.Module):
    def __init__(self, 
                 image_shape=(3, 256, 256), # SD VAE likes 256 or 512
                 latent_dim_img=128,
                 latent_dim_text=128,
                 latent_dim_shared=128,
                 # Arguments kept for compatibility with your training script calls
                 use_bert=None, bert_model_name=None, 
                 text_encoder_dim=None, text_encoder_depth=None, text_encoder_heads=None,
                 text_decoder_dim=None, text_decoder_depth=None, text_decoder_heads=None,
                 max_seq_len=None,
                 vocab_size=None):
        super().__init__()
        
        # --- Unimodal Encoders/Decoders (Pretrained) ---
        # We ignore the CNN/BERT args and use the pretrained defaults
        self.img_branch = ImageBranch2(image_shape, latent_dim_img)
        #latent_dim_img = self.img_branch.vae_flat_dim
        latent_dim_img = self.img_branch.latent_dim

        #self.text_branch = TextBranch2(latent_dim_text)
        self.text_branch = TrainableTextBranch(
            latent_dim=latent_dim_text,
            vocab_size=vocab_size if vocab_size is not None else 30522,
            encoder_dim=256,
            encoder_depth=2,
            encoder_heads=4,
            decoder_depth=2,
            max_seq_len=max_seq_len if max_seq_len is not None else 512,
            decoder_heads=4,
            decoder_dim=256,
            dropout=0.1
        )
        
        # --- Adaptive Bottlenecks ---
        self.shared_layer = AdaptiveRankReducedLinear(
            in_features=latent_dim_img + latent_dim_text,
            out_features=latent_dim_shared
        )
        
        self.img_spec_layer = AdaptiveRankReducedLinear(latent_dim_img, latent_dim_shared)
        self.text_spec_layer = AdaptiveRankReducedLinear(latent_dim_text, latent_dim_shared)
        
        self.adaptive_layers = nn.ModuleList([
            self.shared_layer, 
            self.img_spec_layer, 
            self.text_spec_layer
        ])

        # --- Fusion layers ---
        # Project (Shared + Specific) -> Unimodal Decoder Input
        self.img_fusion = nn.Linear(latent_dim_shared * 2, latent_dim_img)
        self.text_fusion = nn.Linear(latent_dim_shared * 2, latent_dim_text)
        
        # Initialize fusion layers to Identity/Zero for smooth fine-tuning
        """
        with torch.no_grad():
            # --- 1. Image: Identity on specific, Zero on shared ---
            nn.init.zeros_(self.img_fusion.weight)
            nn.init.zeros_(self.img_fusion.bias)
            # Set diagonal of specific part to 1
            self.img_fusion.weight[:, -latent_dim_img:].fill_diagonal_(1.0)
            #nn.init.eye_(self.img_branch.to_latent.linear.weight)
            #nn.init.zeros_(self.img_branch.from_latent.linear.bias)
            #nn.init.eye_(self.img_branch.from_latent.linear.weight)
            #nn.init.zeros_(self.img_branch.to_latent.linear.bias)
            
            # Text: Identity on specific, Zero on shared
            nn.init.zeros_(self.text_fusion.weight)
            nn.init.zeros_(self.text_fusion.bias)
            self.text_fusion.weight[:, -latent_dim_text:].fill_diagonal_(1.0)
        """
        # Initialize all trainable parameters with small weights (like transformers)
        with torch.no_grad():
            # to and from latent layers
            nn.init.xavier_uniform_(self.img_branch.to_latent[1].weight, gain=0.5)
            nn.init.zeros_(self.img_branch.to_latent[1].bias)
            nn.init.xavier_uniform_(self.img_branch.from_latent.weight, gain=0.5)
            nn.init.zeros_(self.img_branch.from_latent.bias)

            # fusion layers
            nn.init.xavier_uniform_(self.img_fusion.weight, gain=0.5)
            nn.init.zeros_(self.img_fusion.bias)
            nn.init.xavier_uniform_(self.text_fusion.weight, gain=0.5)
            nn.init.zeros_(self.text_fusion.bias)

            # shared layers
            nn.init.xavier_uniform_(self.shared_layer.U, gain=0.5)
            nn.init.xavier_uniform_(self.img_spec_layer.U, gain=0.5)
            nn.init.xavier_uniform_(self.text_spec_layer.U, gain=0.5)
            nn.init.xavier_uniform_(self.shared_layer.V, gain=0.5)
            nn.init.xavier_uniform_(self.img_spec_layer.V, gain=0.5)
            nn.init.xavier_uniform_(self.text_spec_layer.V, gain=0.5)

    def encode(self, images, text_input_ids, text_mask):
        h_img = self.img_branch.encode(images)
        h_text = self.text_branch.encode(text_input_ids, text_mask)
        #return h_img, h_text
        
        # Note: No LayerNorms here because we want to preserve the pretrained scale
        
        h = torch.cat([h_img, h_text], dim=1)
        h_shared = self.adaptive_layers[0](h)
        h_specific = []
        for layer in self.adaptive_layers[1:]:
            h_specific.append(layer(h_img if layer == self.img_spec_layer else h_text))
        return (h_shared, h_specific)
    
    def decode(self, h, text_input_ids, text_mask):
        #h_image, h_text = h
        #recon_img = self.img_branch.decode(h_image)
        #text_logits = self.text_branch.decode(h_text, text_input_ids, text_mask)
        #return [recon_img, text_logits]

        h_shared, h_specific = h
        x_hat = []
        for m, h_m in enumerate(h_specific):
            h_concat = torch.cat([h_shared, h_m], dim=1)
            
            if m == 0: # Image Modality
                z_img_full = self.img_fusion(h_concat)
                recon_img = self.img_branch.decode(z_img_full)
                x_hat.append(recon_img)
            else: # Text Modality
                z_text_full = self.text_fusion(h_concat)
                # T5 Decoder needs input_ids for teacher forcing
                text_logits = self.text_branch.decode(z_text_full, text_input_ids, text_mask)
                x_hat.append(text_logits)
        return x_hat

    def forward(self, images, text_input_ids, text_mask):
        h = self.encode(images, text_input_ids, text_mask)
        x_hat = self.decode(h, text_input_ids, text_mask)
        return x_hat, h
    
    def encode_modalities(self, images, text_input_ids, text_mask):
        h_shared, h_specific = self.encode(images, text_input_ids, text_mask)
        h_combined = []
        for i, h_m in enumerate(h_specific):
            h_combined.append(torch.cat([h_shared, h_m], dim=1))
        return h_combined
    
    def reduce_rank(self, reduction_ratio=0.9, threshold=0.01, layer_ids=[], dim=0):
        changes_made = False
        for i, layer in enumerate(self.adaptive_layers):
            if i not in layer_ids: continue
            S = layer.get_rank_reduction_info()
            if len(S) <= layer.min_rank: continue
            
            energy = S**2
            cumulative_energy = torch.cumsum(energy / energy.sum(), dim=0)
            target_rank = max(layer.min_rank, torch.sum(cumulative_energy < (1.0 - threshold)).item())
            
            current_rank = layer.active_dims
            ratio_rank = max(layer.min_rank, int(current_rank * reduction_ratio))
            new_rank = max(target_rank, ratio_rank)
            
            if new_rank < current_rank:
                layer.reduce_rank(new_rank)
                changes_made = True
        return changes_made

    def increase_rank(self, increment=None, increase_ratio=1.1, layer_ids=[], dim=0):
        changes_made = False
        for i, layer in enumerate(self.adaptive_layers):
            if i not in layer_ids: continue
            if layer.increase_rank(increment=increment, increase_ratio=increase_ratio):
                changes_made = True
        return changes_made
    
    def get_total_rank(self):
        return sum(layer.active_dims for layer in self.adaptive_layers)

from transformers import BertConfig, BertModel, BertLMHeadModel

class TextBranch3(nn.Module):
    def __init__(self, latent_dim, model_name='prajjwal1/bert-mini', freeze=True):
        """
        BERT-to-BERT Autoencoder. 
        Uses a standard BERT encoder and adapts a BERT model to act as a Decoder
        by enabling is_decoder=True and add_cross_attention=True.
        """
        super().__init__()
        print(f"Loading Pretrained TinyBERT: {model_name}...")
        
        # --- 1. ENCODER (Standard BERT) ---
        self.encoder = BertModel.from_pretrained(model_name)
        
        # --- 2. DECODER (BERT adapted to Causal LM) ---
        # Load config and modify to enable decoding features
        decoder_config = BertConfig.from_pretrained(model_name)
        decoder_config.is_decoder = True
        decoder_config.add_cross_attention = True
        
        # Load weights into the decoder structure
        self.decoder = BertLMHeadModel.from_pretrained(model_name, config=decoder_config)
        
        # --- 3. FREEZING LOGIC ---
        if freeze:
            print("[TextBranch3] Freezing Encoder and Pretrained Decoder layers...")
            # Freeze Encoder completely
            for param in self.encoder.parameters():
                param.requires_grad = False
            self.encoder.eval()
            
            # Freeze Decoder, BUT keep Cross-Attention trainable
            # (Cross-Attention weights are newly initialized random, they MUST be trained)
            for name, param in self.decoder.named_parameters():
                if "crossattention" in name or "cross_attention" in name:
                    param.requires_grad = True # Keep new glue layers trainable
                else:
                    param.requires_grad = False # Freeze pretrained self-attn/ffn
        
        self.bert_dim = self.encoder.config.hidden_size
        
        # --- 4. PROJECTIONS ---
        self.to_latent = nn.Linear(self.bert_dim, latent_dim)
        print(f"[TextBranch3] BERT hidden dimension: {self.bert_dim}")
        
        self.from_latent = nn.Linear(latent_dim, self.bert_dim)
        print(f"[TextBranch3] Decoder input dimension: {self.bert_dim}")
        
        # Store output dim reference if needed by outer class
        self.to_latent.out_features = latent_dim

    def encode(self, input_ids, attention_mask):
        # 1. Encoder Forward
        # (We don't need no_grad here if freeze=True handles the graph pruning)
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        
        # 2. Pooling Strategy: CLS Token
        # The [CLS] token is the first token (index 0)
        cls_token = outputs.last_hidden_state[:, 0, :] # (B, Hidden)
        
        # 3. Project
        return self.to_latent(cls_token)

    def decode(self, z, input_ids, attention_mask):
        """
        z: (B, Latent)
        input_ids: The target text (used for teacher forcing)
        """
        # 1. Project back & Expand
        # Shape: (B, 1, Hidden) -> Acts as the "Encoder Hidden State"
        memory = self.from_latent(z).unsqueeze(1)
        
        # 2. Decoder Forward
        # BertLMHeadModel handles shifting input_ids internally for loss if 'labels' are passed,
        # but here we just want logits. 
        # Note: input_ids here acts as the "context so far" for the causal mask.
        outputs = self.decoder(
            input_ids=input_ids,             # Teacher forcing input
            encoder_hidden_states=memory,    # The bottleneck z
            # encoder_attention_mask is not needed since memory len is 1 (always valid)
        )
        
        return outputs.logits # (B, Seq, Vocab)

class MultimodalAdaptiveAE3(nn.Module):
    def __init__(self, 
                 image_shape=(3, 256, 256), # SD VAE likes 256 or 512
                 latent_dim_img=128,
                 latent_dim_text=128,
                 latent_dim_shared=128,
                 # Arguments kept for compatibility with your training script calls
                 use_bert=None, bert_model_name=None, 
                 text_encoder_dim=None, text_encoder_depth=None, text_encoder_heads=None,
                 text_decoder_dim=None, text_decoder_depth=None, text_decoder_heads=None,
                 max_seq_len=None,
                 vocab_size=None):
        super().__init__()
        
        # --- Unimodal Encoders/Decoders (Pretrained) ---
        # We ignore the CNN/BERT args and use the pretrained defaults
        self.img_branch = ImageBranch2(image_shape, latent_dim_img)
        #latent_dim_img = self.img_branch.vae_flat_dim
        latent_dim_img = self.img_branch.latent_dim

        #self.text_branch = TextBranch2(latent_dim_text)
        #self.text_branch = TextBranch3(latent_dim_text, model_name='prajjwal1/bert-mini')
        self.text_branch = TextBranch3(latent_dim_text, model_name='prajjwal1/bert-tiny')
        
        # --- Adaptive Bottlenecks ---
        self.shared_layer = AdaptiveRankReducedLinear(
            in_features=latent_dim_img + latent_dim_text,
            out_features=latent_dim_shared
        )
        
        self.img_spec_layer = AdaptiveRankReducedLinear(latent_dim_img, latent_dim_shared)
        self.text_spec_layer = AdaptiveRankReducedLinear(latent_dim_text, latent_dim_shared)
        
        self.adaptive_layers = nn.ModuleList([
            self.shared_layer, 
            self.img_spec_layer, 
            self.text_spec_layer
        ])

        # --- Fusion layers ---
        # Project (Shared + Specific) -> Unimodal Decoder Input
        self.img_fusion = nn.Linear(latent_dim_shared * 2, latent_dim_img)
        self.text_fusion = nn.Linear(latent_dim_shared * 2, latent_dim_text)
        
        # Initialize all trainable parameters with small weights (like transformers)
        # NOTE: MultimodalAdaptiveAE3 uses TextBranch2 (T5-based), not TrainableTextBranch
        with torch.no_grad():
            # to and from latent layers - Image branch
            nn.init.xavier_uniform_(self.img_branch.to_latent[1].weight, gain=0.5)
            nn.init.zeros_(self.img_branch.to_latent[1].bias)
            nn.init.xavier_uniform_(self.img_branch.from_latent.weight, gain=0.5)
            nn.init.zeros_(self.img_branch.from_latent.bias)
            
            # to and from latent layers - Text branch (TextBranch2 uses single Linear, not Sequential)
            nn.init.xavier_uniform_(self.text_branch.to_latent.weight, gain=0.5)
            nn.init.zeros_(self.text_branch.to_latent.bias)
            nn.init.xavier_uniform_(self.text_branch.from_latent.weight, gain=0.5)
            nn.init.zeros_(self.text_branch.from_latent.bias)

            # fusion layers
            nn.init.xavier_uniform_(self.img_fusion.weight, gain=0.5)
            nn.init.zeros_(self.img_fusion.bias)
            nn.init.xavier_uniform_(self.text_fusion.weight, gain=0.5)
            nn.init.zeros_(self.text_fusion.bias)

            # shared layers
            nn.init.xavier_uniform_(self.shared_layer.U, gain=0.5)
            nn.init.xavier_uniform_(self.img_spec_layer.U, gain=0.5)
            nn.init.xavier_uniform_(self.text_spec_layer.U, gain=0.5)
            nn.init.xavier_uniform_(self.shared_layer.V, gain=0.5)
            nn.init.xavier_uniform_(self.img_spec_layer.V, gain=0.5)
            nn.init.xavier_uniform_(self.text_spec_layer.V, gain=0.5)

    def encode(self, images, text_input_ids, text_mask):
        h_img = self.img_branch.encode(images)
        h_text = self.text_branch.encode(text_input_ids, text_mask)
        #return h_img, h_text
        
        # Note: No LayerNorms here because we want to preserve the pretrained scale
        
        h = torch.cat([h_img, h_text], dim=1)
        h_shared = self.adaptive_layers[0](h)
        h_specific = []
        for layer in self.adaptive_layers[1:]:
            h_specific.append(layer(h_img if layer == self.img_spec_layer else h_text))
        return (h_shared, h_specific)
    
    def decode(self, h, text_input_ids, text_mask):
        #h_image, h_text = h
        #recon_img = self.img_branch.decode(h_image)
        #text_logits = self.text_branch.decode(h_text, text_input_ids, text_mask)
        #return [recon_img, text_logits]

        h_shared, h_specific = h
        x_hat = []
        for m, h_m in enumerate(h_specific):
            h_concat = torch.cat([h_shared, h_m], dim=1)
            
            if m == 0: # Image Modality
                z_img_full = self.img_fusion(h_concat)
                recon_img = self.img_branch.decode(z_img_full)
                x_hat.append(recon_img)
            else: # Text Modality
                z_text_full = self.text_fusion(h_concat)
                # T5 Decoder needs input_ids for teacher forcing
                text_logits = self.text_branch.decode(z_text_full, text_input_ids, text_mask)
                x_hat.append(text_logits)
        return x_hat

    def forward(self, images, text_input_ids, text_mask):
        h = self.encode(images, text_input_ids, text_mask)
        x_hat = self.decode(h, text_input_ids, text_mask)
        return x_hat, h
    
    def encode_modalities(self, images, text_input_ids, text_mask):
        h_shared, h_specific = self.encode(images, text_input_ids, text_mask)
        h_combined = []
        for i, h_m in enumerate(h_specific):
            h_combined.append(torch.cat([h_shared, h_m], dim=1))
        return h_combined
    
    def reduce_rank(self, reduction_ratio=0.9, threshold=0.01, layer_ids=[], dim=0):
        changes_made = False
        for i, layer in enumerate(self.adaptive_layers):
            if i not in layer_ids: continue
            S = layer.get_rank_reduction_info()
            if len(S) <= layer.min_rank: continue
            
            energy = S**2
            cumulative_energy = torch.cumsum(energy / energy.sum(), dim=0)
            target_rank = max(layer.min_rank, torch.sum(cumulative_energy < (1.0 - threshold)).item())
            
            current_rank = layer.active_dims
            ratio_rank = max(layer.min_rank, int(current_rank * reduction_ratio))
            new_rank = max(target_rank, ratio_rank)
            
            if new_rank < current_rank:
                layer.reduce_rank(new_rank)
                changes_made = True
        return changes_made

    def increase_rank(self, increment=None, increase_ratio=1.1, layer_ids=[], dim=0):
        changes_made = False
        for i, layer in enumerate(self.adaptive_layers):
            if i not in layer_ids: continue
            if layer.increase_rank(increment=increment, increase_ratio=increase_ratio):
                changes_made = True
        return changes_made
    
    def get_total_rank(self):
        return sum(layer.active_dims for layer in self.adaptive_layers)