import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import lightning as L
import vector_quantize_pytorch

from models.base import BaseModel


class PositionalEncoding(nn.Module):
    """
    Standard sinusoidal positional encoding.

    Injects information about the relative or absolute position of tokens
    in the sequence. The positional encodings have the same dimension as
    the embeddings so that they can be summed. Uses sine and cosine
    functions of different frequencies.

    Reference: "Attention Is All You Need" (Vaswani et al., 2017)

    Args:
        d_model (int): The embedding dimension.
        max_len (int): The maximum sequence length anticipated. Default: 5000.

    Shape:
        - Input: `(batch_size, seq_len, d_model)`
        - Output: `(batch_size, seq_len, d_model)`
    """

    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        if d_model % 2 != 0:
            # Standard implementation assumes d_model is even for sin/cos pairs.
            # While adaptable, adhering to the standard is common.
            print(
                f"Warning: d_model ({d_model}) should ideally be even for standard sinusoidal PE."
            )

        # Create positional encoding matrix `pe` of shape [max_len, d_model]
        pe = torch.zeros(max_len, d_model)

        # Create position indices tensor `position` of shape [max_len, 1]
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(
            1
        )  # [[0], [1], ..., [max_len-1]]

        # Compute the division term: 1 / (10000^(2i / d_model))
        # Uses exp/log for numerical stability: exp(-log(10000) * (2i / d_model))
        # `div_term` shape: [d_model / 2]
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float()
            * (-torch.log(torch.tensor(10000.0)) / d_model)
        )

        # Calculate sinusoidal encodings
        pe[:, 0::2] = torch.sin(position * div_term)  # Even indices (2i) use sin
        pe[:, 1::2] = torch.cos(position * div_term)  # Odd indices (2i+1) use cos

        # Add batch dimension: [1, max_len, d_model]
        pe = pe.unsqueeze(0)

        # Register `pe` as a buffer. Buffers are part of the model's state
        # but are not updated by the optimizer. Ideal for fixed encodings.
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Adds positional encoding to the input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape `(batch_size, seq_len, d_model)`.

        Returns:
            torch.Tensor: Output tensor with added positional encoding,
                        shape `(batch_size, seq_len, d_model)`.
        """
        # x shape: [batch_size, seq_len, d_model]
        # self.pe shape: [1, max_len, d_model]
        # Slice `pe` to match the input sequence length `seq_len = x.size(1)`.
        # `self.pe[:, :x.size(1), :]` has shape [1, seq_len, d_model].
        # Addition uses broadcasting: [batch_size, seq_len, d_model] + [1, seq_len, d_model]
        x = x + self.pe[:, : x.size(1), :]
        return x


# --- Encoder Module ---
class MotionEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hparams = config # Store config for easy access to hyperparameters
        self.eps = 1e-6

        # --- Calculate derived parameters ---
        assert self.hparams.window_size[0] % self.hparams.block_size == 0
        self.tokens_per_sensor = self.hparams.window_size[0] // self.hparams.block_size # L

        # --- Special Token Indices (Needed for masking logic inside) ---
        self.num_motion_tokens = self.hparams.codebook_size
        self.start_token_idx = self.num_motion_tokens
        self.end_token_idx = self.start_token_idx + 1
        self.mask_token_idx = self.num_motion_tokens + 2

        self.register_buffer("start_token_tensor", torch.tensor([self.start_token_idx]))
        self.register_buffer("end_token_tensor", torch.tensor([self.end_token_idx]))
        self.register_buffer("mask_token_tensor", torch.tensor([self.mask_token_idx]))


        # --- Layer Definitions ---
        # 1. Window Projector
        self.hparams.vq_dim = getattr(self.hparams, 'vq_dim', 512)
        # Assuming 1 feature per sensor as per user clarification
        self.num_features_per_sensor = 1
        self.window_feature_dim = self.hparams.block_size * self.num_features_per_sensor
        self.window_projector = nn.Linear(self.window_feature_dim, self.hparams.vq_dim)

        # 2. Normalization Parameter Projector
        self.norm_param_projector = nn.Linear(2, self.hparams.embedding_dim)

        # 3. VQ Layer
        self.vq_layer = vector_quantize_pytorch.SimVQ(
            dim=self.hparams.vq_dim,
            codebook_size=self.hparams.codebook_size,
        )

        # 4. Token Embedding Layer (Motion + MASK only needed here, others added later)
        # Size needs to accommodate motion tokens and the MASK token
        self.token_embedder = nn.Embedding(self.hparams.codebook_size + 32, self.hparams.embedding_dim)
                                             # Increased size to safely include MASK_IDX

        # 5. Sensor Embedding Adapter
        self.hparams.text_embedding_dim = getattr(self.hparams, 'text_embedding_dim', 768)
        self.sensor_adapter = nn.Linear(self.hparams.text_embedding_dim, self.hparams.embedding_dim)

        # 6. Time Step Positional Encoding
        self.time_pe = PositionalEncoding(self.hparams.embedding_dim, max_len=self.tokens_per_sensor + 10)

        # 7. Absolute Positional Encoding (Applied after sequence construction)
        max_sensors = getattr(self.hparams, 'max_sensors', 32)
        max_total_len = 1 + max_sensors * (1 + self.tokens_per_sensor + 1) # CLS + S*(Start+L+End)
        self.absolute_pe = PositionalEncoding(self.hparams.embedding_dim, max_len=max_total_len + 10)

        # 8. Transformer Encoder
        encoder_layer = TransformerEncoderLayer(
            d_model=self.hparams.embedding_dim,
            nhead=self.hparams.transformer_nhead,
            dim_feedforward=self.hparams.transformer_dim_feedforward,
            dropout=self.hparams.transformer_dropout,
            activation=getattr(F, self.hparams.transformer_activation, F.relu),
            batch_first=True
        )
        self.transformer_encoder = TransformerEncoder(
            encoder_layer,
            num_layers=self.hparams.transformer_num_layers
        )

    def forward(self, x: Tensor, sensor_embeddings: Tensor, mask_ratio: float = 0.0, cls_token_embedding: Tensor = None) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]:
        """
        Encodes the input motion sequence.

        Args:
            x: Input tensor (B, T, S).
            sensor_embeddings: Pre-trained sensor embeddings (B, S, TextDim).
            mask_ratio: Probability for MAE masking.
            cls_token_embedding: Learnable CLS token embedding (1, 1, D).
            start_emb: Start token embedding (D,).
            end_emb: End token embedding (D,).

        Returns:
            tuple:
                - transformer_output: Output sequence from Transformer (B, SeqLen, D).
                - commit_loss: VQ commit loss.
                - target_mae_ids: Original VQ indices of masked tokens (NumMasked,) or None.
                - masked_indices_for_loss: Boolean mask indicating masked positions (B, S, L) or None.
        """
        B, T, S = x.shape
        L = self.tokens_per_sensor
        D = self.hparams.embedding_dim

        # --- 1. Windowing, Normalization & Projection ---
        x_permuted = x.permute(0, 2, 1)  # (B, S, T)
        windows = x_permuted.unfold(dimension=2, size=self.hparams.block_size, step=self.hparams.block_size) # (B, S, L, Block_size)
        windows_flat = windows.reshape(-1, self.hparams.block_size * self.num_features_per_sensor) # (B*S*L, Block_size)

        mean = windows_flat.mean(dim=1, keepdim=True)
        std = windows_flat.std(dim=1, keepdim=True)
        normalized_windows = (windows_flat - mean) / (std + self.eps) # (B*S*L, Block_size)
        norm_params = torch.cat([mean, std], dim=1)
        norm_param_embedding = self.norm_param_projector(norm_params)
        norm_param_embedding = norm_param_embedding.view(B, S, L, D)

        projected_windows = self.window_projector(normalized_windows) # (B*S*L, D)

        # --- 2. VQ Quantization ---
        quantized, vq_indices, commit_loss = self.vq_layer(projected_windows) # (B*S*L, D) (B*S*L) (1)
        original_vq_indices = vq_indices.view(B, S, L) # Keep original for MAE target

        # --- 3. MAE Masking (Operate on Indices) ---
        target_mae_ids = None
        masked_indices_for_loss = None
        indices_for_embedding = original_vq_indices # Default to original indices

        if mask_ratio > 0.0 and self.training: # Apply mask only during training
            prob = torch.full(original_vq_indices.shape, mask_ratio, device=x.device)
            mae_mask_bool = torch.bernoulli(prob).bool()
            target_mae_ids = original_vq_indices[mae_mask_bool] # Store original IDs
            masked_vq_indices = original_vq_indices.clone()
            masked_vq_indices[mae_mask_bool] = self.mask_token_tensor # Apply mask ID
            indices_for_embedding = masked_vq_indices # Use masked indices for lookup
            masked_indices_for_loss = mae_mask_bool # Store mask bool for MAE head

        # --- 4. Get Embeddings ---
        motion_embeds = self.token_embedder(indices_for_embedding) # (B, S, L, D)
        # Start/End embeddings are passed in from the main model
        sensor_embeds_adapted = self.sensor_adapter(sensor_embeddings) # (B, S, D)

        # --- 5. Combine Embeddings ---
        motion_embeds_bs_l_d = motion_embeds.view(B * S, L, D)
        motion_embeds_time_pe = self.time_pe(motion_embeds_bs_l_d)
        motion_embeds_time_pe = motion_embeds_time_pe.view(B, S, L, D)

        final_motion_embeds = (
            motion_embeds_time_pe
            + sensor_embeds_adapted.view(B, S, 1, D) # Add sensor dim for broadcasting
            + norm_param_embedding
        )

        # Prepare Start/End with sensor info (need start_emb, end_emb from main model)
        start_emb = self.token_embedder(self.start_token_tensor) # Shape: (D,)
        end_emb = self.token_embedder(self.end_token_tensor)   # Shape: (D,)
        start_tokens_sensor = start_emb.view(1, 1, D) + sensor_embeds_adapted
        end_tokens_sensor = end_emb.view(1, 1, D) + sensor_embeds_adapted
        start_tokens_final = start_tokens_sensor.unsqueeze(2)
        end_tokens_final = end_tokens_sensor.unsqueeze(2)

        # --- 6. Construct Sensor Sequences ---
        full_sensor_sequences = torch.cat([start_tokens_final, final_motion_embeds, end_tokens_final], dim=2) # (B, S, 1+L+1, D)
        flat_sensor_sequence = full_sensor_sequences.reshape(B, -1, D) # (B, S*(L+2), D)

        # --- 7. Prepend CLS Token ---
        if cls_token_embedding is None:
             raise ValueError("CLS token embedding must be provided to the encoder")
        cls_embeddings = cls_token_embedding.expand(B, -1, -1) # (B, 1, D)
        transformer_input_sequence = torch.cat([cls_embeddings, flat_sensor_sequence], dim=1) # (B, 1+S*(L+2), D)

        # --- 8. Add Absolute Positional Encoding ---
        transformer_input_sequence = self.absolute_pe(transformer_input_sequence)

        # --- 9. Transformer Encoder ---
        transformer_output = self.transformer_encoder(transformer_input_sequence) # (B, 1+S*(L+2), D)

        return transformer_output, commit_loss, target_mae_ids, masked_indices_for_loss
    

# --- MAE Head Module ---
class MAEHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mae_head = nn.Linear(config.embedding_dim, config.codebook_size)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, encoder_output: Tensor, target_mae_ids: Tensor | None, masked_indices_bool: Tensor | None) -> tuple[Tensor | None, Tensor | None]:
        """
        Calculates MAE loss from encoder output.

        Args:
            encoder_output: Full output sequence from MotionEncoder (B, SeqLen, D).
            target_mae_ids: Original VQ indices of masked tokens (NumMasked,).
            masked_indices_bool: Boolean mask indicating masked motion token positions (B, S, L).

        Returns:
            tuple: (mae_loss, mae_predictions)
        """
        mae_loss = None
        mae_predictions = None

        if masked_indices_bool is not None and target_mae_ids is not None:
            B, S, L = masked_indices_bool.shape
            D = encoder_output.shape[-1]
            seq_len_per_sensor_with_specials = 1 + L + 1 # Start + L + End

            # Select masked output states (same logic as before)
            mask_flat_bool = masked_indices_bool.view(B, -1)
            transformer_output_motion_only = encoder_output[:, 1:].reshape(B, S, seq_len_per_sensor_with_specials, D)[:, :, 1:1+L, :]
            transformer_output_motion_only_flat = transformer_output_motion_only.reshape(B, S*L, D)
            masked_output_states = transformer_output_motion_only_flat[mask_flat_bool]

            if masked_output_states.shape[0] > 0:
                mae_logits = self.mae_head(masked_output_states)
                mae_predictions = mae_logits
                if target_mae_ids.numel() > 0:
                    mae_loss = self.loss_fn(mae_logits, target_mae_ids)
                else:
                    mae_loss = torch.tensor(0.0, device=encoder_output.device)
            else:
                 mae_loss = torch.tensor(0.0, device=encoder_output.device)

        return mae_loss, mae_predictions


# --- Classification Head Module ---
class ClassificationHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.classification_head = nn.Linear(config.embedding_dim, config.num_classes)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, encoder_output: Tensor, labels: Tensor | None) -> tuple[Tensor | None, Tensor]:
        """
        Calculates classification loss/predictions from encoder output.

        Args:
            encoder_output: Full output sequence from MotionEncoder (B, SeqLen, D).
            labels: Ground truth labels (B,).

        Returns:
            tuple: (classification_loss, classification_predictions)
        """
        cls_token_output = encoder_output[:, 0] # Extract CLS token state
        classification_logits = self.classification_head(cls_token_output)
        classification_predictions = classification_logits

        classification_loss = None
        if labels is not None:
            classification_loss = self.loss_fn(classification_logits, labels)

        return classification_loss, classification_predictions
    
# --- Main MoVQFormer Model ---
class l_model(BaseModel):
    def __init__(self, config):
        super().__init__(config) # Handles hparams, base metrics
    
        # --- Training Mode Configuration ---
        # 'pretrain': Focus on MAE + commit loss
        # 'finetune': Focus on Classification + commit loss (potentially freezing encoder)
        # 'joint': Train both MAE and Classification
        self.mode = getattr(self.hparams, 'mode', 'joint') # Default to joint training

        # --- Define Core Components ---
        self.encoder = MotionEncoder(config)
        self.mae_head = MAEHead(config)

        # --- MODIFIED: Conditional Initialization of Classification Head ---
        self.classification_head = None # Initialize as None
        if self.mode == 'finetune' or self.mode == 'joint':
            self.classification_head = ClassificationHead(config)
        else:
            print(f"Skipping ClassificationHead initialization for mode '{self.mode}'.")
        # ---

        # --- CLS Token Parameter  ---
        self.cls_token_embedding = nn.Parameter(torch.zeros(1, 1, self.hparams.embedding_dim))
        nn.init.normal_(self.cls_token_embedding, std=0.02)



    def forward(
        self,
        x: Tensor,
        sensor_embeddings: Tensor,
        labels: Tensor = None,
        mask_ratio: float = 0.0,
    ) -> dict:
        """
        Main forward pass orchestrating encoder and heads based on mode.
        """
        # 1. Run Encoder
        transformer_output, commit_loss, target_mae_ids, masked_indices_bool = self.encoder(
            x=x,
            sensor_embeddings=sensor_embeddings,
            mask_ratio=mask_ratio if self.mode != 'finetune' else 0.0,
            cls_token_embedding=self.cls_token_embedding.to(x.device),
        )

        # 2. Run Heads
        mae_loss, mae_predictions = self.mae_head(
            encoder_output=transformer_output,
            target_mae_ids=target_mae_ids,
            masked_indices_bool=masked_indices_bool
        )

        # --- MODIFIED: Conditional Classification Head Execution ---
        classification_loss = None
        classification_predictions = None
        if self.classification_head is not None: # Check if head exists
            classification_loss, classification_predictions = self.classification_head(
                encoder_output=transformer_output,
                labels=labels
            )

        # 3. Package Results
        results = {
            "commit_loss": commit_loss.mean(),
            "mae_loss": mae_loss,
            "classification_loss": classification_loss,
            "mae_predictions": mae_predictions,
            "classification_predictions": classification_predictions,
            # Optionally include encoder output if needed elsewhere
            # "encoder_output": transformer_output
        }
        results = {k: v for k, v in results.items() if v is not None}
        return results

    # --- Override training_step ---
    def training_step(self, batch, batch_idx):
        x, sensor_embeddings, y = batch[0], batch[1], batch[2]

        # Determine mask ratio based on mode
        if self.mode == 'pretrain' or self.mode == 'joint':
            # Use mask_ratio from hparams, defaulting to 0.25 if not present
            train_mask_ratio = self.hparams.get("mask_ratio", 0.25)
        else: # Finetune mode
            train_mask_ratio = 0.0

        # Forward pass
        outputs = self(x=x, sensor_embeddings=sensor_embeddings, labels=y, mask_ratio=train_mask_ratio)

        # Extract raw losses
        classification_loss_raw = outputs.get("classification_loss", torch.tensor(0.0, device=self.device))
        mae_loss_raw = outputs.get("mae_loss", torch.tensor(0.0, device=self.device))
        commit_loss_raw = outputs.get("commit_loss", torch.tensor(0.0, device=self.device))

        # --- Consistent Weight Retrieval (Default to 1.0) ---
        w_cls = self.hparams.get('classification_loss_weight', 1.0)
        w_mae = self.hparams.get('mae_loss_weight', 1.0)
        w_commit = self.hparams.get('commit_loss_weight', 1.0)

        # --- Calculate Weighted Loss Components ---
        classification_loss_w = w_cls * classification_loss_raw
        mae_loss_w = w_mae * mae_loss_raw
        commit_loss_w = w_commit * commit_loss_raw

        # --- Calculate Total Weighted Loss based on mode ---
        # Include components only if relevant to the current training mode
        total_loss = torch.tensor(0.0, device=self.device)
        total_loss += commit_loss_w # 总是相关的

        if self.mode == 'pretrain':
            total_loss += mae_loss_w
        elif self.mode == 'joint':
            total_loss += mae_loss_w
            if self.classification_head is not None:
                if classification_loss_raw.numel() > 0 and classification_loss_raw.item() != 0.0 :
                    total_loss += classification_loss_w
        elif self.mode == 'finetune':
            if self.classification_head is not None: # 确保分类头存在且有损失
                if classification_loss_raw.numel() > 0 and classification_loss_raw.item() != 0.0 :
                    total_loss += classification_loss_w
        else:
            print(f"Warning: Unknown training mode '{self.mode}' for loss calculation.")

        # --- Logging ---
        # Log individual *weighted* losses
        self.log("loss/train/classification_w", classification_loss_w, on_step=True, on_epoch=False, logger=True, sync_dist=True)
        self.log("loss/train/mae_w", mae_loss_w, on_step=True, on_epoch=False, logger=True, sync_dist=True)
        self.log("loss/train/commit_w", commit_loss_w, on_step=True, on_epoch=False, logger=True, sync_dist=True)

        # Log total weighted loss with standard key for general visualization
        self.log("loss/train/total", total_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)

        # Logged only on epoch end, as monitor typically works on epoch level
        self.log("loss_train_monitor", total_loss, on_step=False, on_epoch=True, logger=False, sync_dist=True)

        # --- Update Classification Metrics ---
        classification_logits = outputs.get("classification_predictions")
        # Check if classification happened AND metrics exist
        if classification_logits is not None:
            self.train_metrics.update(classification_logits, y)

        return total_loss
    

    # --- Override validation_step ---
    def validation_step(self, batch, batch_idx):
        x, sensor_embeddings, y = batch[0], batch[1], batch[2]
        # No masking during validation forward pass
        outputs = self(x=x, sensor_embeddings=sensor_embeddings, labels=y, mask_ratio=0.0)

        # Extract raw losses
        classification_loss_raw = outputs.get("classification_loss", torch.tensor(0.0, device=self.device))
        mae_loss_raw = outputs.get("mae_loss", torch.tensor(0.0, device=self.device)) # MAE loss without masking
        commit_loss_raw = outputs.get("commit_loss", torch.tensor(0.0, device=self.device))

        # --- Consistent Weight Retrieval (Default to 1.0) ---
        # Use the SAME weights as training for comparable scale
        w_cls = self.hparams.get('classification_loss_weight', 1.0)
        w_mae = self.hparams.get('mae_loss_weight', 1.0)
        w_commit = self.hparams.get('commit_loss_weight', 1.0)

        # --- Calculate Weighted Loss Components ---
        classification_loss_w = w_cls * classification_loss_raw
        mae_loss_w = w_mae * mae_loss_raw # Weighted MAE loss (even if mask_ratio=0)
        commit_loss_w = w_commit * commit_loss_raw

        # --- Calculate Total Weighted Validation Loss ---
        # Summing all weighted components for consistent scale with training
        val_loss_total_w = classification_loss_w + mae_loss_w + commit_loss_w

        # --- Logging ---
        # Log individual weighted losses on epoch
        self.log("loss/val/classification_w", classification_loss_w, on_step=False, on_epoch=True, logger=True, sync_dist=True)
        self.log("loss/val/mae_w", mae_loss_w, on_step=False, on_epoch=True, logger=True, sync_dist=True)
        self.log("loss/val/commit_w", commit_loss_w, on_step=False, on_epoch=True, logger=True, sync_dist=True)

        # Log total weighted validation loss
        # Use "loss/val" key as it's often the primary monitored value for validation
        self.log("loss/val", val_loss_total_w, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)

        # --- Update Classification Metrics ---
        # Update metrics regardless of mode, as we always run classification head in validation
        classification_logits = outputs.get("classification_predictions")
        # Check if classification happened AND metrics exist
        if classification_logits is not None:
            self.val_metrics.update(classification_logits, y)

        # Return the total weighted validation loss
        return val_loss_total_w
    

    # --- Override configure_optimizers for Fine-tuning ---
    def configure_optimizers(self):
        lr = self.hparams.lr
        weight_decay = self.hparams.get('weight_decay', 1e-5)

        if self.mode == 'finetune' and getattr(self.hparams, 'freeze_encoder', False):
            print("Fine-tuning mode: Optimizing only head parameters.")
            # Optimize only the classification head parameters
            optimizer = torch.optim.AdamW([
                {'params': self.cls_token_embedding},
                {'params': self.classification_head.parameters()},
                {'params': self.encoder.transformer_encoder.parameters()},
            ],
                lr=lr,
                weight_decay=weight_decay
            )
        elif self.mode == 'finetune' and getattr(self.hparams, 'encoder_lr_factor', 1.0) != 1.0:
            print(f"Fine-tuning mode: Differential learning rates (Factor: {self.hparams.encoder_lr_factor}).")
            # Optimize encoder with lower LR, head with full LR
            optimizer = torch.optim.AdamW([
                {'params': self.encoder.parameters(), 'lr': lr * self.hparams.encoder_lr_factor},
                {'params': self.classification_head.parameters(), 'lr': lr},
            ], lr=lr, weight_decay=weight_decay) # Base LR is head LR here
        else:
             # Pre-training, joint training, or full fine-tuning
             print(f"{self.mode.capitalize()} mode: Optimizing all parameters.")
             optimizer = torch.optim.AdamW(
                 self.parameters(), # Optimize all parameters
                 lr=lr,
                 weight_decay=weight_decay
             )

        return optimizer

    # --- Helper for loading pre-trained encoder ---
    def load_encoder_weights(self, checkpoint_path: str):
        """Loads only the encoder weights from a pre-trained checkpoint."""
        print(f"Loading encoder weights from: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
        state_dict = checkpoint['state_dict']

        # Filter state dict to keep only encoder weights
        encoder_state_dict = {}
        for k, v in state_dict.items():
            if k.startswith("encoder."):
                encoder_state_dict[k.replace("encoder.", "", 1)] = v # Remove 'encoder.' prefix

        # Load into the encoder module
        missing_keys, unexpected_keys = self.encoder.load_state_dict(encoder_state_dict, strict=False)
        print(f"Encoder weights loaded. Missing keys: {missing_keys}, Unexpected keys: {unexpected_keys}")
        if unexpected_keys:
             print("Warning: Unexpected keys found, encoder architecture might have changed.")